Search in sources :

Example 1 with CategoricalRegionInfo

use of org.apache.ignite.ml.trees.CategoricalRegionInfo in project ignite by apache.

the class CategoricalFeatureProcessor method createInitialRegion.

/**
 * {@inheritDoc}
 */
@Override
public RegionProjection<CategoricalRegionInfo> createInitialRegion(Integer[] sampleIndexes, double[] values, double[] labels) {
    BitSet set = new BitSet();
    set.set(0, catsCnt);
    Double impurity = calc.apply(Arrays.stream(labels));
    return new RegionProjection<>(sampleIndexes, new CategoricalRegionInfo(impurity, set), 0);
}
Also used : CategoricalRegionInfo(org.apache.ignite.ml.trees.CategoricalRegionInfo) RegionProjection(org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection) FeatureVectorProcessorUtils.splitByBitSet(org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureVectorProcessorUtils.splitByBitSet) BitSet(java.util.BitSet) SparseBitSet(com.zaxxer.sparsebits.SparseBitSet)

Example 2 with CategoricalRegionInfo

use of org.apache.ignite.ml.trees.CategoricalRegionInfo in project ignite by apache.

the class CategoricalFeatureProcessor method split.

/**
 */
private SplitInfo<CategoricalRegionInfo> split(BitSet leftCats, int intervalIdx, Map<Integer, Integer> mapping, Integer[] sampleIndexes, double[] values, double[] labels, double impurity) {
    Map<Boolean, List<Integer>> leftRight = Arrays.stream(sampleIndexes).collect(Collectors.partitioningBy((smpl) -> leftCats.get(mapping.get((int) values[smpl]))));
    List<Integer> left = leftRight.get(true);
    int leftSize = left.size();
    double leftImpurity = calc.apply(left.stream().mapToDouble(s -> labels[s]));
    List<Integer> right = leftRight.get(false);
    int rightSize = right.size();
    double rightImpurity = calc.apply(right.stream().mapToDouble(s -> labels[s]));
    int totalSize = leftSize + rightSize;
    // Result of this call will be sent back to trainer node, we do not need vectors inside of sent data.
    CategoricalSplitInfo<CategoricalRegionInfo> res = new CategoricalSplitInfo<>(intervalIdx, // cats can be computed on the last step.
    new CategoricalRegionInfo(leftImpurity, null), new CategoricalRegionInfo(rightImpurity, null), leftCats);
    res.setInfoGain(impurity - (double) leftSize / totalSize * leftImpurity - (double) rightSize / totalSize * rightImpurity);
    return res;
}
Also used : Arrays(java.util.Arrays) RegionProjection(org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection) Iterator(java.util.Iterator) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) HashMap(java.util.HashMap) FeatureVectorProcessorUtils.splitByBitSet(org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureVectorProcessorUtils.splitByBitSet) Collectors(java.util.stream.Collectors) DoubleStream(java.util.stream.DoubleStream) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) List(java.util.List) RegionInfo(org.apache.ignite.ml.trees.RegionInfo) Stream(java.util.stream.Stream) CategoricalSplitInfo(org.apache.ignite.ml.trees.CategoricalSplitInfo) Map(java.util.Map) StreamSupport(java.util.stream.StreamSupport) BitSet(java.util.BitSet) Comparator(java.util.Comparator) CategoricalRegionInfo(org.apache.ignite.ml.trees.CategoricalRegionInfo) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer) SparseBitSet(com.zaxxer.sparsebits.SparseBitSet) CategoricalRegionInfo(org.apache.ignite.ml.trees.CategoricalRegionInfo) CategoricalSplitInfo(org.apache.ignite.ml.trees.CategoricalSplitInfo) List(java.util.List)

Example 3 with CategoricalRegionInfo

use of org.apache.ignite.ml.trees.CategoricalRegionInfo in project ignite by apache.

the class CategoricalFeatureProcessor method performSplitGeneric.

/**
 * {@inheritDoc}
 */
@Override
public IgniteBiTuple<RegionProjection, RegionProjection> performSplitGeneric(SparseBitSet bs, double[] values, RegionProjection<CategoricalRegionInfo> reg, RegionInfo leftData, RegionInfo rightData) {
    int depth = reg.depth();
    int lSize = bs.cardinality();
    int rSize = reg.sampleIndexes().length - lSize;
    IgniteBiTuple<Integer[], Integer[]> lrSamples = splitByBitSet(lSize, rSize, reg.sampleIndexes(), bs);
    BitSet leftCats = calculateCats(lrSamples.get1(), values);
    CategoricalRegionInfo lInfo = new CategoricalRegionInfo(leftData.impurity(), leftCats);
    // TODO: IGNITE-5892 Check how it will work with sparse data.
    BitSet rightCats = calculateCats(lrSamples.get2(), values);
    CategoricalRegionInfo rInfo = new CategoricalRegionInfo(rightData.impurity(), rightCats);
    RegionProjection<CategoricalRegionInfo> rPrj = new RegionProjection<>(lrSamples.get2(), rInfo, depth + 1);
    RegionProjection<CategoricalRegionInfo> lPrj = new RegionProjection<>(lrSamples.get1(), lInfo, depth + 1);
    return new IgniteBiTuple<>(lPrj, rPrj);
}
Also used : CategoricalRegionInfo(org.apache.ignite.ml.trees.CategoricalRegionInfo) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) RegionProjection(org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection) FeatureVectorProcessorUtils.splitByBitSet(org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureVectorProcessorUtils.splitByBitSet) BitSet(java.util.BitSet) SparseBitSet(com.zaxxer.sparsebits.SparseBitSet)

Aggregations

SparseBitSet (com.zaxxer.sparsebits.SparseBitSet)3 BitSet (java.util.BitSet)3 CategoricalRegionInfo (org.apache.ignite.ml.trees.CategoricalRegionInfo)3 RegionProjection (org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection)3 FeatureVectorProcessorUtils.splitByBitSet (org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureVectorProcessorUtils.splitByBitSet)3 IgniteBiTuple (org.apache.ignite.lang.IgniteBiTuple)2 Arrays (java.util.Arrays)1 Comparator (java.util.Comparator)1 HashMap (java.util.HashMap)1 Iterator (java.util.Iterator)1 List (java.util.List)1 Map (java.util.Map)1 Collectors (java.util.stream.Collectors)1 DoubleStream (java.util.stream.DoubleStream)1 Stream (java.util.stream.Stream)1 StreamSupport (java.util.stream.StreamSupport)1 IgniteFunction (org.apache.ignite.ml.math.functions.IgniteFunction)1 CategoricalSplitInfo (org.apache.ignite.ml.trees.CategoricalSplitInfo)1 RegionInfo (org.apache.ignite.ml.trees.RegionInfo)1 ColumnDecisionTreeTrainer (org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer)1