Search in sources :

Example 1 with SplitNode

use of org.apache.ignite.ml.trees.nodes.SplitNode in project ignite by apache.

the class ColumnDecisionTreeTrainer method doTrain.

/**
 */
@NotNull
private DecisionTreeModel doTrain(ColumnDecisionTreeTrainerInput input, UUID uuid) {
    RootNode root = new RootNode();
    // List containing setters of leaves of the tree.
    List<TreeTip> tips = new LinkedList<>();
    tips.add(new TreeTip(root::setSplit, 0));
    int curDepth = 0;
    int regsCnt = 1;
    int featuresCnt = input.featuresCount();
    IntStream.range(0, featuresCnt).mapToObj(fIdx -> SplitCache.key(fIdx, input.affinityKey(fIdx, ignite), uuid)).forEach(k -> SplitCache.getOrCreate(ignite).put(k, new IgniteBiTuple<>(0, 0.0)));
    updateSplitCache(0, regsCnt, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid);
    // regions cannot be split more and split only those that can.
    while (true) {
        long before = System.currentTimeMillis();
        IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>> b = findBestSplitIndexForFeatures(featuresCnt, input::affinityKey, uuid);
        long findBestRegIdx = System.currentTimeMillis() - before;
        Integer bestFeatureIdx = b.get1();
        Integer regIdx = b.get2().get1();
        Double bestInfoGain = b.get2().get2();
        if (regIdx >= 0 && bestInfoGain > MIN_INFO_GAIN) {
            before = System.currentTimeMillis();
            SplitInfo bi = ignite.compute().affinityCall(ProjectionsCache.CACHE_NAME, input.affinityKey(bestFeatureIdx, ignite), () -> {
                TrainingContext<ContinuousRegionInfo> ctx = ContextCache.getOrCreate(ignite).get(uuid);
                Ignite ignite = Ignition.localIgnite();
                RegionKey key = ProjectionsCache.key(bestFeatureIdx, regIdx / BLOCK_SIZE, input.affinityKey(bestFeatureIdx, Ignition.localIgnite()), uuid);
                RegionProjection reg = ProjectionsCache.getOrCreate(ignite).localPeek(key).get(regIdx % BLOCK_SIZE);
                return ctx.featureProcessor(bestFeatureIdx).findBestSplit(reg, ctx.values(bestFeatureIdx, ignite), ctx.labels(), regIdx);
            });
            long findBestSplit = System.currentTimeMillis() - before;
            IndexAndSplitInfo best = new IndexAndSplitInfo(bestFeatureIdx, bi);
            regsCnt++;
            if (log.isDebugEnabled())
                log.debug("Globally best: " + best.info + " idx time: " + findBestRegIdx + ", calculate best: " + findBestSplit + " fi: " + best.featureIdx + ", regs: " + regsCnt);
            // Request bitset for split region.
            int ind = best.info.regionIndex();
            SparseBitSet bs = ignite.compute().affinityCall(ProjectionsCache.CACHE_NAME, input.affinityKey(bestFeatureIdx, ignite), () -> {
                Ignite ignite = Ignition.localIgnite();
                IgniteCache<FeatureKey, double[]> featuresCache = FeaturesCache.getOrCreate(ignite);
                IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ignite);
                TrainingContext ctx = ctxCache.localPeek(uuid);
                double[] values = featuresCache.localPeek(getFeatureCacheKey(bestFeatureIdx, uuid, input.affinityKey(bestFeatureIdx, Ignition.localIgnite())));
                RegionKey key = ProjectionsCache.key(bestFeatureIdx, regIdx / BLOCK_SIZE, input.affinityKey(bestFeatureIdx, Ignition.localIgnite()), uuid);
                RegionProjection reg = ProjectionsCache.getOrCreate(ignite).localPeek(key).get(regIdx % BLOCK_SIZE);
                return ctx.featureProcessor(bestFeatureIdx).calculateOwnershipBitSet(reg, values, best.info);
            });
            SplitNode sn = best.info.createSplitNode(best.featureIdx);
            TreeTip tipToSplit = tips.get(ind);
            tipToSplit.leafSetter.accept(sn);
            tipToSplit.leafSetter = sn::setLeft;
            int d = tipToSplit.depth++;
            tips.add(new TreeTip(sn::setRight, d));
            if (d > curDepth) {
                curDepth = d;
                if (log.isDebugEnabled()) {
                    log.debug("Depth: " + curDepth);
                    log.debug("Cache size: " + prjsCache.size(CachePeekMode.PRIMARY));
                }
            }
            before = System.currentTimeMillis();
            // Perform split on all feature vectors.
            IgniteSupplier<Set<RegionKey>> bestRegsKeys = () -> IntStream.range(0, featuresCnt).mapToObj(fIdx -> ProjectionsCache.key(fIdx, ind / BLOCK_SIZE, input.affinityKey(fIdx, Ignition.localIgnite()), uuid)).collect(Collectors.toSet());
            int rc = regsCnt;
            // Perform split.
            CacheUtils.update(prjsCache.getName(), ignite, (Ignite ign, Cache.Entry<RegionKey, List<RegionProjection>> e) -> {
                RegionKey k = e.getKey();
                List<RegionProjection> leftBlock = e.getValue();
                int fIdx = k.featureIdx();
                int idxInBlock = ind % BLOCK_SIZE;
                IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ign);
                TrainingContext<D> ctx = ctxCache.get(uuid);
                RegionProjection targetRegProj = leftBlock.get(idxInBlock);
                IgniteBiTuple<RegionProjection, RegionProjection> regs = ctx.performSplit(input, bs, fIdx, best.featureIdx, targetRegProj, best.info.leftData(), best.info.rightData(), ign);
                RegionProjection left = regs.get1();
                RegionProjection right = regs.get2();
                leftBlock.set(idxInBlock, left);
                RegionKey rightKey = ProjectionsCache.key(fIdx, (rc - 1) / BLOCK_SIZE, input.affinityKey(fIdx, ign), uuid);
                IgniteCache<RegionKey, List<RegionProjection>> c = ProjectionsCache.getOrCreate(ign);
                List<RegionProjection> rightBlock = rightKey.equals(k) ? leftBlock : c.localPeek(rightKey);
                if (rightBlock == null) {
                    List<RegionProjection> newBlock = new ArrayList<>(BLOCK_SIZE);
                    newBlock.add(right);
                    return Stream.of(new CacheEntryImpl<>(k, leftBlock), new CacheEntryImpl<>(rightKey, newBlock));
                } else {
                    rightBlock.add(right);
                    return rightBlock.equals(k) ? Stream.of(new CacheEntryImpl<>(k, leftBlock)) : Stream.of(new CacheEntryImpl<>(k, leftBlock), new CacheEntryImpl<>(rightKey, rightBlock));
                }
            }, bestRegsKeys);
            if (log.isDebugEnabled())
                log.debug("Update of projections cache time: " + (System.currentTimeMillis() - before));
            before = System.currentTimeMillis();
            updateSplitCache(ind, rc, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid);
            if (log.isDebugEnabled())
                log.debug("Update of split cache time: " + (System.currentTimeMillis() - before));
        } else {
            if (log.isDebugEnabled())
                log.debug("Best split [bestFeatureIdx=" + bestFeatureIdx + ", bestInfoGain=" + bestInfoGain + "]");
            break;
        }
    }
    int rc = regsCnt;
    IgniteSupplier<Iterable<Cache.Entry<RegionKey, List<RegionProjection>>>> featZeroRegs = () -> {
        IgniteCache<RegionKey, List<RegionProjection>> projsCache = ProjectionsCache.getOrCreate(Ignition.localIgnite());
        return () -> IntStream.range(0, (rc - 1) / BLOCK_SIZE + 1).mapToObj(rBIdx -> ProjectionsCache.key(0, rBIdx, input.affinityKey(0, Ignition.localIgnite()), uuid)).map(k -> (Cache.Entry<RegionKey, List<RegionProjection>>) new CacheEntryImpl<>(k, projsCache.localPeek(k))).iterator();
    };
    Map<Integer, Double> vals = CacheUtils.reduce(prjsCache.getName(), ignite, (TrainingContext ctx, Cache.Entry<RegionKey, List<RegionProjection>> e, Map<Integer, Double> m) -> {
        int regBlockIdx = e.getKey().regionBlockIndex();
        if (e.getValue() != null) {
            for (int i = 0; i < e.getValue().size(); i++) {
                int regIdx = regBlockIdx * BLOCK_SIZE + i;
                RegionProjection reg = e.getValue().get(i);
                Double res = regCalc.apply(Arrays.stream(reg.sampleIndexes()).mapToDouble(s -> ctx.labels()[s]));
                m.put(regIdx, res);
            }
        }
        return m;
    }, () -> ContextCache.getOrCreate(Ignition.localIgnite()).get(uuid), featZeroRegs, (infos, infos2) -> {
        Map<Integer, Double> res = new HashMap<>();
        res.putAll(infos);
        res.putAll(infos2);
        return res;
    }, HashMap::new);
    int i = 0;
    for (TreeTip tip : tips) {
        tip.leafSetter.accept(new Leaf(vals.get(i)));
        i++;
    }
    ProjectionsCache.clear(featuresCnt, rc, input::affinityKey, uuid, ignite);
    ContextCache.getOrCreate(ignite).remove(uuid);
    FeaturesCache.clear(featuresCnt, input::affinityKey, uuid, ignite);
    SplitCache.clear(featuresCnt, input::affinityKey, uuid, ignite);
    return new DecisionTreeModel(root.s);
}
Also used : Leaf(org.apache.ignite.ml.trees.nodes.Leaf) Arrays(java.util.Arrays) FeaturesCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache) SplitNode(org.apache.ignite.ml.trees.nodes.SplitNode) Functions(org.apache.ignite.ml.math.functions.Functions) RegionKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache.RegionKey) Vector(org.apache.ignite.ml.math.Vector) Map(java.util.Map) Cache(javax.cache.Cache) SparseBitSet(com.zaxxer.sparsebits.SparseBitSet) CacheEntryImpl(org.apache.ignite.internal.processors.cache.CacheEntryImpl) FeatureKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.FeatureKey) CachePeekMode(org.apache.ignite.cache.CachePeekMode) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) ProjectionsCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache) Set(java.util.Set) UUID(java.util.UUID) Collectors(java.util.stream.Collectors) SplitKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache.SplitKey) IgniteCache(org.apache.ignite.IgniteCache) ContextCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ContextCache) DoubleStream(java.util.stream.DoubleStream) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) List(java.util.List) Stream(java.util.stream.Stream) FeatureProcessor(org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor) SplitCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache) Optional(java.util.Optional) NotNull(org.jetbrains.annotations.NotNull) SplitInfo(org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo) IntStream(java.util.stream.IntStream) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) FeaturesCache.getFeatureCacheKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.getFeatureCacheKey) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) Affinity(org.apache.ignite.cache.affinity.Affinity) HashMap(java.util.HashMap) IgniteLogger(org.apache.ignite.IgniteLogger) ArrayList(java.util.ArrayList) Trainer(org.apache.ignite.ml.Trainer) ClusterNode(org.apache.ignite.cluster.ClusterNode) LinkedList(java.util.LinkedList) DecisionTreeNode(org.apache.ignite.ml.trees.nodes.DecisionTreeNode) Ignite(org.apache.ignite.Ignite) CacheUtils(org.apache.ignite.ml.math.distributed.CacheUtils) IgniteCurriedBiFunction(org.apache.ignite.ml.math.functions.IgniteCurriedBiFunction) Consumer(java.util.function.Consumer) Ignition(org.apache.ignite.Ignition) IgniteBiFunction(org.apache.ignite.ml.math.functions.IgniteBiFunction) ContinuousSplitCalculator(org.apache.ignite.ml.trees.ContinuousSplitCalculator) Comparator(java.util.Comparator) ContinuousRegionInfo(org.apache.ignite.ml.trees.ContinuousRegionInfo) Collections(java.util.Collections) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) ContinuousRegionInfo(org.apache.ignite.ml.trees.ContinuousRegionInfo) SplitNode(org.apache.ignite.ml.trees.nodes.SplitNode) List(java.util.List) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) UUID(java.util.UUID) LinkedList(java.util.LinkedList) Map(java.util.Map) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) FeaturesCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache) Cache(javax.cache.Cache) ProjectionsCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache) IgniteCache(org.apache.ignite.IgniteCache) ContextCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ContextCache) SplitCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache) SparseBitSet(com.zaxxer.sparsebits.SparseBitSet) Set(java.util.Set) UUID(java.util.UUID) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) RegionKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache.RegionKey) SplitInfo(org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo) CacheEntryImpl(org.apache.ignite.internal.processors.cache.CacheEntryImpl) Ignite(org.apache.ignite.Ignite) Leaf(org.apache.ignite.ml.trees.nodes.Leaf) IgniteCache(org.apache.ignite.IgniteCache) SparseBitSet(com.zaxxer.sparsebits.SparseBitSet) FeatureKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.FeatureKey) NotNull(org.jetbrains.annotations.NotNull)

Aggregations

SparseBitSet (com.zaxxer.sparsebits.SparseBitSet)1 ArrayList (java.util.ArrayList)1 Arrays (java.util.Arrays)1 Collections (java.util.Collections)1 Comparator (java.util.Comparator)1 HashMap (java.util.HashMap)1 LinkedList (java.util.LinkedList)1 List (java.util.List)1 Map (java.util.Map)1 Optional (java.util.Optional)1 Set (java.util.Set)1 UUID (java.util.UUID)1 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)1 Consumer (java.util.function.Consumer)1 Collectors (java.util.stream.Collectors)1 DoubleStream (java.util.stream.DoubleStream)1 IntStream (java.util.stream.IntStream)1 Stream (java.util.stream.Stream)1 Cache (javax.cache.Cache)1 Ignite (org.apache.ignite.Ignite)1