Search in sources :

Example 1 with ColumnDecisionTreeTrainerInput

use of org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainerInput in project ignite by apache.

the class SplitDataGenerator method testByGen.

/**
 */
<D extends ContinuousRegionInfo> void testByGen(int totalPts, IgniteFunction<ColumnDecisionTreeTrainerInput, ? extends ContinuousSplitCalculator<D>> calc, IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> catImpCalc, IgniteFunction<DoubleStream, Double> regCalc, Ignite ignite) {
    List<IgniteBiTuple<Integer, V>> lst = points(totalPts, (i, rn) -> i).collect(Collectors.toList());
    Collections.shuffle(lst, rnd);
    SparseDistributedMatrix m = new SparseDistributedMatrix(totalPts, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
    Map<Integer, List<LabeledVectorDouble>> byRegion = new HashMap<>();
    int i = 0;
    for (IgniteBiTuple<Integer, V> bt : lst) {
        byRegion.putIfAbsent(bt.get1(), new LinkedList<>());
        byRegion.get(bt.get1()).add(asLabeledVector(bt.get2().getStorage().data()));
        m.setRow(i, bt.get2().getStorage().data());
        i++;
    }
    ColumnDecisionTreeTrainer<D> trainer = new ColumnDecisionTreeTrainer<>(3, calc, catImpCalc, regCalc, ignite);
    DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, catFeaturesInfo));
    byRegion.keySet().forEach(k -> mdl.apply(byRegion.get(k).get(0).features()));
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) BiFunction(java.util.function.BiFunction) ColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainerInput) HashMap(java.util.HashMap) Random(java.util.Random) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) Function(java.util.function.Function) Supplier(java.util.function.Supplier) Vector(org.apache.ignite.ml.math.Vector) Map(java.util.Map) LinkedList(java.util.LinkedList) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) LabeledVectorDouble(org.apache.ignite.ml.structures.LabeledVectorDouble) Ignite(org.apache.ignite.Ignite) Collectors(java.util.stream.Collectors) Serializable(java.io.Serializable) DoubleStream(java.util.stream.DoubleStream) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) List(java.util.List) Stream(java.util.stream.Stream) MathIllegalArgumentException(org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException) Utils(org.apache.ignite.ml.util.Utils) ContinuousSplitCalculator(org.apache.ignite.ml.trees.ContinuousSplitCalculator) BitSet(java.util.BitSet) StorageConstants(org.apache.ignite.ml.math.StorageConstants) ContinuousRegionInfo(org.apache.ignite.ml.trees.ContinuousRegionInfo) Collections(java.util.Collections) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) HashMap(java.util.HashMap) MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) LinkedList(java.util.LinkedList) List(java.util.List) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer)

Example 2 with ColumnDecisionTreeTrainerInput

use of org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainerInput in project ignite by apache.

the class ColumnDecisionTreeTrainerTest method testByGen.

/**
 */
private <D extends ContinuousRegionInfo> void testByGen(int totalPts, HashMap<Integer, Integer> catsInfo, SplitDataGenerator<DenseLocalOnHeapVector> gen, IgniteFunction<ColumnDecisionTreeTrainerInput, ? extends ContinuousSplitCalculator<D>> calc, IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> catImpCalc, IgniteFunction<DoubleStream, Double> regCalc, Random rnd) {
    List<IgniteBiTuple<Integer, DenseLocalOnHeapVector>> lst = gen.points(totalPts, (i, rn) -> i).collect(Collectors.toList());
    int featCnt = gen.featuresCnt();
    Collections.shuffle(lst, rnd);
    SparseDistributedMatrix m = new SparseDistributedMatrix(totalPts, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
    Map<Integer, List<LabeledVectorDouble>> byRegion = new HashMap<>();
    int i = 0;
    for (IgniteBiTuple<Integer, DenseLocalOnHeapVector> bt : lst) {
        byRegion.putIfAbsent(bt.get1(), new LinkedList<>());
        byRegion.get(bt.get1()).add(asLabeledVector(bt.get2().getStorage().data()));
        m.setRow(i, bt.get2().getStorage().data());
        i++;
    }
    ColumnDecisionTreeTrainer<D> trainer = new ColumnDecisionTreeTrainer<>(3, calc, catImpCalc, regCalc, ignite);
    DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, catsInfo));
    byRegion.keySet().forEach(k -> {
        LabeledVectorDouble sp = byRegion.get(k).get(0);
        Tracer.showAscii(sp.features());
        X.println("Actual and predicted vectors [act=" + sp.label() + " " + ", pred=" + mdl.apply(sp.features()) + "]");
        assert mdl.apply(sp.features()) == sp.doubleLabel();
    });
}
Also used : MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) LabeledVectorDouble(org.apache.ignite.ml.structures.LabeledVectorDouble) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) ContinuousSplitCalculators(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) ColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainerInput) HashMap(java.util.HashMap) Random(java.util.Random) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) Collectors(java.util.stream.Collectors) DoubleStream(java.util.stream.DoubleStream) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) List(java.util.List) Map(java.util.Map) IgniteUtils(org.apache.ignite.internal.util.IgniteUtils) X(org.apache.ignite.internal.util.typedef.X) Tracer(org.apache.ignite.ml.math.Tracer) LinkedList(java.util.LinkedList) StorageConstants(org.apache.ignite.ml.math.StorageConstants) RegionCalculators(org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators) Collections(java.util.Collections) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) LabeledVectorDouble(org.apache.ignite.ml.structures.LabeledVectorDouble) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) HashMap(java.util.HashMap) MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) List(java.util.List) LinkedList(java.util.LinkedList) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer)

Aggregations

Collections (java.util.Collections)2 HashMap (java.util.HashMap)2 LinkedList (java.util.LinkedList)2 List (java.util.List)2 Map (java.util.Map)2 Random (java.util.Random)2 Collectors (java.util.stream.Collectors)2 DoubleStream (java.util.stream.DoubleStream)2 IgniteBiTuple (org.apache.ignite.lang.IgniteBiTuple)2 StorageConstants (org.apache.ignite.ml.math.StorageConstants)2 IgniteFunction (org.apache.ignite.ml.math.functions.IgniteFunction)2 SparseDistributedMatrix (org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix)2 DenseLocalOnHeapVector (org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector)2 LabeledVectorDouble (org.apache.ignite.ml.structures.LabeledVectorDouble)2 DecisionTreeModel (org.apache.ignite.ml.trees.models.DecisionTreeModel)2 ColumnDecisionTreeTrainer (org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer)2 ColumnDecisionTreeTrainerInput (org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainerInput)2 MatrixColumnDecisionTreeTrainerInput (org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput)2 Serializable (java.io.Serializable)1 Arrays (java.util.Arrays)1