use of org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection in project ignite by apache.
the class ProjectionsCache method projectionsOfRegions.
/**
* Get region projections in the form of map (regionIndex -> regionProjections).
*
* @param featureIdx Feature index.
* @param maxDepth Max depth of decision tree.
* @param regionIndexes Indexes of regions for which we want get projections.
* @param blockSize Size of regions block.
* @param affinity Affinity function.
* @param trainingUUID UUID of training.
* @param ignite Ignite instance.
* @return Region projections in the form of map (regionIndex -> regionProjections).
*/
public static Map<Integer, RegionProjection> projectionsOfRegions(int featureIdx, int maxDepth, IntStream regionIndexes, int blockSize, IgniteFunction<Integer, Object> affinity, UUID trainingUUID, Ignite ignite) {
HashMap<Integer, RegionProjection> regsForSearch = new HashMap<>();
IgniteCache<RegionKey, List<RegionProjection>> cache = getOrCreate(ignite);
PrimitiveIterator.OfInt itr = regionIndexes.iterator();
int curBlockIdx = -1;
List<RegionProjection> block = null;
Object affinityKey = affinity.apply(featureIdx);
while (itr.hasNext()) {
int i = itr.nextInt();
int blockIdx = i / blockSize;
if (blockIdx != curBlockIdx) {
block = cache.localPeek(key(featureIdx, blockIdx, affinityKey, trainingUUID));
curBlockIdx = blockIdx;
}
if (block == null)
throw new IllegalStateException("Unexpected null block at index " + i);
RegionProjection reg = block.get(i % blockSize);
if (reg.depth() < maxDepth)
regsForSearch.put(i, reg);
}
return regsForSearch;
}
use of org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection in project ignite by apache.
the class CategoricalFeatureProcessor method createInitialRegion.
/**
* {@inheritDoc}
*/
@Override
public RegionProjection<CategoricalRegionInfo> createInitialRegion(Integer[] sampleIndexes, double[] values, double[] labels) {
BitSet set = new BitSet();
set.set(0, catsCnt);
Double impurity = calc.apply(Arrays.stream(labels));
return new RegionProjection<>(sampleIndexes, new CategoricalRegionInfo(impurity, set), 0);
}
use of org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection in project ignite by apache.
the class CategoricalFeatureProcessor method performSplitGeneric.
/**
* {@inheritDoc}
*/
@Override
public IgniteBiTuple<RegionProjection, RegionProjection> performSplitGeneric(SparseBitSet bs, double[] values, RegionProjection<CategoricalRegionInfo> reg, RegionInfo leftData, RegionInfo rightData) {
int depth = reg.depth();
int lSize = bs.cardinality();
int rSize = reg.sampleIndexes().length - lSize;
IgniteBiTuple<Integer[], Integer[]> lrSamples = splitByBitSet(lSize, rSize, reg.sampleIndexes(), bs);
BitSet leftCats = calculateCats(lrSamples.get1(), values);
CategoricalRegionInfo lInfo = new CategoricalRegionInfo(leftData.impurity(), leftCats);
// TODO: IGNITE-5892 Check how it will work with sparse data.
BitSet rightCats = calculateCats(lrSamples.get2(), values);
CategoricalRegionInfo rInfo = new CategoricalRegionInfo(rightData.impurity(), rightCats);
RegionProjection<CategoricalRegionInfo> rPrj = new RegionProjection<>(lrSamples.get2(), rInfo, depth + 1);
RegionProjection<CategoricalRegionInfo> lPrj = new RegionProjection<>(lrSamples.get1(), lInfo, depth + 1);
return new IgniteBiTuple<>(lPrj, rPrj);
}
Aggregations