use of org.apache.ignite.ml.trees.nodes.SplitNode in project ignite by apache.
the class ColumnDecisionTreeTrainer method doTrain.
/**
*/
@NotNull
private DecisionTreeModel doTrain(ColumnDecisionTreeTrainerInput input, UUID uuid) {
RootNode root = new RootNode();
// List containing setters of leaves of the tree.
List<TreeTip> tips = new LinkedList<>();
tips.add(new TreeTip(root::setSplit, 0));
int curDepth = 0;
int regsCnt = 1;
int featuresCnt = input.featuresCount();
IntStream.range(0, featuresCnt).mapToObj(fIdx -> SplitCache.key(fIdx, input.affinityKey(fIdx, ignite), uuid)).forEach(k -> SplitCache.getOrCreate(ignite).put(k, new IgniteBiTuple<>(0, 0.0)));
updateSplitCache(0, regsCnt, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid);
// regions cannot be split more and split only those that can.
while (true) {
long before = System.currentTimeMillis();
IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>> b = findBestSplitIndexForFeatures(featuresCnt, input::affinityKey, uuid);
long findBestRegIdx = System.currentTimeMillis() - before;
Integer bestFeatureIdx = b.get1();
Integer regIdx = b.get2().get1();
Double bestInfoGain = b.get2().get2();
if (regIdx >= 0 && bestInfoGain > MIN_INFO_GAIN) {
before = System.currentTimeMillis();
SplitInfo bi = ignite.compute().affinityCall(ProjectionsCache.CACHE_NAME, input.affinityKey(bestFeatureIdx, ignite), () -> {
TrainingContext<ContinuousRegionInfo> ctx = ContextCache.getOrCreate(ignite).get(uuid);
Ignite ignite = Ignition.localIgnite();
RegionKey key = ProjectionsCache.key(bestFeatureIdx, regIdx / BLOCK_SIZE, input.affinityKey(bestFeatureIdx, Ignition.localIgnite()), uuid);
RegionProjection reg = ProjectionsCache.getOrCreate(ignite).localPeek(key).get(regIdx % BLOCK_SIZE);
return ctx.featureProcessor(bestFeatureIdx).findBestSplit(reg, ctx.values(bestFeatureIdx, ignite), ctx.labels(), regIdx);
});
long findBestSplit = System.currentTimeMillis() - before;
IndexAndSplitInfo best = new IndexAndSplitInfo(bestFeatureIdx, bi);
regsCnt++;
if (log.isDebugEnabled())
log.debug("Globally best: " + best.info + " idx time: " + findBestRegIdx + ", calculate best: " + findBestSplit + " fi: " + best.featureIdx + ", regs: " + regsCnt);
// Request bitset for split region.
int ind = best.info.regionIndex();
SparseBitSet bs = ignite.compute().affinityCall(ProjectionsCache.CACHE_NAME, input.affinityKey(bestFeatureIdx, ignite), () -> {
Ignite ignite = Ignition.localIgnite();
IgniteCache<FeatureKey, double[]> featuresCache = FeaturesCache.getOrCreate(ignite);
IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ignite);
TrainingContext ctx = ctxCache.localPeek(uuid);
double[] values = featuresCache.localPeek(getFeatureCacheKey(bestFeatureIdx, uuid, input.affinityKey(bestFeatureIdx, Ignition.localIgnite())));
RegionKey key = ProjectionsCache.key(bestFeatureIdx, regIdx / BLOCK_SIZE, input.affinityKey(bestFeatureIdx, Ignition.localIgnite()), uuid);
RegionProjection reg = ProjectionsCache.getOrCreate(ignite).localPeek(key).get(regIdx % BLOCK_SIZE);
return ctx.featureProcessor(bestFeatureIdx).calculateOwnershipBitSet(reg, values, best.info);
});
SplitNode sn = best.info.createSplitNode(best.featureIdx);
TreeTip tipToSplit = tips.get(ind);
tipToSplit.leafSetter.accept(sn);
tipToSplit.leafSetter = sn::setLeft;
int d = tipToSplit.depth++;
tips.add(new TreeTip(sn::setRight, d));
if (d > curDepth) {
curDepth = d;
if (log.isDebugEnabled()) {
log.debug("Depth: " + curDepth);
log.debug("Cache size: " + prjsCache.size(CachePeekMode.PRIMARY));
}
}
before = System.currentTimeMillis();
// Perform split on all feature vectors.
IgniteSupplier<Set<RegionKey>> bestRegsKeys = () -> IntStream.range(0, featuresCnt).mapToObj(fIdx -> ProjectionsCache.key(fIdx, ind / BLOCK_SIZE, input.affinityKey(fIdx, Ignition.localIgnite()), uuid)).collect(Collectors.toSet());
int rc = regsCnt;
// Perform split.
CacheUtils.update(prjsCache.getName(), ignite, (Ignite ign, Cache.Entry<RegionKey, List<RegionProjection>> e) -> {
RegionKey k = e.getKey();
List<RegionProjection> leftBlock = e.getValue();
int fIdx = k.featureIdx();
int idxInBlock = ind % BLOCK_SIZE;
IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ign);
TrainingContext<D> ctx = ctxCache.get(uuid);
RegionProjection targetRegProj = leftBlock.get(idxInBlock);
IgniteBiTuple<RegionProjection, RegionProjection> regs = ctx.performSplit(input, bs, fIdx, best.featureIdx, targetRegProj, best.info.leftData(), best.info.rightData(), ign);
RegionProjection left = regs.get1();
RegionProjection right = regs.get2();
leftBlock.set(idxInBlock, left);
RegionKey rightKey = ProjectionsCache.key(fIdx, (rc - 1) / BLOCK_SIZE, input.affinityKey(fIdx, ign), uuid);
IgniteCache<RegionKey, List<RegionProjection>> c = ProjectionsCache.getOrCreate(ign);
List<RegionProjection> rightBlock = rightKey.equals(k) ? leftBlock : c.localPeek(rightKey);
if (rightBlock == null) {
List<RegionProjection> newBlock = new ArrayList<>(BLOCK_SIZE);
newBlock.add(right);
return Stream.of(new CacheEntryImpl<>(k, leftBlock), new CacheEntryImpl<>(rightKey, newBlock));
} else {
rightBlock.add(right);
return rightBlock.equals(k) ? Stream.of(new CacheEntryImpl<>(k, leftBlock)) : Stream.of(new CacheEntryImpl<>(k, leftBlock), new CacheEntryImpl<>(rightKey, rightBlock));
}
}, bestRegsKeys);
if (log.isDebugEnabled())
log.debug("Update of projections cache time: " + (System.currentTimeMillis() - before));
before = System.currentTimeMillis();
updateSplitCache(ind, rc, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid);
if (log.isDebugEnabled())
log.debug("Update of split cache time: " + (System.currentTimeMillis() - before));
} else {
if (log.isDebugEnabled())
log.debug("Best split [bestFeatureIdx=" + bestFeatureIdx + ", bestInfoGain=" + bestInfoGain + "]");
break;
}
}
int rc = regsCnt;
IgniteSupplier<Iterable<Cache.Entry<RegionKey, List<RegionProjection>>>> featZeroRegs = () -> {
IgniteCache<RegionKey, List<RegionProjection>> projsCache = ProjectionsCache.getOrCreate(Ignition.localIgnite());
return () -> IntStream.range(0, (rc - 1) / BLOCK_SIZE + 1).mapToObj(rBIdx -> ProjectionsCache.key(0, rBIdx, input.affinityKey(0, Ignition.localIgnite()), uuid)).map(k -> (Cache.Entry<RegionKey, List<RegionProjection>>) new CacheEntryImpl<>(k, projsCache.localPeek(k))).iterator();
};
Map<Integer, Double> vals = CacheUtils.reduce(prjsCache.getName(), ignite, (TrainingContext ctx, Cache.Entry<RegionKey, List<RegionProjection>> e, Map<Integer, Double> m) -> {
int regBlockIdx = e.getKey().regionBlockIndex();
if (e.getValue() != null) {
for (int i = 0; i < e.getValue().size(); i++) {
int regIdx = regBlockIdx * BLOCK_SIZE + i;
RegionProjection reg = e.getValue().get(i);
Double res = regCalc.apply(Arrays.stream(reg.sampleIndexes()).mapToDouble(s -> ctx.labels()[s]));
m.put(regIdx, res);
}
}
return m;
}, () -> ContextCache.getOrCreate(Ignition.localIgnite()).get(uuid), featZeroRegs, (infos, infos2) -> {
Map<Integer, Double> res = new HashMap<>();
res.putAll(infos);
res.putAll(infos2);
return res;
}, HashMap::new);
int i = 0;
for (TreeTip tip : tips) {
tip.leafSetter.accept(new Leaf(vals.get(i)));
i++;
}
ProjectionsCache.clear(featuresCnt, rc, input::affinityKey, uuid, ignite);
ContextCache.getOrCreate(ignite).remove(uuid);
FeaturesCache.clear(featuresCnt, input::affinityKey, uuid, ignite);
SplitCache.clear(featuresCnt, input::affinityKey, uuid, ignite);
return new DecisionTreeModel(root.s);
}
Aggregations