use of org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor in project ignite by apache.
the class ColumnDecisionTreeTrainer method train.
/**
* {@inheritDoc}
*/
@Override
public DecisionTreeModel train(ColumnDecisionTreeTrainerInput i) {
prjsCache = ProjectionsCache.getOrCreate(ignite);
IgniteCache<UUID, TrainingContext<D>> ctxtCache = ContextCache.getOrCreate(ignite);
SplitCache.getOrCreate(ignite);
UUID trainingUUID = UUID.randomUUID();
TrainingContext<D> ct = new TrainingContext<>(i, continuousCalculatorProvider.apply(i), categoricalCalculatorProvider.apply(i), trainingUUID, ignite);
ctxtCache.put(trainingUUID, ct);
CacheUtils.bcast(prjsCache.getName(), ignite, () -> {
Ignite ignite = Ignition.localIgnite();
IgniteCache<RegionKey, List<RegionProjection>> projCache = ProjectionsCache.getOrCreate(ignite);
IgniteCache<FeatureKey, double[]> featuresCache = FeaturesCache.getOrCreate(ignite);
Affinity<RegionKey> targetAffinity = ignite.affinity(ProjectionsCache.CACHE_NAME);
ClusterNode locNode = ignite.cluster().localNode();
Map<FeatureKey, double[]> fm = new ConcurrentHashMap<>();
Map<RegionKey, List<RegionProjection>> pm = new ConcurrentHashMap<>();
targetAffinity.mapKeysToNodes(IntStream.range(0, i.featuresCount()).mapToObj(idx -> ProjectionsCache.key(idx, 0, i.affinityKey(idx, ignite), trainingUUID)).collect(Collectors.toSet())).getOrDefault(locNode, Collections.emptyList()).forEach(k -> {
FeatureProcessor vec;
int featureIdx = k.featureIdx();
IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ignite);
TrainingContext ctx = ctxCache.get(trainingUUID);
double[] vals = new double[ctx.labels().length];
vec = ctx.featureProcessor(featureIdx);
i.values(featureIdx).forEach(t -> vals[t.get1()] = t.get2());
fm.put(getFeatureCacheKey(featureIdx, trainingUUID, i.affinityKey(featureIdx, ignite)), vals);
List<RegionProjection> newReg = new ArrayList<>(BLOCK_SIZE);
newReg.add(vec.createInitialRegion(getSamples(i.values(featureIdx), ctx.labels().length), vals, ctx.labels()));
pm.put(k, newReg);
});
featuresCache.putAll(fm);
projCache.putAll(pm);
return null;
});
return doTrain(i, trainingUUID);
}
Aggregations