Search in sources :

Example 1 with CategoricalSplitInfo

use of org.apache.ignite.ml.trees.CategoricalSplitInfo in project ignite by apache.

the class CategoricalFeatureProcessor method split.

/**
 */
private SplitInfo<CategoricalRegionInfo> split(BitSet leftCats, int intervalIdx, Map<Integer, Integer> mapping, Integer[] sampleIndexes, double[] values, double[] labels, double impurity) {
    Map<Boolean, List<Integer>> leftRight = Arrays.stream(sampleIndexes).collect(Collectors.partitioningBy((smpl) -> leftCats.get(mapping.get((int) values[smpl]))));
    List<Integer> left = leftRight.get(true);
    int leftSize = left.size();
    double leftImpurity = calc.apply(left.stream().mapToDouble(s -> labels[s]));
    List<Integer> right = leftRight.get(false);
    int rightSize = right.size();
    double rightImpurity = calc.apply(right.stream().mapToDouble(s -> labels[s]));
    int totalSize = leftSize + rightSize;
    // Result of this call will be sent back to trainer node, we do not need vectors inside of sent data.
    CategoricalSplitInfo<CategoricalRegionInfo> res = new CategoricalSplitInfo<>(intervalIdx, // cats can be computed on the last step.
    new CategoricalRegionInfo(leftImpurity, null), new CategoricalRegionInfo(rightImpurity, null), leftCats);
    res.setInfoGain(impurity - (double) leftSize / totalSize * leftImpurity - (double) rightSize / totalSize * rightImpurity);
    return res;
}
Also used : Arrays(java.util.Arrays) RegionProjection(org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection) Iterator(java.util.Iterator) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) HashMap(java.util.HashMap) FeatureVectorProcessorUtils.splitByBitSet(org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureVectorProcessorUtils.splitByBitSet) Collectors(java.util.stream.Collectors) DoubleStream(java.util.stream.DoubleStream) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) List(java.util.List) RegionInfo(org.apache.ignite.ml.trees.RegionInfo) Stream(java.util.stream.Stream) CategoricalSplitInfo(org.apache.ignite.ml.trees.CategoricalSplitInfo) Map(java.util.Map) StreamSupport(java.util.stream.StreamSupport) BitSet(java.util.BitSet) Comparator(java.util.Comparator) CategoricalRegionInfo(org.apache.ignite.ml.trees.CategoricalRegionInfo) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer) SparseBitSet(com.zaxxer.sparsebits.SparseBitSet) CategoricalRegionInfo(org.apache.ignite.ml.trees.CategoricalRegionInfo) CategoricalSplitInfo(org.apache.ignite.ml.trees.CategoricalSplitInfo) List(java.util.List)

Aggregations

SparseBitSet (com.zaxxer.sparsebits.SparseBitSet)1 Arrays (java.util.Arrays)1 BitSet (java.util.BitSet)1 Comparator (java.util.Comparator)1 HashMap (java.util.HashMap)1 Iterator (java.util.Iterator)1 List (java.util.List)1 Map (java.util.Map)1 Collectors (java.util.stream.Collectors)1 DoubleStream (java.util.stream.DoubleStream)1 Stream (java.util.stream.Stream)1 StreamSupport (java.util.stream.StreamSupport)1 IgniteBiTuple (org.apache.ignite.lang.IgniteBiTuple)1 IgniteFunction (org.apache.ignite.ml.math.functions.IgniteFunction)1 CategoricalRegionInfo (org.apache.ignite.ml.trees.CategoricalRegionInfo)1 CategoricalSplitInfo (org.apache.ignite.ml.trees.CategoricalSplitInfo)1 RegionInfo (org.apache.ignite.ml.trees.RegionInfo)1 ColumnDecisionTreeTrainer (org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer)1 RegionProjection (org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection)1 FeatureVectorProcessorUtils.splitByBitSet (org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureVectorProcessorUtils.splitByBitSet)1