Search in sources :

Example 1 with FeatureKey

use of org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.FeatureKey in project ignite by apache.

the class ColumnDecisionTreeTrainer method train.

/**
 * {@inheritDoc}
 */
@Override
public DecisionTreeModel train(ColumnDecisionTreeTrainerInput i) {
    prjsCache = ProjectionsCache.getOrCreate(ignite);
    IgniteCache<UUID, TrainingContext<D>> ctxtCache = ContextCache.getOrCreate(ignite);
    SplitCache.getOrCreate(ignite);
    UUID trainingUUID = UUID.randomUUID();
    TrainingContext<D> ct = new TrainingContext<>(i, continuousCalculatorProvider.apply(i), categoricalCalculatorProvider.apply(i), trainingUUID, ignite);
    ctxtCache.put(trainingUUID, ct);
    CacheUtils.bcast(prjsCache.getName(), ignite, () -> {
        Ignite ignite = Ignition.localIgnite();
        IgniteCache<RegionKey, List<RegionProjection>> projCache = ProjectionsCache.getOrCreate(ignite);
        IgniteCache<FeatureKey, double[]> featuresCache = FeaturesCache.getOrCreate(ignite);
        Affinity<RegionKey> targetAffinity = ignite.affinity(ProjectionsCache.CACHE_NAME);
        ClusterNode locNode = ignite.cluster().localNode();
        Map<FeatureKey, double[]> fm = new ConcurrentHashMap<>();
        Map<RegionKey, List<RegionProjection>> pm = new ConcurrentHashMap<>();
        targetAffinity.mapKeysToNodes(IntStream.range(0, i.featuresCount()).mapToObj(idx -> ProjectionsCache.key(idx, 0, i.affinityKey(idx, ignite), trainingUUID)).collect(Collectors.toSet())).getOrDefault(locNode, Collections.emptyList()).forEach(k -> {
            FeatureProcessor vec;
            int featureIdx = k.featureIdx();
            IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ignite);
            TrainingContext ctx = ctxCache.get(trainingUUID);
            double[] vals = new double[ctx.labels().length];
            vec = ctx.featureProcessor(featureIdx);
            i.values(featureIdx).forEach(t -> vals[t.get1()] = t.get2());
            fm.put(getFeatureCacheKey(featureIdx, trainingUUID, i.affinityKey(featureIdx, ignite)), vals);
            List<RegionProjection> newReg = new ArrayList<>(BLOCK_SIZE);
            newReg.add(vec.createInitialRegion(getSamples(i.values(featureIdx), ctx.labels().length), vals, ctx.labels()));
            pm.put(k, newReg);
        });
        featuresCache.putAll(fm);
        projCache.putAll(pm);
        return null;
    });
    return doTrain(i, trainingUUID);
}
Also used : ClusterNode(org.apache.ignite.cluster.ClusterNode) UUID(java.util.UUID) ArrayList(java.util.ArrayList) RegionKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache.RegionKey) FeatureProcessor(org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor) Ignite(org.apache.ignite.Ignite) List(java.util.List) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) UUID(java.util.UUID) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) FeatureKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.FeatureKey)

Example 2 with FeatureKey

use of org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.FeatureKey in project ignite by apache.

the class ColumnDecisionTreeTrainer method doTrain.

/**
 */
@NotNull
private DecisionTreeModel doTrain(ColumnDecisionTreeTrainerInput input, UUID uuid) {
    RootNode root = new RootNode();
    // List containing setters of leaves of the tree.
    List<TreeTip> tips = new LinkedList<>();
    tips.add(new TreeTip(root::setSplit, 0));
    int curDepth = 0;
    int regsCnt = 1;
    int featuresCnt = input.featuresCount();
    IntStream.range(0, featuresCnt).mapToObj(fIdx -> SplitCache.key(fIdx, input.affinityKey(fIdx, ignite), uuid)).forEach(k -> SplitCache.getOrCreate(ignite).put(k, new IgniteBiTuple<>(0, 0.0)));
    updateSplitCache(0, regsCnt, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid);
    // regions cannot be split more and split only those that can.
    while (true) {
        long before = System.currentTimeMillis();
        IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>> b = findBestSplitIndexForFeatures(featuresCnt, input::affinityKey, uuid);
        long findBestRegIdx = System.currentTimeMillis() - before;
        Integer bestFeatureIdx = b.get1();
        Integer regIdx = b.get2().get1();
        Double bestInfoGain = b.get2().get2();
        if (regIdx >= 0 && bestInfoGain > MIN_INFO_GAIN) {
            before = System.currentTimeMillis();
            SplitInfo bi = ignite.compute().affinityCall(ProjectionsCache.CACHE_NAME, input.affinityKey(bestFeatureIdx, ignite), () -> {
                TrainingContext<ContinuousRegionInfo> ctx = ContextCache.getOrCreate(ignite).get(uuid);
                Ignite ignite = Ignition.localIgnite();
                RegionKey key = ProjectionsCache.key(bestFeatureIdx, regIdx / BLOCK_SIZE, input.affinityKey(bestFeatureIdx, Ignition.localIgnite()), uuid);
                RegionProjection reg = ProjectionsCache.getOrCreate(ignite).localPeek(key).get(regIdx % BLOCK_SIZE);
                return ctx.featureProcessor(bestFeatureIdx).findBestSplit(reg, ctx.values(bestFeatureIdx, ignite), ctx.labels(), regIdx);
            });
            long findBestSplit = System.currentTimeMillis() - before;
            IndexAndSplitInfo best = new IndexAndSplitInfo(bestFeatureIdx, bi);
            regsCnt++;
            if (log.isDebugEnabled())
                log.debug("Globally best: " + best.info + " idx time: " + findBestRegIdx + ", calculate best: " + findBestSplit + " fi: " + best.featureIdx + ", regs: " + regsCnt);
            // Request bitset for split region.
            int ind = best.info.regionIndex();
            SparseBitSet bs = ignite.compute().affinityCall(ProjectionsCache.CACHE_NAME, input.affinityKey(bestFeatureIdx, ignite), () -> {
                Ignite ignite = Ignition.localIgnite();
                IgniteCache<FeatureKey, double[]> featuresCache = FeaturesCache.getOrCreate(ignite);
                IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ignite);
                TrainingContext ctx = ctxCache.localPeek(uuid);
                double[] values = featuresCache.localPeek(getFeatureCacheKey(bestFeatureIdx, uuid, input.affinityKey(bestFeatureIdx, Ignition.localIgnite())));
                RegionKey key = ProjectionsCache.key(bestFeatureIdx, regIdx / BLOCK_SIZE, input.affinityKey(bestFeatureIdx, Ignition.localIgnite()), uuid);
                RegionProjection reg = ProjectionsCache.getOrCreate(ignite).localPeek(key).get(regIdx % BLOCK_SIZE);
                return ctx.featureProcessor(bestFeatureIdx).calculateOwnershipBitSet(reg, values, best.info);
            });
            SplitNode sn = best.info.createSplitNode(best.featureIdx);
            TreeTip tipToSplit = tips.get(ind);
            tipToSplit.leafSetter.accept(sn);
            tipToSplit.leafSetter = sn::setLeft;
            int d = tipToSplit.depth++;
            tips.add(new TreeTip(sn::setRight, d));
            if (d > curDepth) {
                curDepth = d;
                if (log.isDebugEnabled()) {
                    log.debug("Depth: " + curDepth);
                    log.debug("Cache size: " + prjsCache.size(CachePeekMode.PRIMARY));
                }
            }
            before = System.currentTimeMillis();
            // Perform split on all feature vectors.
            IgniteSupplier<Set<RegionKey>> bestRegsKeys = () -> IntStream.range(0, featuresCnt).mapToObj(fIdx -> ProjectionsCache.key(fIdx, ind / BLOCK_SIZE, input.affinityKey(fIdx, Ignition.localIgnite()), uuid)).collect(Collectors.toSet());
            int rc = regsCnt;
            // Perform split.
            CacheUtils.update(prjsCache.getName(), ignite, (Ignite ign, Cache.Entry<RegionKey, List<RegionProjection>> e) -> {
                RegionKey k = e.getKey();
                List<RegionProjection> leftBlock = e.getValue();
                int fIdx = k.featureIdx();
                int idxInBlock = ind % BLOCK_SIZE;
                IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ign);
                TrainingContext<D> ctx = ctxCache.get(uuid);
                RegionProjection targetRegProj = leftBlock.get(idxInBlock);
                IgniteBiTuple<RegionProjection, RegionProjection> regs = ctx.performSplit(input, bs, fIdx, best.featureIdx, targetRegProj, best.info.leftData(), best.info.rightData(), ign);
                RegionProjection left = regs.get1();
                RegionProjection right = regs.get2();
                leftBlock.set(idxInBlock, left);
                RegionKey rightKey = ProjectionsCache.key(fIdx, (rc - 1) / BLOCK_SIZE, input.affinityKey(fIdx, ign), uuid);
                IgniteCache<RegionKey, List<RegionProjection>> c = ProjectionsCache.getOrCreate(ign);
                List<RegionProjection> rightBlock = rightKey.equals(k) ? leftBlock : c.localPeek(rightKey);
                if (rightBlock == null) {
                    List<RegionProjection> newBlock = new ArrayList<>(BLOCK_SIZE);
                    newBlock.add(right);
                    return Stream.of(new CacheEntryImpl<>(k, leftBlock), new CacheEntryImpl<>(rightKey, newBlock));
                } else {
                    rightBlock.add(right);
                    return rightBlock.equals(k) ? Stream.of(new CacheEntryImpl<>(k, leftBlock)) : Stream.of(new CacheEntryImpl<>(k, leftBlock), new CacheEntryImpl<>(rightKey, rightBlock));
                }
            }, bestRegsKeys);
            if (log.isDebugEnabled())
                log.debug("Update of projections cache time: " + (System.currentTimeMillis() - before));
            before = System.currentTimeMillis();
            updateSplitCache(ind, rc, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid);
            if (log.isDebugEnabled())
                log.debug("Update of split cache time: " + (System.currentTimeMillis() - before));
        } else {
            if (log.isDebugEnabled())
                log.debug("Best split [bestFeatureIdx=" + bestFeatureIdx + ", bestInfoGain=" + bestInfoGain + "]");
            break;
        }
    }
    int rc = regsCnt;
    IgniteSupplier<Iterable<Cache.Entry<RegionKey, List<RegionProjection>>>> featZeroRegs = () -> {
        IgniteCache<RegionKey, List<RegionProjection>> projsCache = ProjectionsCache.getOrCreate(Ignition.localIgnite());
        return () -> IntStream.range(0, (rc - 1) / BLOCK_SIZE + 1).mapToObj(rBIdx -> ProjectionsCache.key(0, rBIdx, input.affinityKey(0, Ignition.localIgnite()), uuid)).map(k -> (Cache.Entry<RegionKey, List<RegionProjection>>) new CacheEntryImpl<>(k, projsCache.localPeek(k))).iterator();
    };
    Map<Integer, Double> vals = CacheUtils.reduce(prjsCache.getName(), ignite, (TrainingContext ctx, Cache.Entry<RegionKey, List<RegionProjection>> e, Map<Integer, Double> m) -> {
        int regBlockIdx = e.getKey().regionBlockIndex();
        if (e.getValue() != null) {
            for (int i = 0; i < e.getValue().size(); i++) {
                int regIdx = regBlockIdx * BLOCK_SIZE + i;
                RegionProjection reg = e.getValue().get(i);
                Double res = regCalc.apply(Arrays.stream(reg.sampleIndexes()).mapToDouble(s -> ctx.labels()[s]));
                m.put(regIdx, res);
            }
        }
        return m;
    }, () -> ContextCache.getOrCreate(Ignition.localIgnite()).get(uuid), featZeroRegs, (infos, infos2) -> {
        Map<Integer, Double> res = new HashMap<>();
        res.putAll(infos);
        res.putAll(infos2);
        return res;
    }, HashMap::new);
    int i = 0;
    for (TreeTip tip : tips) {
        tip.leafSetter.accept(new Leaf(vals.get(i)));
        i++;
    }
    ProjectionsCache.clear(featuresCnt, rc, input::affinityKey, uuid, ignite);
    ContextCache.getOrCreate(ignite).remove(uuid);
    FeaturesCache.clear(featuresCnt, input::affinityKey, uuid, ignite);
    SplitCache.clear(featuresCnt, input::affinityKey, uuid, ignite);
    return new DecisionTreeModel(root.s);
}
Also used : Leaf(org.apache.ignite.ml.trees.nodes.Leaf) Arrays(java.util.Arrays) FeaturesCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache) SplitNode(org.apache.ignite.ml.trees.nodes.SplitNode) Functions(org.apache.ignite.ml.math.functions.Functions) RegionKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache.RegionKey) Vector(org.apache.ignite.ml.math.Vector) Map(java.util.Map) Cache(javax.cache.Cache) SparseBitSet(com.zaxxer.sparsebits.SparseBitSet) CacheEntryImpl(org.apache.ignite.internal.processors.cache.CacheEntryImpl) FeatureKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.FeatureKey) CachePeekMode(org.apache.ignite.cache.CachePeekMode) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) ProjectionsCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache) Set(java.util.Set) UUID(java.util.UUID) Collectors(java.util.stream.Collectors) SplitKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache.SplitKey) IgniteCache(org.apache.ignite.IgniteCache) ContextCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ContextCache) DoubleStream(java.util.stream.DoubleStream) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) List(java.util.List) Stream(java.util.stream.Stream) FeatureProcessor(org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor) SplitCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache) Optional(java.util.Optional) NotNull(org.jetbrains.annotations.NotNull) SplitInfo(org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo) IntStream(java.util.stream.IntStream) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) FeaturesCache.getFeatureCacheKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.getFeatureCacheKey) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) Affinity(org.apache.ignite.cache.affinity.Affinity) HashMap(java.util.HashMap) IgniteLogger(org.apache.ignite.IgniteLogger) ArrayList(java.util.ArrayList) Trainer(org.apache.ignite.ml.Trainer) ClusterNode(org.apache.ignite.cluster.ClusterNode) LinkedList(java.util.LinkedList) DecisionTreeNode(org.apache.ignite.ml.trees.nodes.DecisionTreeNode) Ignite(org.apache.ignite.Ignite) CacheUtils(org.apache.ignite.ml.math.distributed.CacheUtils) IgniteCurriedBiFunction(org.apache.ignite.ml.math.functions.IgniteCurriedBiFunction) Consumer(java.util.function.Consumer) Ignition(org.apache.ignite.Ignition) IgniteBiFunction(org.apache.ignite.ml.math.functions.IgniteBiFunction) ContinuousSplitCalculator(org.apache.ignite.ml.trees.ContinuousSplitCalculator) Comparator(java.util.Comparator) ContinuousRegionInfo(org.apache.ignite.ml.trees.ContinuousRegionInfo) Collections(java.util.Collections) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) ContinuousRegionInfo(org.apache.ignite.ml.trees.ContinuousRegionInfo) SplitNode(org.apache.ignite.ml.trees.nodes.SplitNode) List(java.util.List) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) UUID(java.util.UUID) LinkedList(java.util.LinkedList) Map(java.util.Map) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) FeaturesCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache) Cache(javax.cache.Cache) ProjectionsCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache) IgniteCache(org.apache.ignite.IgniteCache) ContextCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ContextCache) SplitCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache) SparseBitSet(com.zaxxer.sparsebits.SparseBitSet) Set(java.util.Set) UUID(java.util.UUID) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) RegionKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache.RegionKey) SplitInfo(org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo) CacheEntryImpl(org.apache.ignite.internal.processors.cache.CacheEntryImpl) Ignite(org.apache.ignite.Ignite) Leaf(org.apache.ignite.ml.trees.nodes.Leaf) IgniteCache(org.apache.ignite.IgniteCache) SparseBitSet(com.zaxxer.sparsebits.SparseBitSet) FeatureKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.FeatureKey) NotNull(org.jetbrains.annotations.NotNull)

Aggregations

ArrayList (java.util.ArrayList)2 LinkedList (java.util.LinkedList)2 List (java.util.List)2 UUID (java.util.UUID)2 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)2 Ignite (org.apache.ignite.Ignite)2 ClusterNode (org.apache.ignite.cluster.ClusterNode)2 FeatureKey (org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.FeatureKey)2 RegionKey (org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache.RegionKey)2 FeatureProcessor (org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor)2 SparseBitSet (com.zaxxer.sparsebits.SparseBitSet)1 Arrays (java.util.Arrays)1 Collections (java.util.Collections)1 Comparator (java.util.Comparator)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 Optional (java.util.Optional)1 Set (java.util.Set)1 Consumer (java.util.function.Consumer)1 Collectors (java.util.stream.Collectors)1