Search in sources :

Example 1 with RegionProjection

use of org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection in project ignite by apache.

the class ProjectionsCache method projectionsOfRegions.

/**
 * Get region projections in the form of map (regionIndex -> regionProjections).
 *
 * @param featureIdx Feature index.
 * @param maxDepth Max depth of decision tree.
 * @param regionIndexes Indexes of regions for which we want get projections.
 * @param blockSize Size of regions block.
 * @param affinity Affinity function.
 * @param trainingUUID UUID of training.
 * @param ignite Ignite instance.
 * @return Region projections in the form of map (regionIndex -> regionProjections).
 */
public static Map<Integer, RegionProjection> projectionsOfRegions(int featureIdx, int maxDepth, IntStream regionIndexes, int blockSize, IgniteFunction<Integer, Object> affinity, UUID trainingUUID, Ignite ignite) {
    HashMap<Integer, RegionProjection> regsForSearch = new HashMap<>();
    IgniteCache<RegionKey, List<RegionProjection>> cache = getOrCreate(ignite);
    PrimitiveIterator.OfInt itr = regionIndexes.iterator();
    int curBlockIdx = -1;
    List<RegionProjection> block = null;
    Object affinityKey = affinity.apply(featureIdx);
    while (itr.hasNext()) {
        int i = itr.nextInt();
        int blockIdx = i / blockSize;
        if (blockIdx != curBlockIdx) {
            block = cache.localPeek(key(featureIdx, blockIdx, affinityKey, trainingUUID));
            curBlockIdx = blockIdx;
        }
        if (block == null)
            throw new IllegalStateException("Unexpected null block at index " + i);
        RegionProjection reg = block.get(i % blockSize);
        if (reg.depth() < maxDepth)
            regsForSearch.put(i, reg);
    }
    return regsForSearch;
}
Also used : PrimitiveIterator(java.util.PrimitiveIterator) HashMap(java.util.HashMap) RegionProjection(org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection) List(java.util.List)

Example 2 with RegionProjection

use of org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection in project ignite by apache.

the class CategoricalFeatureProcessor method createInitialRegion.

/**
 * {@inheritDoc}
 */
@Override
public RegionProjection<CategoricalRegionInfo> createInitialRegion(Integer[] sampleIndexes, double[] values, double[] labels) {
    BitSet set = new BitSet();
    set.set(0, catsCnt);
    Double impurity = calc.apply(Arrays.stream(labels));
    return new RegionProjection<>(sampleIndexes, new CategoricalRegionInfo(impurity, set), 0);
}
Also used : CategoricalRegionInfo(org.apache.ignite.ml.trees.CategoricalRegionInfo) RegionProjection(org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection) FeatureVectorProcessorUtils.splitByBitSet(org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureVectorProcessorUtils.splitByBitSet) BitSet(java.util.BitSet) SparseBitSet(com.zaxxer.sparsebits.SparseBitSet)

Example 3 with RegionProjection

use of org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection in project ignite by apache.

the class CategoricalFeatureProcessor method performSplitGeneric.

/**
 * {@inheritDoc}
 */
@Override
public IgniteBiTuple<RegionProjection, RegionProjection> performSplitGeneric(SparseBitSet bs, double[] values, RegionProjection<CategoricalRegionInfo> reg, RegionInfo leftData, RegionInfo rightData) {
    int depth = reg.depth();
    int lSize = bs.cardinality();
    int rSize = reg.sampleIndexes().length - lSize;
    IgniteBiTuple<Integer[], Integer[]> lrSamples = splitByBitSet(lSize, rSize, reg.sampleIndexes(), bs);
    BitSet leftCats = calculateCats(lrSamples.get1(), values);
    CategoricalRegionInfo lInfo = new CategoricalRegionInfo(leftData.impurity(), leftCats);
    // TODO: IGNITE-5892 Check how it will work with sparse data.
    BitSet rightCats = calculateCats(lrSamples.get2(), values);
    CategoricalRegionInfo rInfo = new CategoricalRegionInfo(rightData.impurity(), rightCats);
    RegionProjection<CategoricalRegionInfo> rPrj = new RegionProjection<>(lrSamples.get2(), rInfo, depth + 1);
    RegionProjection<CategoricalRegionInfo> lPrj = new RegionProjection<>(lrSamples.get1(), lInfo, depth + 1);
    return new IgniteBiTuple<>(lPrj, rPrj);
}
Also used : CategoricalRegionInfo(org.apache.ignite.ml.trees.CategoricalRegionInfo) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) RegionProjection(org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection) FeatureVectorProcessorUtils.splitByBitSet(org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureVectorProcessorUtils.splitByBitSet) BitSet(java.util.BitSet) SparseBitSet(com.zaxxer.sparsebits.SparseBitSet)

Aggregations

RegionProjection (org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection)3 SparseBitSet (com.zaxxer.sparsebits.SparseBitSet)2 BitSet (java.util.BitSet)2 CategoricalRegionInfo (org.apache.ignite.ml.trees.CategoricalRegionInfo)2 FeatureVectorProcessorUtils.splitByBitSet (org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureVectorProcessorUtils.splitByBitSet)2 HashMap (java.util.HashMap)1 List (java.util.List)1 PrimitiveIterator (java.util.PrimitiveIterator)1 IgniteBiTuple (org.apache.ignite.lang.IgniteBiTuple)1