Search in sources :

Example 1 with DecisionTreeModel

use of org.apache.ignite.ml.trees.models.DecisionTreeModel in project ignite by apache.

the class SplitDataGenerator method testByGen.

/**
 */
<D extends ContinuousRegionInfo> void testByGen(int totalPts, IgniteFunction<ColumnDecisionTreeTrainerInput, ? extends ContinuousSplitCalculator<D>> calc, IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> catImpCalc, IgniteFunction<DoubleStream, Double> regCalc, Ignite ignite) {
    List<IgniteBiTuple<Integer, V>> lst = points(totalPts, (i, rn) -> i).collect(Collectors.toList());
    Collections.shuffle(lst, rnd);
    SparseDistributedMatrix m = new SparseDistributedMatrix(totalPts, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
    Map<Integer, List<LabeledVectorDouble>> byRegion = new HashMap<>();
    int i = 0;
    for (IgniteBiTuple<Integer, V> bt : lst) {
        byRegion.putIfAbsent(bt.get1(), new LinkedList<>());
        byRegion.get(bt.get1()).add(asLabeledVector(bt.get2().getStorage().data()));
        m.setRow(i, bt.get2().getStorage().data());
        i++;
    }
    ColumnDecisionTreeTrainer<D> trainer = new ColumnDecisionTreeTrainer<>(3, calc, catImpCalc, regCalc, ignite);
    DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, catFeaturesInfo));
    byRegion.keySet().forEach(k -> mdl.apply(byRegion.get(k).get(0).features()));
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) BiFunction(java.util.function.BiFunction) ColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainerInput) HashMap(java.util.HashMap) Random(java.util.Random) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) Function(java.util.function.Function) Supplier(java.util.function.Supplier) Vector(org.apache.ignite.ml.math.Vector) Map(java.util.Map) LinkedList(java.util.LinkedList) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) LabeledVectorDouble(org.apache.ignite.ml.structures.LabeledVectorDouble) Ignite(org.apache.ignite.Ignite) Collectors(java.util.stream.Collectors) Serializable(java.io.Serializable) DoubleStream(java.util.stream.DoubleStream) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) List(java.util.List) Stream(java.util.stream.Stream) MathIllegalArgumentException(org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException) Utils(org.apache.ignite.ml.util.Utils) ContinuousSplitCalculator(org.apache.ignite.ml.trees.ContinuousSplitCalculator) BitSet(java.util.BitSet) StorageConstants(org.apache.ignite.ml.math.StorageConstants) ContinuousRegionInfo(org.apache.ignite.ml.trees.ContinuousRegionInfo) Collections(java.util.Collections) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) HashMap(java.util.HashMap) MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) LinkedList(java.util.LinkedList) List(java.util.List) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer)

Example 2 with DecisionTreeModel

use of org.apache.ignite.ml.trees.models.DecisionTreeModel in project ignite by apache.

the class ColumnDecisionTreeTrainerBenchmark method tstMNISTSparseDistributedMatrix.

/**
 * Run decision tree classifier on MNIST using sparse distributed matrix as a storage for dataset.
 * To run this test rename this method so it starts from 'test'.
 *
 * @throws IOException In case of loading MNIST dataset errors.
 */
public void tstMNISTSparseDistributedMatrix() throws IOException {
    IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
    int ptsCnt = 30_000;
    int featCnt = 28 * 28;
    Properties props = loadMNISTProperties();
    Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnist(props.getProperty(PROP_TRAINING_IMAGES), props.getProperty(PROP_TRAINING_LABELS), new Random(123L), ptsCnt);
    Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnist(props.getProperty(PROP_TEST_IMAGES), props.getProperty(PROP_TEST_LABELS), new Random(123L), 10_000);
    SparseDistributedMatrix m = new SparseDistributedMatrix(ptsCnt, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
    SparseDistributedMatrixStorage sto = (SparseDistributedMatrixStorage) m.getStorage();
    loadVectorsIntoSparseDistributedMatrixCache(sto.cache().getName(), sto.getUUID(), trainingMnistStream.iterator(), featCnt + 1);
    ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer = new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite);
    X.println("Training started");
    long before = System.currentTimeMillis();
    DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, new HashMap<>()));
    X.println("Training finished in " + (System.currentTimeMillis() - before));
    IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.errorsPercentage();
    Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity());
    X.println("Errors percentage: " + accuracy);
    Assert.assertEquals(0, SplitCache.getOrCreate(ignite).size());
    Assert.assertEquals(0, FeaturesCache.getOrCreate(ignite).size());
    Assert.assertEquals(0, ContextCache.getOrCreate(ignite).size());
    Assert.assertEquals(0, ProjectionsCache.getOrCreate(ignite).size());
}
Also used : CacheAtomicityMode(org.apache.ignite.cache.CacheAtomicityMode) Arrays(java.util.Arrays) FeaturesCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache) IgniteTestResources(org.apache.ignite.testframework.junits.IgniteTestResources) Random(java.util.Random) BiIndex(org.apache.ignite.ml.trees.trainers.columnbased.BiIndex) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) SparseDistributedMatrixStorage(org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorage) VarianceSplitCalculator(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.VarianceSplitCalculator) Vector(org.apache.ignite.ml.math.Vector) Estimators(org.apache.ignite.ml.estimators.Estimators) Map(java.util.Map) X(org.apache.ignite.internal.util.typedef.X) Level(org.apache.log4j.Level) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) LabeledVectorDouble(org.apache.ignite.ml.structures.LabeledVectorDouble) BaseDecisionTreeTest(org.apache.ignite.ml.trees.BaseDecisionTreeTest) IgniteTriFunction(org.apache.ignite.ml.math.functions.IgniteTriFunction) ProjectionsCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache) UUID(java.util.UUID) StreamTransformer(org.apache.ignite.stream.StreamTransformer) Collectors(java.util.stream.Collectors) IgniteCache(org.apache.ignite.IgniteCache) ContextCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ContextCache) DoubleStream(java.util.stream.DoubleStream) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) List(java.util.List) IgniteConfiguration(org.apache.ignite.configuration.IgniteConfiguration) Stream(java.util.stream.Stream) SparseMatrixKey(org.apache.ignite.ml.math.distributed.keys.impl.SparseMatrixKey) SplitCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache) RegionCalculators(org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators) IntStream(java.util.stream.IntStream) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) Model(org.apache.ignite.ml.Model) HashMap(java.util.HashMap) Function(java.util.function.Function) GiniSplitCalculator(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator) BiIndexedCacheColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput) CacheWriteSynchronizationMode(org.apache.ignite.cache.CacheWriteSynchronizationMode) IgniteUtils(org.apache.ignite.internal.util.IgniteUtils) MnistUtils(org.apache.ignite.ml.util.MnistUtils) LinkedList(java.util.LinkedList) Properties(java.util.Properties) Iterator(java.util.Iterator) ContinuousSplitCalculators(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators) IOException(java.io.IOException) SplitDataGenerator(org.apache.ignite.ml.trees.SplitDataGenerator) Int2DoubleOpenHashMap(it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap) Ignition(org.apache.ignite.Ignition) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) IgniteDataStreamer(org.apache.ignite.IgniteDataStreamer) Tracer(org.apache.ignite.ml.math.Tracer) StorageConstants(org.apache.ignite.ml.math.StorageConstants) Assert(org.junit.Assert) Collections(java.util.Collections) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer) GridCacheProcessor(org.apache.ignite.internal.processors.cache.GridCacheProcessor) InputStream(java.io.InputStream) CacheMode(org.apache.ignite.cache.CacheMode) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) HashMap(java.util.HashMap) Int2DoubleOpenHashMap(it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) SparseDistributedMatrixStorage(org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorage) Properties(java.util.Properties) LabeledVectorDouble(org.apache.ignite.ml.structures.LabeledVectorDouble) IgniteTriFunction(org.apache.ignite.ml.math.functions.IgniteTriFunction) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) Function(java.util.function.Function) Random(java.util.Random) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) Model(org.apache.ignite.ml.Model) DoubleStream(java.util.stream.DoubleStream) Stream(java.util.stream.Stream) IntStream(java.util.stream.IntStream) InputStream(java.io.InputStream) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) Vector(org.apache.ignite.ml.math.Vector) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer)

Example 3 with DecisionTreeModel

use of org.apache.ignite.ml.trees.models.DecisionTreeModel in project ignite by apache.

the class ColumnDecisionTreeTrainerBenchmark method tstMNISTBiIndexedCache.

/**
 * Run decision tree classifier on MNIST using bi-indexed cache as a storage for dataset.
 * To run this test rename this method so it starts from 'test'.
 *
 * @throws IOException In case of loading MNIST dataset errors.
 */
public void tstMNISTBiIndexedCache() throws IOException {
    IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
    int ptsCnt = 40_000;
    int featCnt = 28 * 28;
    Properties props = loadMNISTProperties();
    Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnist(props.getProperty(PROP_TRAINING_IMAGES), props.getProperty(PROP_TRAINING_LABELS), new Random(123L), ptsCnt);
    Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnist(props.getProperty(PROP_TEST_IMAGES), props.getProperty(PROP_TEST_LABELS), new Random(123L), 10_000);
    IgniteCache<BiIndex, Double> cache = createBiIndexedCache();
    loadVectorsIntoBiIndexedCache(cache.getName(), trainingMnistStream.iterator(), featCnt + 1);
    ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer = new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite);
    X.println("Training started.");
    long before = System.currentTimeMillis();
    DecisionTreeModel mdl = trainer.train(new BiIndexedCacheColumnDecisionTreeTrainerInput(cache, new HashMap<>(), ptsCnt, featCnt));
    X.println("Training finished in " + (System.currentTimeMillis() - before));
    IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.errorsPercentage();
    Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity());
    X.println("Errors percentage: " + accuracy);
    Assert.assertEquals(0, SplitCache.getOrCreate(ignite).size());
    Assert.assertEquals(0, FeaturesCache.getOrCreate(ignite).size());
    Assert.assertEquals(0, ContextCache.getOrCreate(ignite).size());
    Assert.assertEquals(0, ProjectionsCache.getOrCreate(ignite).size());
}
Also used : CacheAtomicityMode(org.apache.ignite.cache.CacheAtomicityMode) Arrays(java.util.Arrays) FeaturesCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache) IgniteTestResources(org.apache.ignite.testframework.junits.IgniteTestResources) Random(java.util.Random) BiIndex(org.apache.ignite.ml.trees.trainers.columnbased.BiIndex) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) SparseDistributedMatrixStorage(org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorage) VarianceSplitCalculator(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.VarianceSplitCalculator) Vector(org.apache.ignite.ml.math.Vector) Estimators(org.apache.ignite.ml.estimators.Estimators) Map(java.util.Map) X(org.apache.ignite.internal.util.typedef.X) Level(org.apache.log4j.Level) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) LabeledVectorDouble(org.apache.ignite.ml.structures.LabeledVectorDouble) BaseDecisionTreeTest(org.apache.ignite.ml.trees.BaseDecisionTreeTest) IgniteTriFunction(org.apache.ignite.ml.math.functions.IgniteTriFunction) ProjectionsCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache) UUID(java.util.UUID) StreamTransformer(org.apache.ignite.stream.StreamTransformer) Collectors(java.util.stream.Collectors) IgniteCache(org.apache.ignite.IgniteCache) ContextCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ContextCache) DoubleStream(java.util.stream.DoubleStream) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) List(java.util.List) IgniteConfiguration(org.apache.ignite.configuration.IgniteConfiguration) Stream(java.util.stream.Stream) SparseMatrixKey(org.apache.ignite.ml.math.distributed.keys.impl.SparseMatrixKey) SplitCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache) RegionCalculators(org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators) IntStream(java.util.stream.IntStream) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) Model(org.apache.ignite.ml.Model) HashMap(java.util.HashMap) Function(java.util.function.Function) GiniSplitCalculator(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator) BiIndexedCacheColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput) CacheWriteSynchronizationMode(org.apache.ignite.cache.CacheWriteSynchronizationMode) IgniteUtils(org.apache.ignite.internal.util.IgniteUtils) MnistUtils(org.apache.ignite.ml.util.MnistUtils) LinkedList(java.util.LinkedList) Properties(java.util.Properties) Iterator(java.util.Iterator) ContinuousSplitCalculators(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators) IOException(java.io.IOException) SplitDataGenerator(org.apache.ignite.ml.trees.SplitDataGenerator) Int2DoubleOpenHashMap(it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap) Ignition(org.apache.ignite.Ignition) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) IgniteDataStreamer(org.apache.ignite.IgniteDataStreamer) Tracer(org.apache.ignite.ml.math.Tracer) StorageConstants(org.apache.ignite.ml.math.StorageConstants) Assert(org.junit.Assert) Collections(java.util.Collections) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer) GridCacheProcessor(org.apache.ignite.internal.processors.cache.GridCacheProcessor) InputStream(java.io.InputStream) CacheMode(org.apache.ignite.cache.CacheMode) HashMap(java.util.HashMap) Int2DoubleOpenHashMap(it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) BiIndex(org.apache.ignite.ml.trees.trainers.columnbased.BiIndex) Properties(java.util.Properties) LabeledVectorDouble(org.apache.ignite.ml.structures.LabeledVectorDouble) BiIndexedCacheColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput) IgniteTriFunction(org.apache.ignite.ml.math.functions.IgniteTriFunction) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) Function(java.util.function.Function) Random(java.util.Random) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) Model(org.apache.ignite.ml.Model) DoubleStream(java.util.stream.DoubleStream) Stream(java.util.stream.Stream) IntStream(java.util.stream.IntStream) InputStream(java.io.InputStream) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) Vector(org.apache.ignite.ml.math.Vector) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer)

Example 4 with DecisionTreeModel

use of org.apache.ignite.ml.trees.models.DecisionTreeModel in project ignite by apache.

the class ColumnDecisionTreeTrainerBenchmark method testByGenStreamerLoad.

/**
 */
private void testByGenStreamerLoad(int ptsPerReg, HashMap<Integer, Integer> catsInfo, SplitDataGenerator<DenseLocalOnHeapVector> gen, Random rnd) {
    List<IgniteBiTuple<Integer, DenseLocalOnHeapVector>> lst = gen.points(ptsPerReg, (i, rn) -> i).collect(Collectors.toList());
    int featCnt = gen.featuresCnt();
    Collections.shuffle(lst, rnd);
    int numRegs = gen.regsCount();
    SparseDistributedMatrix m = new SparseDistributedMatrix(numRegs * ptsPerReg, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
    IgniteFunction<DoubleStream, Double> regCalc = s -> s.average().orElse(0.0);
    Map<Integer, List<LabeledVectorDouble>> byRegion = new HashMap<>();
    SparseDistributedMatrixStorage sto = (SparseDistributedMatrixStorage) m.getStorage();
    long before = System.currentTimeMillis();
    X.println("Batch loading started...");
    loadVectorsIntoSparseDistributedMatrixCache(sto.cache().getName(), sto.getUUID(), gen.points(ptsPerReg, (i, rn) -> i).map(IgniteBiTuple::get2).iterator(), featCnt + 1);
    X.println("Batch loading took " + (System.currentTimeMillis() - before) + " ms.");
    for (IgniteBiTuple<Integer, DenseLocalOnHeapVector> bt : lst) {
        byRegion.putIfAbsent(bt.get1(), new LinkedList<>());
        byRegion.get(bt.get1()).add(asLabeledVector(bt.get2().getStorage().data()));
    }
    ColumnDecisionTreeTrainer<VarianceSplitCalculator.VarianceData> trainer = new ColumnDecisionTreeTrainer<>(2, ContinuousSplitCalculators.VARIANCE, RegionCalculators.VARIANCE, regCalc, ignite);
    before = System.currentTimeMillis();
    DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, catsInfo));
    X.println("Training took: " + (System.currentTimeMillis() - before) + " ms.");
    byRegion.keySet().forEach(k -> {
        LabeledVectorDouble sp = byRegion.get(k).get(0);
        Tracer.showAscii(sp.features());
        X.println("Predicted value and label [pred=" + mdl.apply(sp.features()) + ", label=" + sp.doubleLabel() + "]");
        assert mdl.apply(sp.features()) == sp.doubleLabel();
    });
}
Also used : CacheAtomicityMode(org.apache.ignite.cache.CacheAtomicityMode) Arrays(java.util.Arrays) FeaturesCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache) IgniteTestResources(org.apache.ignite.testframework.junits.IgniteTestResources) Random(java.util.Random) BiIndex(org.apache.ignite.ml.trees.trainers.columnbased.BiIndex) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) SparseDistributedMatrixStorage(org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorage) VarianceSplitCalculator(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.VarianceSplitCalculator) Vector(org.apache.ignite.ml.math.Vector) Estimators(org.apache.ignite.ml.estimators.Estimators) Map(java.util.Map) X(org.apache.ignite.internal.util.typedef.X) Level(org.apache.log4j.Level) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) LabeledVectorDouble(org.apache.ignite.ml.structures.LabeledVectorDouble) BaseDecisionTreeTest(org.apache.ignite.ml.trees.BaseDecisionTreeTest) IgniteTriFunction(org.apache.ignite.ml.math.functions.IgniteTriFunction) ProjectionsCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache) UUID(java.util.UUID) StreamTransformer(org.apache.ignite.stream.StreamTransformer) Collectors(java.util.stream.Collectors) IgniteCache(org.apache.ignite.IgniteCache) ContextCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ContextCache) DoubleStream(java.util.stream.DoubleStream) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) List(java.util.List) IgniteConfiguration(org.apache.ignite.configuration.IgniteConfiguration) Stream(java.util.stream.Stream) SparseMatrixKey(org.apache.ignite.ml.math.distributed.keys.impl.SparseMatrixKey) SplitCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache) RegionCalculators(org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators) IntStream(java.util.stream.IntStream) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) Model(org.apache.ignite.ml.Model) HashMap(java.util.HashMap) Function(java.util.function.Function) GiniSplitCalculator(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator) BiIndexedCacheColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput) CacheWriteSynchronizationMode(org.apache.ignite.cache.CacheWriteSynchronizationMode) IgniteUtils(org.apache.ignite.internal.util.IgniteUtils) MnistUtils(org.apache.ignite.ml.util.MnistUtils) LinkedList(java.util.LinkedList) Properties(java.util.Properties) Iterator(java.util.Iterator) ContinuousSplitCalculators(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators) IOException(java.io.IOException) SplitDataGenerator(org.apache.ignite.ml.trees.SplitDataGenerator) Int2DoubleOpenHashMap(it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap) Ignition(org.apache.ignite.Ignition) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) IgniteDataStreamer(org.apache.ignite.IgniteDataStreamer) Tracer(org.apache.ignite.ml.math.Tracer) StorageConstants(org.apache.ignite.ml.math.StorageConstants) Assert(org.junit.Assert) Collections(java.util.Collections) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer) GridCacheProcessor(org.apache.ignite.internal.processors.cache.GridCacheProcessor) InputStream(java.io.InputStream) CacheMode(org.apache.ignite.cache.CacheMode) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) LabeledVectorDouble(org.apache.ignite.ml.structures.LabeledVectorDouble) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) HashMap(java.util.HashMap) Int2DoubleOpenHashMap(it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap) MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) SparseDistributedMatrixStorage(org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorage) LabeledVectorDouble(org.apache.ignite.ml.structures.LabeledVectorDouble) DoubleStream(java.util.stream.DoubleStream) List(java.util.List) LinkedList(java.util.LinkedList) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer)

Example 5 with DecisionTreeModel

use of org.apache.ignite.ml.trees.models.DecisionTreeModel in project ignite by apache.

the class DecisionTreesExample method main.

/**
 * Launches example.
 *
 * @param args Program arguments.
 */
public static void main(String[] args) throws IOException {
    System.out.println(">>> Decision trees example started.");
    String igniteCfgPath;
    CommandLineParser parser = new BasicParser();
    String trainingImagesPath;
    String trainingLabelsPath;
    String testImagesPath;
    String testLabelsPath;
    Map<String, String> mnistPaths = new HashMap<>();
    mnistPaths.put(MNIST_TRAIN_IMAGES, "train-images-idx3-ubyte");
    mnistPaths.put(MNIST_TRAIN_LABELS, "train-labels-idx1-ubyte");
    mnistPaths.put(MNIST_TEST_IMAGES, "t10k-images-idx3-ubyte");
    mnistPaths.put(MNIST_TEST_LABELS, "t10k-labels-idx1-ubyte");
    try {
        // Parse the command line arguments.
        CommandLine line = parser.parse(buildOptions(), args);
        if (line.hasOption(MLExamplesCommonArgs.UNATTENDED)) {
            System.out.println(">>> Skipped example execution because 'unattended' mode is used.");
            System.out.println(">>> Decision trees example finished.");
            return;
        }
        igniteCfgPath = line.getOptionValue(CONFIG, DEFAULT_CONFIG);
    } catch (ParseException e) {
        e.printStackTrace();
        return;
    }
    if (!getMNIST(mnistPaths.values())) {
        System.out.println(">>> You should have MNIST dataset in " + MNIST_DIR + " to run this example.");
        return;
    }
    trainingImagesPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + mnistPaths.get(MNIST_TRAIN_IMAGES))).getPath();
    trainingLabelsPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + mnistPaths.get(MNIST_TRAIN_LABELS))).getPath();
    testImagesPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + mnistPaths.get(MNIST_TEST_IMAGES))).getPath();
    testLabelsPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + mnistPaths.get(MNIST_TEST_LABELS))).getPath();
    try (Ignite ignite = Ignition.start(igniteCfgPath)) {
        IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
        int ptsCnt = 60000;
        int featCnt = 28 * 28;
        Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnist(trainingImagesPath, trainingLabelsPath, new Random(123L), ptsCnt);
        Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnist(testImagesPath, testLabelsPath, new Random(123L), 10_000);
        IgniteCache<BiIndex, Double> cache = createBiIndexedCache(ignite);
        loadVectorsIntoBiIndexedCache(cache.getName(), trainingMnistStream.iterator(), featCnt + 1, ignite);
        ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer = new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite);
        System.out.println(">>> Training started");
        long before = System.currentTimeMillis();
        DecisionTreeModel mdl = trainer.train(new BiIndexedCacheColumnDecisionTreeTrainerInput(cache, new HashMap<>(), ptsCnt, featCnt));
        System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before));
        IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.errorsPercentage();
        Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity());
        System.out.println(">>> Errs percentage: " + accuracy);
    } catch (IOException e) {
        e.printStackTrace();
    }
    System.out.println(">>> Decision trees example finished.");
}
Also used : GZIPInputStream(java.util.zip.GZIPInputStream) URL(java.net.URL) Scanner(java.util.Scanner) Random(java.util.Random) BiIndex(org.apache.ignite.ml.trees.trainers.columnbased.BiIndex) ExampleNodeStartup(org.apache.ignite.examples.ExampleNodeStartup) Vector(org.apache.ignite.ml.math.Vector) Estimators(org.apache.ignite.ml.estimators.Estimators) Map(java.util.Map) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) Collection(java.util.Collection) IgniteTriFunction(org.apache.ignite.ml.math.functions.IgniteTriFunction) Collectors(java.util.stream.Collectors) IgniteCache(org.apache.ignite.IgniteCache) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) Objects(java.util.Objects) List(java.util.List) Stream(java.util.stream.Stream) ParseException(org.apache.commons.cli.ParseException) RegionCalculators(org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators) NotNull(org.jetbrains.annotations.NotNull) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) Model(org.apache.ignite.ml.Model) Options(org.apache.commons.cli.Options) HashMap(java.util.HashMap) Function(java.util.function.Function) GiniSplitCalculator(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator) BiIndexedCacheColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput) OptionBuilder(org.apache.commons.cli.OptionBuilder) CacheWriteSynchronizationMode(org.apache.ignite.cache.CacheWriteSynchronizationMode) IgniteUtils(org.apache.ignite.internal.util.IgniteUtils) BasicParser(org.apache.commons.cli.BasicParser) CommandLine(org.apache.commons.cli.CommandLine) MnistUtils(org.apache.ignite.ml.util.MnistUtils) Option(org.apache.commons.cli.Option) ReadableByteChannel(java.nio.channels.ReadableByteChannel) Iterator(java.util.Iterator) ContinuousSplitCalculators(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators) CommandLineParser(org.apache.commons.cli.CommandLineParser) Channels(java.nio.channels.Channels) FileOutputStream(java.io.FileOutputStream) IOException(java.io.IOException) FileInputStream(java.io.FileInputStream) Ignite(org.apache.ignite.Ignite) MLExamplesCommonArgs(org.apache.ignite.examples.ml.MLExamplesCommonArgs) Ignition(org.apache.ignite.Ignition) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) IgniteDataStreamer(org.apache.ignite.IgniteDataStreamer) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer) HashMap(java.util.HashMap) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) IgniteTriFunction(org.apache.ignite.ml.math.functions.IgniteTriFunction) Function(java.util.function.Function) Random(java.util.Random) Ignite(org.apache.ignite.Ignite) GZIPInputStream(java.util.zip.GZIPInputStream) Stream(java.util.stream.Stream) FileOutputStream(java.io.FileOutputStream) FileInputStream(java.io.FileInputStream) CommandLineParser(org.apache.commons.cli.CommandLineParser) Vector(org.apache.ignite.ml.math.Vector) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer) BiIndex(org.apache.ignite.ml.trees.trainers.columnbased.BiIndex) IOException(java.io.IOException) BiIndexedCacheColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput) BasicParser(org.apache.commons.cli.BasicParser) CommandLine(org.apache.commons.cli.CommandLine) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) Model(org.apache.ignite.ml.Model) ParseException(org.apache.commons.cli.ParseException) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector)

Aggregations

HashMap (java.util.HashMap)8 List (java.util.List)8 Map (java.util.Map)8 Collectors (java.util.stream.Collectors)8 IgniteBiTuple (org.apache.ignite.lang.IgniteBiTuple)8 DecisionTreeModel (org.apache.ignite.ml.trees.models.DecisionTreeModel)8 Collections (java.util.Collections)7 LinkedList (java.util.LinkedList)7 Random (java.util.Random)7 DoubleStream (java.util.stream.DoubleStream)7 Stream (java.util.stream.Stream)7 Vector (org.apache.ignite.ml.math.Vector)7 IgniteFunction (org.apache.ignite.ml.math.functions.IgniteFunction)7 DenseLocalOnHeapVector (org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector)7 ColumnDecisionTreeTrainer (org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer)7 Arrays (java.util.Arrays)6 IntStream (java.util.stream.IntStream)6 IgniteCache (org.apache.ignite.IgniteCache)6 Ignition (org.apache.ignite.Ignition)6 IgniteUtils (org.apache.ignite.internal.util.IgniteUtils)6