Search in sources :

Example 1 with ContinuousSplitCalculator

use of org.apache.ignite.ml.trees.ContinuousSplitCalculator in project ignite by apache.

the class SplitDataGenerator method testByGen.

/**
 */
<D extends ContinuousRegionInfo> void testByGen(int totalPts, IgniteFunction<ColumnDecisionTreeTrainerInput, ? extends ContinuousSplitCalculator<D>> calc, IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> catImpCalc, IgniteFunction<DoubleStream, Double> regCalc, Ignite ignite) {
    List<IgniteBiTuple<Integer, V>> lst = points(totalPts, (i, rn) -> i).collect(Collectors.toList());
    Collections.shuffle(lst, rnd);
    SparseDistributedMatrix m = new SparseDistributedMatrix(totalPts, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
    Map<Integer, List<LabeledVectorDouble>> byRegion = new HashMap<>();
    int i = 0;
    for (IgniteBiTuple<Integer, V> bt : lst) {
        byRegion.putIfAbsent(bt.get1(), new LinkedList<>());
        byRegion.get(bt.get1()).add(asLabeledVector(bt.get2().getStorage().data()));
        m.setRow(i, bt.get2().getStorage().data());
        i++;
    }
    ColumnDecisionTreeTrainer<D> trainer = new ColumnDecisionTreeTrainer<>(3, calc, catImpCalc, regCalc, ignite);
    DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, catFeaturesInfo));
    byRegion.keySet().forEach(k -> mdl.apply(byRegion.get(k).get(0).features()));
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) BiFunction(java.util.function.BiFunction) ColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainerInput) HashMap(java.util.HashMap) Random(java.util.Random) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) Function(java.util.function.Function) Supplier(java.util.function.Supplier) Vector(org.apache.ignite.ml.math.Vector) Map(java.util.Map) LinkedList(java.util.LinkedList) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) LabeledVectorDouble(org.apache.ignite.ml.structures.LabeledVectorDouble) Ignite(org.apache.ignite.Ignite) Collectors(java.util.stream.Collectors) Serializable(java.io.Serializable) DoubleStream(java.util.stream.DoubleStream) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) List(java.util.List) Stream(java.util.stream.Stream) MathIllegalArgumentException(org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException) Utils(org.apache.ignite.ml.util.Utils) ContinuousSplitCalculator(org.apache.ignite.ml.trees.ContinuousSplitCalculator) BitSet(java.util.BitSet) StorageConstants(org.apache.ignite.ml.math.StorageConstants) ContinuousRegionInfo(org.apache.ignite.ml.trees.ContinuousRegionInfo) Collections(java.util.Collections) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) HashMap(java.util.HashMap) MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) LinkedList(java.util.LinkedList) List(java.util.List) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer)

Aggregations

Serializable (java.io.Serializable)1 Arrays (java.util.Arrays)1 BitSet (java.util.BitSet)1 Collections (java.util.Collections)1 HashMap (java.util.HashMap)1 LinkedList (java.util.LinkedList)1 List (java.util.List)1 Map (java.util.Map)1 Random (java.util.Random)1 BiFunction (java.util.function.BiFunction)1 Function (java.util.function.Function)1 Supplier (java.util.function.Supplier)1 Collectors (java.util.stream.Collectors)1 DoubleStream (java.util.stream.DoubleStream)1 IntStream (java.util.stream.IntStream)1 Stream (java.util.stream.Stream)1 Ignite (org.apache.ignite.Ignite)1 IgniteBiTuple (org.apache.ignite.lang.IgniteBiTuple)1 StorageConstants (org.apache.ignite.ml.math.StorageConstants)1 Vector (org.apache.ignite.ml.math.Vector)1