Search in sources :

Example 1 with FeatureProcessor

use of org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor in project ignite by apache.

the class ColumnDecisionTreeTrainer method train.

/**
 * {@inheritDoc}
 */
@Override
public DecisionTreeModel train(ColumnDecisionTreeTrainerInput i) {
    prjsCache = ProjectionsCache.getOrCreate(ignite);
    IgniteCache<UUID, TrainingContext<D>> ctxtCache = ContextCache.getOrCreate(ignite);
    SplitCache.getOrCreate(ignite);
    UUID trainingUUID = UUID.randomUUID();
    TrainingContext<D> ct = new TrainingContext<>(i, continuousCalculatorProvider.apply(i), categoricalCalculatorProvider.apply(i), trainingUUID, ignite);
    ctxtCache.put(trainingUUID, ct);
    CacheUtils.bcast(prjsCache.getName(), ignite, () -> {
        Ignite ignite = Ignition.localIgnite();
        IgniteCache<RegionKey, List<RegionProjection>> projCache = ProjectionsCache.getOrCreate(ignite);
        IgniteCache<FeatureKey, double[]> featuresCache = FeaturesCache.getOrCreate(ignite);
        Affinity<RegionKey> targetAffinity = ignite.affinity(ProjectionsCache.CACHE_NAME);
        ClusterNode locNode = ignite.cluster().localNode();
        Map<FeatureKey, double[]> fm = new ConcurrentHashMap<>();
        Map<RegionKey, List<RegionProjection>> pm = new ConcurrentHashMap<>();
        targetAffinity.mapKeysToNodes(IntStream.range(0, i.featuresCount()).mapToObj(idx -> ProjectionsCache.key(idx, 0, i.affinityKey(idx, ignite), trainingUUID)).collect(Collectors.toSet())).getOrDefault(locNode, Collections.emptyList()).forEach(k -> {
            FeatureProcessor vec;
            int featureIdx = k.featureIdx();
            IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ignite);
            TrainingContext ctx = ctxCache.get(trainingUUID);
            double[] vals = new double[ctx.labels().length];
            vec = ctx.featureProcessor(featureIdx);
            i.values(featureIdx).forEach(t -> vals[t.get1()] = t.get2());
            fm.put(getFeatureCacheKey(featureIdx, trainingUUID, i.affinityKey(featureIdx, ignite)), vals);
            List<RegionProjection> newReg = new ArrayList<>(BLOCK_SIZE);
            newReg.add(vec.createInitialRegion(getSamples(i.values(featureIdx), ctx.labels().length), vals, ctx.labels()));
            pm.put(k, newReg);
        });
        featuresCache.putAll(fm);
        projCache.putAll(pm);
        return null;
    });
    return doTrain(i, trainingUUID);
}
Also used : ClusterNode(org.apache.ignite.cluster.ClusterNode) UUID(java.util.UUID) ArrayList(java.util.ArrayList) RegionKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache.RegionKey) FeatureProcessor(org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor) Ignite(org.apache.ignite.Ignite) List(java.util.List) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) UUID(java.util.UUID) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) FeatureKey(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.FeatureKey)

Aggregations

ArrayList (java.util.ArrayList)1 LinkedList (java.util.LinkedList)1 List (java.util.List)1 UUID (java.util.UUID)1 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)1 Ignite (org.apache.ignite.Ignite)1 ClusterNode (org.apache.ignite.cluster.ClusterNode)1 FeatureKey (org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.FeatureKey)1 RegionKey (org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache.RegionKey)1 FeatureProcessor (org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor)1