use of org.apache.ignite.ml.trees.CategoricalRegionInfo in project ignite by apache.
the class CategoricalFeatureProcessor method createInitialRegion.
/**
* {@inheritDoc}
*/
@Override
public RegionProjection<CategoricalRegionInfo> createInitialRegion(Integer[] sampleIndexes, double[] values, double[] labels) {
BitSet set = new BitSet();
set.set(0, catsCnt);
Double impurity = calc.apply(Arrays.stream(labels));
return new RegionProjection<>(sampleIndexes, new CategoricalRegionInfo(impurity, set), 0);
}
use of org.apache.ignite.ml.trees.CategoricalRegionInfo in project ignite by apache.
the class CategoricalFeatureProcessor method split.
/**
*/
private SplitInfo<CategoricalRegionInfo> split(BitSet leftCats, int intervalIdx, Map<Integer, Integer> mapping, Integer[] sampleIndexes, double[] values, double[] labels, double impurity) {
Map<Boolean, List<Integer>> leftRight = Arrays.stream(sampleIndexes).collect(Collectors.partitioningBy((smpl) -> leftCats.get(mapping.get((int) values[smpl]))));
List<Integer> left = leftRight.get(true);
int leftSize = left.size();
double leftImpurity = calc.apply(left.stream().mapToDouble(s -> labels[s]));
List<Integer> right = leftRight.get(false);
int rightSize = right.size();
double rightImpurity = calc.apply(right.stream().mapToDouble(s -> labels[s]));
int totalSize = leftSize + rightSize;
// Result of this call will be sent back to trainer node, we do not need vectors inside of sent data.
CategoricalSplitInfo<CategoricalRegionInfo> res = new CategoricalSplitInfo<>(intervalIdx, // cats can be computed on the last step.
new CategoricalRegionInfo(leftImpurity, null), new CategoricalRegionInfo(rightImpurity, null), leftCats);
res.setInfoGain(impurity - (double) leftSize / totalSize * leftImpurity - (double) rightSize / totalSize * rightImpurity);
return res;
}
use of org.apache.ignite.ml.trees.CategoricalRegionInfo in project ignite by apache.
the class CategoricalFeatureProcessor method performSplitGeneric.
/**
* {@inheritDoc}
*/
@Override
public IgniteBiTuple<RegionProjection, RegionProjection> performSplitGeneric(SparseBitSet bs, double[] values, RegionProjection<CategoricalRegionInfo> reg, RegionInfo leftData, RegionInfo rightData) {
int depth = reg.depth();
int lSize = bs.cardinality();
int rSize = reg.sampleIndexes().length - lSize;
IgniteBiTuple<Integer[], Integer[]> lrSamples = splitByBitSet(lSize, rSize, reg.sampleIndexes(), bs);
BitSet leftCats = calculateCats(lrSamples.get1(), values);
CategoricalRegionInfo lInfo = new CategoricalRegionInfo(leftData.impurity(), leftCats);
// TODO: IGNITE-5892 Check how it will work with sparse data.
BitSet rightCats = calculateCats(lrSamples.get2(), values);
CategoricalRegionInfo rInfo = new CategoricalRegionInfo(rightData.impurity(), rightCats);
RegionProjection<CategoricalRegionInfo> rPrj = new RegionProjection<>(lrSamples.get2(), rInfo, depth + 1);
RegionProjection<CategoricalRegionInfo> lPrj = new RegionProjection<>(lrSamples.get1(), lInfo, depth + 1);
return new IgniteBiTuple<>(lPrj, rPrj);
}
Aggregations