Search in sources :

Example 6 with DecisionTreeModel

use of org.apache.ignite.ml.trees.models.DecisionTreeModel in project ignite by apache.

the class ColumnDecisionTreeTrainerTest method testByGen.

/**
 */
private <D extends ContinuousRegionInfo> void testByGen(int totalPts, HashMap<Integer, Integer> catsInfo, SplitDataGenerator<DenseLocalOnHeapVector> gen, IgniteFunction<ColumnDecisionTreeTrainerInput, ? extends ContinuousSplitCalculator<D>> calc, IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> catImpCalc, IgniteFunction<DoubleStream, Double> regCalc, Random rnd) {
    List<IgniteBiTuple<Integer, DenseLocalOnHeapVector>> lst = gen.points(totalPts, (i, rn) -> i).collect(Collectors.toList());
    int featCnt = gen.featuresCnt();
    Collections.shuffle(lst, rnd);
    SparseDistributedMatrix m = new SparseDistributedMatrix(totalPts, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
    Map<Integer, List<LabeledVectorDouble>> byRegion = new HashMap<>();
    int i = 0;
    for (IgniteBiTuple<Integer, DenseLocalOnHeapVector> bt : lst) {
        byRegion.putIfAbsent(bt.get1(), new LinkedList<>());
        byRegion.get(bt.get1()).add(asLabeledVector(bt.get2().getStorage().data()));
        m.setRow(i, bt.get2().getStorage().data());
        i++;
    }
    ColumnDecisionTreeTrainer<D> trainer = new ColumnDecisionTreeTrainer<>(3, calc, catImpCalc, regCalc, ignite);
    DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, catsInfo));
    byRegion.keySet().forEach(k -> {
        LabeledVectorDouble sp = byRegion.get(k).get(0);
        Tracer.showAscii(sp.features());
        X.println("Actual and predicted vectors [act=" + sp.label() + " " + ", pred=" + mdl.apply(sp.features()) + "]");
        assert mdl.apply(sp.features()) == sp.doubleLabel();
    });
}
Also used : MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) LabeledVectorDouble(org.apache.ignite.ml.structures.LabeledVectorDouble) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) ContinuousSplitCalculators(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) ColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainerInput) HashMap(java.util.HashMap) Random(java.util.Random) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) Collectors(java.util.stream.Collectors) DoubleStream(java.util.stream.DoubleStream) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) List(java.util.List) Map(java.util.Map) IgniteUtils(org.apache.ignite.internal.util.IgniteUtils) X(org.apache.ignite.internal.util.typedef.X) Tracer(org.apache.ignite.ml.math.Tracer) LinkedList(java.util.LinkedList) StorageConstants(org.apache.ignite.ml.math.StorageConstants) RegionCalculators(org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators) Collections(java.util.Collections) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) LabeledVectorDouble(org.apache.ignite.ml.structures.LabeledVectorDouble) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) HashMap(java.util.HashMap) MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) List(java.util.List) LinkedList(java.util.LinkedList) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer)

Example 7 with DecisionTreeModel

use of org.apache.ignite.ml.trees.models.DecisionTreeModel in project ignite by apache.

the class ColumnDecisionTreeTrainer method doTrain.

/**
 */
@NotNull
private DecisionTreeModel doTrain(ColumnDecisionTreeTrainerInput input, UUID uuid) {
    RootNode root = new RootNode();
    // List containing setters of leaves of the tree.
    List<TreeTip> tips = new LinkedList<>();
    tips.add(new TreeTip(root::setSplit, 0));
    int curDepth = 0;
    int regsCnt = 1;
    int featuresCnt = input.featuresCount();
    IntStream.range(0, featuresCnt).mapToObj(fIdx -> SplitCache.key(fIdx, input.affinityKey(fIdx, ignite), uuid)).forEach(k -> SplitCache.getOrCreate(ignite).put(k, new IgniteBiTuple<>(0, 0.0)));
    updateSplitCache(0, regsCnt, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid);
    // regions cannot be split more and split only those that can.
    while (true) {
        long before = System.currentTimeMillis();
        IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>> b = findBestSplitIndexForFeatures(featuresCnt, input::affinityKey, uuid);
        long findBestRegIdx = System.currentTimeMillis() - before;
        Integer bestFeatureIdx = b.get1();
        Integer regIdx = b.get2().get1();
        Double bestInfoGain = b.get2().get2();
        if (regIdx >= 0 && bestInfoGain > MIN_INFO_GAIN) {
            before = System.currentTimeMillis();
            SplitInfo bi = ignite.compute().affinityCall(ProjectionsCache.CACHE_NAME, input.affinityKey(bestFeatureIdx, ignite), () -> {
                TrainingContext<ContinuousRegionInfo> ctx = ContextCache.getOrCreate(ignite).get(uuid);
                Ignite ignite = Ignition.localIgnite();
                RegionKey key = ProjectionsCache.key(bestFeatureIdx, regIdx / BLOCK_SIZE, input.affinityKey(bestFeatureIdx, Ignition.localIgnite()), uuid);
                RegionProjection reg = ProjectionsCache.getOrCreate(ignite).localPeek(key).get(regIdx % BLOCK_SIZE);
                return ctx.featureProcessor(bestFeatureIdx).findBestSplit(reg, ctx.values(bestFeatureIdx, ignite), ctx.labels(), regIdx);
            });
            long findBestSplit = System.currentTimeMillis() - before;
            IndexAndSplitInfo best = new IndexAndSplitInfo(bestFeatureIdx, bi);
            regsCnt++;
            if (log.isDebugEnabled())
                log.debug("Globally best: " + best.info + " idx time: " + findBestRegIdx + ", calculate best: " + findBestSplit + " fi: " + best.featureIdx + ", regs: " + regsCnt);
            // Request bitset for split region.
            int ind = best.info.regionIndex();
            SparseBitSet bs = ignite.compute().affinityCall(ProjectionsCache.CACHE_NAME, input.affinityKey(bestFeatureIdx, ignite), () -> {
                Ignite ignite = Ignition.localIgnite();
                IgniteCache<FeatureKey, double[]> featuresCache = FeaturesCache.getOrCreate(ignite);
                IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ignite);
                TrainingContext ctx = ctxCache.localPeek(uuid);
                double[] values = featuresCache.localPeek(getFeatureCacheKey(bestFeatureIdx, uuid, input.affinityKey(bestFeatureIdx, Ignition.localIgnite())));
                RegionKey key = ProjectionsCache.key(bestFeatureIdx, regIdx / BLOCK_SIZE, input.affinityKey(bestFeatureIdx, Ignition.localIgnite()), uuid);
                RegionProjection reg = ProjectionsCache.getOrCreate(ignite).localPeek(key).get(regIdx % BLOCK_SIZE);
                return ctx.featureProcessor(bestFeatureIdx).calculateOwnershipBitSet(reg, values, best.info);
            });
            SplitNode sn = best.info.createSplitNode(best.featureIdx);
            TreeTip tipToSplit = tips.get(ind);
            tipToSplit.leafSetter.accept(sn);
            tipToSplit.leafSetter = sn::setLeft;
            int d = tipToSplit.depth++;
            tips.add(new TreeTip(sn::setRight, d));
            if (d > curDepth) {
                curDepth = d;
                if (log.isDebugEnabled()) {
                    log.debug("Depth: " + curDepth);
                    log.debug("Cache size: " + prjsCache.size(CachePeekMode.PRIMARY));
                }
            }
            before = System.currentTimeMillis();
            // Perform split on all feature vectors.
            IgniteSupplier<Set<RegionKey>> bestRegsKeys = () -> IntStream.range(0, featuresCnt).mapToObj(fIdx -> ProjectionsCache.key(fIdx, ind / BLOCK_SIZE, input.affinityKey(fIdx, Ignition.localIgnite()), uuid)).collect(Collectors.toSet());
            int rc = regsCnt;
            // Perform split.
            CacheUtils.update(prjsCache.getName(), ignite, (Ignite ign, Cache.Entry<RegionKey, List<RegionProjection>> e) -> {
                RegionKey k = e.getKey();
                List<RegionProjection> leftBlock = e.getValue();
                int fIdx = k.featureIdx();
                int idxInBlock = ind % BLOCK_SIZE;
                IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ign);
                TrainingContext<D> ctx = ctxCache.get(uuid);
                RegionProjection targetRegProj = leftBlock.get(idxInBlock);
                IgniteBiTuple<RegionProjection, RegionProjection> regs = ctx.performSplit(input, bs, fIdx, best.featureIdx, targetRegProj, best.info.leftData(), best.info.rightData(), ign);
                RegionProjection left = regs.get1();
                RegionProjection right = regs.get2();
                leftBlock.set(idxInBlock, left);
                RegionKey rightKey = ProjectionsCache.key(fIdx, (rc - 1) / BLOCK_SIZE, input.affinityKey(fIdx, ign), uuid);
                IgniteCache<RegionKey, List<RegionProjection>> c = ProjectionsCache.getOrCreate(ign);
                List<RegionProjection> rightBlock = rightKey.equals(k) ? leftBlock : c.localPeek(rightKey);
                if (rightBlock == null) {
                    List<RegionProjection> newBlock = new ArrayList<>(BLOCK_SIZE);
                    newBlock.add(right);
                    return Stream.of(new CacheEntryImpl<>(k, leftBlock), new CacheEntryImpl<>(rightKey, newBlock));
                } else {
                    rightBlock.add(right);
                    return rightBlock.equals(k) ? Stream.of(new CacheEntryImpl<>(k, leftBlock)) : Stream.of(new CacheEntryImpl<>(k, leftBlock), new CacheEntryImpl<>(rightKey, rightBlock));
                }
            }, bestRegsKeys);
            if (log.isDebugEnabled())
                log.debug("Update of projections cache time: " + (System.currentTimeMillis() - before));
            before = System.currentTimeMillis();
            updateSplitCache(ind, rc, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid);
            if (log.isDebugEnabled())
                log.debug("Update of split cache time: " + (System.currentTimeMillis() - before));
        } else {
            if (log.isDebugEnabled())
                log.debug("Best split [bestFeatureIdx=" + bestFeatureIdx + ", bestInfoGain=" + bestInfoGain + "]");
            break;
        }
    }
    int rc = regsCnt;
    IgniteSupplier<Iterable<Cache.Entry<RegionKey, List<RegionProjection>>>> featZeroRegs = () -> {
        IgniteCache<RegionKey, List<RegionProjection>> projsCache = ProjectionsCache.getOrCreate(Ignition.localIgnite());
        return () -> IntStream.range(0, (rc - 1) / BLOCK_SIZE + 1).mapToObj(rBIdx -> ProjectionsCache.key(0, rBIdx, input.affinityKey(0, Ignition.localIgnite()), uuid)).map(k -> (Cache.Entry<RegionKey, List<RegionProjection>>) new CacheEntryImpl<>(k, projsCache.localPeek(k))).iterator();
    };
    Map<Integer, Double> vals = CacheUtils.reduce(prjsCache.getName(), ignite, (TrainingContext ctx, Cache.Entry<RegionKey, List<RegionProjection>> e, Map<Integer, Double> m) -> {
        int regBlockIdx = e.getKey().regionBlockIndex();
        if (e.getValue() != null) {
            for (int i = 0; i < e.getValue().size(); i++) {
                int regIdx = regBlockIdx * BLOCK_SIZE + i;
                RegionProjection reg = e.getValue().get(i);
                Double res = regCalc.apply(Arrays.stream(reg.sampleIndexes()).mapToDouble(s -> ctx.labels()[s]));
                m.put(regIdx, res);
            }
        }
        return m;
    }, () -> ContextCache.getOrCreate(Ignition.localIgnite()).get(uuid), featZeroRegs, (infos, infos2) -> {
        Map<Integer, Double> res = new HashMap<>();
        res.putAll(infos);
        res.putAll(infos2);
        return res;
    }, HashMap::new);
    int i = 0;
    for (TreeTip tip : tips) {
        tip.leafSetter.accept(new Leaf(vals.get(i)));
        i++;
    }
    ProjectionsCache.clear(featuresCnt, rc, input::affinityKey, uuid, ignite);
    ContextCache.getOrCreate(ignite).remove(uuid);
    FeaturesCache.clear(featuresCnt, input::affinityKey, uuid, ignite);
    SplitCache.clear(featuresCnt, input::affinityKey, uuid, ignite);
    return new DecisionTreeModel(root.s);
}
Also used : Leaf(org.apache.ignite.ml.trees.nodes.Leaf) Arrays(java.util.Arrays) FeaturesCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache) SplitNode(org.apache.ignite.ml.trees.nodes.SplitNode) Functions(org.apache.ignite.ml.math.functions.Functions) RegionKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache.RegionKey) Vector(org.apache.ignite.ml.math.Vector) Map(java.util.Map) Cache(javax.cache.Cache) SparseBitSet(com.zaxxer.sparsebits.SparseBitSet) CacheEntryImpl(org.apache.ignite.internal.processors.cache.CacheEntryImpl) FeatureKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.FeatureKey) CachePeekMode(org.apache.ignite.cache.CachePeekMode) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) ProjectionsCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache) Set(java.util.Set) UUID(java.util.UUID) Collectors(java.util.stream.Collectors) SplitKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache.SplitKey) IgniteCache(org.apache.ignite.IgniteCache) ContextCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ContextCache) DoubleStream(java.util.stream.DoubleStream) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) List(java.util.List) Stream(java.util.stream.Stream) FeatureProcessor(org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor) SplitCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache) Optional(java.util.Optional) NotNull(org.jetbrains.annotations.NotNull) SplitInfo(org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo) IntStream(java.util.stream.IntStream) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) FeaturesCache.getFeatureCacheKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.getFeatureCacheKey) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) Affinity(org.apache.ignite.cache.affinity.Affinity) HashMap(java.util.HashMap) IgniteLogger(org.apache.ignite.IgniteLogger) ArrayList(java.util.ArrayList) Trainer(org.apache.ignite.ml.Trainer) ClusterNode(org.apache.ignite.cluster.ClusterNode) LinkedList(java.util.LinkedList) DecisionTreeNode(org.apache.ignite.ml.trees.nodes.DecisionTreeNode) Ignite(org.apache.ignite.Ignite) CacheUtils(org.apache.ignite.ml.math.distributed.CacheUtils) IgniteCurriedBiFunction(org.apache.ignite.ml.math.functions.IgniteCurriedBiFunction) Consumer(java.util.function.Consumer) Ignition(org.apache.ignite.Ignition) IgniteBiFunction(org.apache.ignite.ml.math.functions.IgniteBiFunction) ContinuousSplitCalculator(org.apache.ignite.ml.trees.ContinuousSplitCalculator) Comparator(java.util.Comparator) ContinuousRegionInfo(org.apache.ignite.ml.trees.ContinuousRegionInfo) Collections(java.util.Collections) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) ContinuousRegionInfo(org.apache.ignite.ml.trees.ContinuousRegionInfo) SplitNode(org.apache.ignite.ml.trees.nodes.SplitNode) List(java.util.List) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) UUID(java.util.UUID) LinkedList(java.util.LinkedList) Map(java.util.Map) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) FeaturesCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache) Cache(javax.cache.Cache) ProjectionsCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache) IgniteCache(org.apache.ignite.IgniteCache) ContextCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ContextCache) SplitCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache) SparseBitSet(com.zaxxer.sparsebits.SparseBitSet) Set(java.util.Set) UUID(java.util.UUID) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) RegionKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache.RegionKey) SplitInfo(org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo) CacheEntryImpl(org.apache.ignite.internal.processors.cache.CacheEntryImpl) Ignite(org.apache.ignite.Ignite) Leaf(org.apache.ignite.ml.trees.nodes.Leaf) IgniteCache(org.apache.ignite.IgniteCache) SparseBitSet(com.zaxxer.sparsebits.SparseBitSet) FeatureKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.FeatureKey) NotNull(org.jetbrains.annotations.NotNull)

Example 8 with DecisionTreeModel

use of org.apache.ignite.ml.trees.models.DecisionTreeModel in project ignite by apache.

the class ColumnDecisionTreeTrainerBenchmark method tstF1.

/**
 * Test decision tree regression.
 * To run this test rename this method so it starts from 'test'.
 */
public void tstF1() {
    IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
    int ptsCnt = 10000;
    Map<Integer, double[]> ranges = new HashMap<>();
    ranges.put(0, new double[] { -100.0, 100.0 });
    ranges.put(1, new double[] { -100.0, 100.0 });
    ranges.put(2, new double[] { -100.0, 100.0 });
    int featCnt = 100;
    double[] defRng = { -1.0, 1.0 };
    Vector[] trainVectors = vecsFromRanges(ranges, featCnt, defRng, new Random(123L), ptsCnt, f1);
    SparseDistributedMatrix m = new SparseDistributedMatrix(ptsCnt, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
    SparseDistributedMatrixStorage sto = (SparseDistributedMatrixStorage) m.getStorage();
    loadVectorsIntoSparseDistributedMatrixCache(sto.cache().getName(), sto.getUUID(), Arrays.stream(trainVectors).iterator(), featCnt + 1);
    IgniteFunction<DoubleStream, Double> regCalc = s -> s.average().orElse(0.0);
    ColumnDecisionTreeTrainer<VarianceSplitCalculator.VarianceData> trainer = new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.VARIANCE, RegionCalculators.VARIANCE, regCalc, ignite);
    X.println("Training started.");
    long before = System.currentTimeMillis();
    DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, new HashMap<>()));
    X.println("Training finished in: " + (System.currentTimeMillis() - before) + " ms.");
    Vector[] testVectors = vecsFromRanges(ranges, featCnt, defRng, new Random(123L), 20, f1);
    IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.MSE();
    Double accuracy = mse.apply(mdl, Arrays.stream(testVectors).map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity());
    X.println("MSE: " + accuracy);
}
Also used : CacheAtomicityMode(org.apache.ignite.cache.CacheAtomicityMode) Arrays(java.util.Arrays) FeaturesCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache) IgniteTestResources(org.apache.ignite.testframework.junits.IgniteTestResources) Random(java.util.Random) BiIndex(org.apache.ignite.ml.trees.trainers.columnbased.BiIndex) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) SparseDistributedMatrixStorage(org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorage) VarianceSplitCalculator(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.VarianceSplitCalculator) Vector(org.apache.ignite.ml.math.Vector) Estimators(org.apache.ignite.ml.estimators.Estimators) Map(java.util.Map) X(org.apache.ignite.internal.util.typedef.X) Level(org.apache.log4j.Level) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) LabeledVectorDouble(org.apache.ignite.ml.structures.LabeledVectorDouble) BaseDecisionTreeTest(org.apache.ignite.ml.trees.BaseDecisionTreeTest) IgniteTriFunction(org.apache.ignite.ml.math.functions.IgniteTriFunction) ProjectionsCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache) UUID(java.util.UUID) StreamTransformer(org.apache.ignite.stream.StreamTransformer) Collectors(java.util.stream.Collectors) IgniteCache(org.apache.ignite.IgniteCache) ContextCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ContextCache) DoubleStream(java.util.stream.DoubleStream) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) List(java.util.List) IgniteConfiguration(org.apache.ignite.configuration.IgniteConfiguration) Stream(java.util.stream.Stream) SparseMatrixKey(org.apache.ignite.ml.math.distributed.keys.impl.SparseMatrixKey) SplitCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache) RegionCalculators(org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators) IntStream(java.util.stream.IntStream) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) Model(org.apache.ignite.ml.Model) HashMap(java.util.HashMap) Function(java.util.function.Function) GiniSplitCalculator(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator) BiIndexedCacheColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput) CacheWriteSynchronizationMode(org.apache.ignite.cache.CacheWriteSynchronizationMode) IgniteUtils(org.apache.ignite.internal.util.IgniteUtils) MnistUtils(org.apache.ignite.ml.util.MnistUtils) LinkedList(java.util.LinkedList) Properties(java.util.Properties) Iterator(java.util.Iterator) ContinuousSplitCalculators(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators) IOException(java.io.IOException) SplitDataGenerator(org.apache.ignite.ml.trees.SplitDataGenerator) Int2DoubleOpenHashMap(it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap) Ignition(org.apache.ignite.Ignition) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) IgniteDataStreamer(org.apache.ignite.IgniteDataStreamer) Tracer(org.apache.ignite.ml.math.Tracer) StorageConstants(org.apache.ignite.ml.math.StorageConstants) Assert(org.junit.Assert) Collections(java.util.Collections) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer) GridCacheProcessor(org.apache.ignite.internal.processors.cache.GridCacheProcessor) InputStream(java.io.InputStream) CacheMode(org.apache.ignite.cache.CacheMode) HashMap(java.util.HashMap) Int2DoubleOpenHashMap(it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap) MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) IgniteTriFunction(org.apache.ignite.ml.math.functions.IgniteTriFunction) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) Function(java.util.function.Function) Random(java.util.Random) DoubleStream(java.util.stream.DoubleStream) Stream(java.util.stream.Stream) IntStream(java.util.stream.IntStream) InputStream(java.io.InputStream) Vector(org.apache.ignite.ml.math.Vector) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) SparseDistributedMatrixStorage(org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorage) LabeledVectorDouble(org.apache.ignite.ml.structures.LabeledVectorDouble) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) Model(org.apache.ignite.ml.Model) DoubleStream(java.util.stream.DoubleStream)

Aggregations

HashMap (java.util.HashMap)8 List (java.util.List)8 Map (java.util.Map)8 Collectors (java.util.stream.Collectors)8 IgniteBiTuple (org.apache.ignite.lang.IgniteBiTuple)8 DecisionTreeModel (org.apache.ignite.ml.trees.models.DecisionTreeModel)8 Collections (java.util.Collections)7 LinkedList (java.util.LinkedList)7 Random (java.util.Random)7 DoubleStream (java.util.stream.DoubleStream)7 Stream (java.util.stream.Stream)7 Vector (org.apache.ignite.ml.math.Vector)7 IgniteFunction (org.apache.ignite.ml.math.functions.IgniteFunction)7 DenseLocalOnHeapVector (org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector)7 ColumnDecisionTreeTrainer (org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer)7 Arrays (java.util.Arrays)6 IntStream (java.util.stream.IntStream)6 IgniteCache (org.apache.ignite.IgniteCache)6 Ignition (org.apache.ignite.Ignition)6 IgniteUtils (org.apache.ignite.internal.util.IgniteUtils)6