use of org.apache.ignite.ml.math.functions.IgniteSupplier in project ignite by apache.
the class CacheUtils method sparseFold.
/**
* Sparse version of fold. This method also applicable to sparse zeroes.
*
* @param cacheName Cache name.
* @param folder Folder.
* @param keyFilter Key filter.
* @param accumulator Accumulator.
* @param zeroValSupp Zero value supplier.
* @param defVal Default value.
* @param defKey Default key.
* @param defValCnt Def value count.
* @param isNilpotent Is nilpotent.
*/
private static <K, V, A> A sparseFold(String cacheName, IgniteBiFunction<Cache.Entry<K, V>, A, A> folder, IgnitePredicate<K> keyFilter, BinaryOperator<A> accumulator, IgniteSupplier<A> zeroValSupp, V defVal, K defKey, long defValCnt, boolean isNilpotent) {
A defRes = zeroValSupp.get();
if (!isNilpotent)
for (int i = 0; i < defValCnt; i++) defRes = folder.apply(new CacheEntryImpl<>(defKey, defVal), defRes);
Collection<A> totalRes = bcast(cacheName, () -> {
Ignite ignite = Ignition.localIgnite();
IgniteCache<K, V> cache = ignite.getOrCreateCache(cacheName);
int partsCnt = ignite.affinity(cacheName).partitions();
// Use affinity in filter for ScanQuery. Otherwise we accept consumer in each node which is wrong.
Affinity affinity = ignite.affinity(cacheName);
ClusterNode locNode = ignite.cluster().localNode();
A a = zeroValSupp.get();
// Iterate over all partitions. Some of them will be stored on that local node.
for (int part = 0; part < partsCnt; part++) {
int p = part;
// Query returns an empty cursor if this partition is not stored on this node.
for (Cache.Entry<K, V> entry : cache.query(new ScanQuery<K, V>(part, (k, v) -> affinity.mapPartitionToNode(p) == locNode && (keyFilter == null || keyFilter.apply(k))))) a = folder.apply(entry, a);
}
return a;
});
return totalRes.stream().reduce(defRes, accumulator);
}
use of org.apache.ignite.ml.math.functions.IgniteSupplier in project ignite by apache.
the class DistributedWorkersChainTest method testDistributed.
/**
*/
public void testDistributed() {
ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create();
IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache = TestGroupTrainingCache.getOrCreate(ignite);
int init = 1;
UUID trainingUUID = UUID.randomUUID();
TestLocalContext locCtx = new TestLocalContext(0, trainingUUID);
Map<GroupTrainerCacheKey<Double>, Integer> m = new HashMap<>();
m.put(new GroupTrainerCacheKey<>(0L, 1.0, trainingUUID), 1);
m.put(new GroupTrainerCacheKey<>(1L, 2.0, trainingUUID), 2);
m.put(new GroupTrainerCacheKey<>(2L, 3.0, trainingUUID), 3);
m.put(new GroupTrainerCacheKey<>(3L, 4.0, trainingUUID), 4);
Stream<GroupTrainerCacheKey<Double>> keys = m.keySet().stream();
cache.putAll(m);
IgniteBiFunction<Integer, TestLocalContext, IgniteSupplier<Stream<GroupTrainerCacheKey<Double>>>> function = (o, l) -> () -> keys;
IgniteFunction<List<Integer>, Integer> max = ints -> ints.stream().mapToInt(x -> x).max().orElse(Integer.MIN_VALUE);
Integer res = chain.thenDistributedForEntries((integer, context) -> () -> null, this::readAndIncrement, function, max).process(init, new GroupTrainingContext<>(locCtx, cache, ignite));
int localMax = m.values().stream().max(Comparator.comparingInt(i -> i)).orElse(Integer.MIN_VALUE);
assertEquals((long) localMax, (long) res);
for (GroupTrainerCacheKey<Double> key : m.keySet()) m.compute(key, (k, v) -> v + 1);
assertMapEqualsCache(m, cache);
}
use of org.apache.ignite.ml.math.functions.IgniteSupplier in project ignite by apache.
the class ColumnDecisionTreeTrainer method doTrain.
/**
*/
@NotNull
private DecisionTreeModel doTrain(ColumnDecisionTreeTrainerInput input, UUID uuid) {
RootNode root = new RootNode();
// List containing setters of leaves of the tree.
List<TreeTip> tips = new LinkedList<>();
tips.add(new TreeTip(root::setSplit, 0));
int curDepth = 0;
int regsCnt = 1;
int featuresCnt = input.featuresCount();
IntStream.range(0, featuresCnt).mapToObj(fIdx -> SplitCache.key(fIdx, input.affinityKey(fIdx, ignite), uuid)).forEach(k -> SplitCache.getOrCreate(ignite).put(k, new IgniteBiTuple<>(0, 0.0)));
updateSplitCache(0, regsCnt, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid);
// regions cannot be split more and split only those that can.
while (true) {
long before = System.currentTimeMillis();
IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>> b = findBestSplitIndexForFeatures(featuresCnt, input::affinityKey, uuid);
long findBestRegIdx = System.currentTimeMillis() - before;
Integer bestFeatureIdx = b.get1();
Integer regIdx = b.get2().get1();
Double bestInfoGain = b.get2().get2();
if (regIdx >= 0 && bestInfoGain > MIN_INFO_GAIN) {
before = System.currentTimeMillis();
SplitInfo bi = ignite.compute().affinityCall(ProjectionsCache.CACHE_NAME, input.affinityKey(bestFeatureIdx, ignite), () -> {
TrainingContext<ContinuousRegionInfo> ctx = ContextCache.getOrCreate(ignite).get(uuid);
Ignite ignite = Ignition.localIgnite();
RegionKey key = ProjectionsCache.key(bestFeatureIdx, regIdx / BLOCK_SIZE, input.affinityKey(bestFeatureIdx, Ignition.localIgnite()), uuid);
RegionProjection reg = ProjectionsCache.getOrCreate(ignite).localPeek(key).get(regIdx % BLOCK_SIZE);
return ctx.featureProcessor(bestFeatureIdx).findBestSplit(reg, ctx.values(bestFeatureIdx, ignite), ctx.labels(), regIdx);
});
long findBestSplit = System.currentTimeMillis() - before;
IndexAndSplitInfo best = new IndexAndSplitInfo(bestFeatureIdx, bi);
regsCnt++;
if (log.isDebugEnabled())
log.debug("Globally best: " + best.info + " idx time: " + findBestRegIdx + ", calculate best: " + findBestSplit + " fi: " + best.featureIdx + ", regs: " + regsCnt);
// Request bitset for split region.
int ind = best.info.regionIndex();
SparseBitSet bs = ignite.compute().affinityCall(ProjectionsCache.CACHE_NAME, input.affinityKey(bestFeatureIdx, ignite), () -> {
Ignite ignite = Ignition.localIgnite();
IgniteCache<FeatureKey, double[]> featuresCache = FeaturesCache.getOrCreate(ignite);
IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ignite);
TrainingContext ctx = ctxCache.localPeek(uuid);
double[] values = featuresCache.localPeek(getFeatureCacheKey(bestFeatureIdx, uuid, input.affinityKey(bestFeatureIdx, Ignition.localIgnite())));
RegionKey key = ProjectionsCache.key(bestFeatureIdx, regIdx / BLOCK_SIZE, input.affinityKey(bestFeatureIdx, Ignition.localIgnite()), uuid);
RegionProjection reg = ProjectionsCache.getOrCreate(ignite).localPeek(key).get(regIdx % BLOCK_SIZE);
return ctx.featureProcessor(bestFeatureIdx).calculateOwnershipBitSet(reg, values, best.info);
});
SplitNode sn = best.info.createSplitNode(best.featureIdx);
TreeTip tipToSplit = tips.get(ind);
tipToSplit.leafSetter.accept(sn);
tipToSplit.leafSetter = sn::setLeft;
int d = tipToSplit.depth++;
tips.add(new TreeTip(sn::setRight, d));
if (d > curDepth) {
curDepth = d;
if (log.isDebugEnabled()) {
log.debug("Depth: " + curDepth);
log.debug("Cache size: " + prjsCache.size(CachePeekMode.PRIMARY));
}
}
before = System.currentTimeMillis();
// Perform split on all feature vectors.
IgniteSupplier<Set<RegionKey>> bestRegsKeys = () -> IntStream.range(0, featuresCnt).mapToObj(fIdx -> ProjectionsCache.key(fIdx, ind / BLOCK_SIZE, input.affinityKey(fIdx, Ignition.localIgnite()), uuid)).collect(Collectors.toSet());
int rc = regsCnt;
// Perform split.
CacheUtils.update(prjsCache.getName(), ignite, (Ignite ign, Cache.Entry<RegionKey, List<RegionProjection>> e) -> {
RegionKey k = e.getKey();
List<RegionProjection> leftBlock = e.getValue();
int fIdx = k.featureIdx();
int idxInBlock = ind % BLOCK_SIZE;
IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ign);
TrainingContext<D> ctx = ctxCache.get(uuid);
RegionProjection targetRegProj = leftBlock.get(idxInBlock);
IgniteBiTuple<RegionProjection, RegionProjection> regs = ctx.performSplit(input, bs, fIdx, best.featureIdx, targetRegProj, best.info.leftData(), best.info.rightData(), ign);
RegionProjection left = regs.get1();
RegionProjection right = regs.get2();
leftBlock.set(idxInBlock, left);
RegionKey rightKey = ProjectionsCache.key(fIdx, (rc - 1) / BLOCK_SIZE, input.affinityKey(fIdx, ign), uuid);
IgniteCache<RegionKey, List<RegionProjection>> c = ProjectionsCache.getOrCreate(ign);
List<RegionProjection> rightBlock = rightKey.equals(k) ? leftBlock : c.localPeek(rightKey);
if (rightBlock == null) {
List<RegionProjection> newBlock = new ArrayList<>(BLOCK_SIZE);
newBlock.add(right);
return Stream.of(new CacheEntryImpl<>(k, leftBlock), new CacheEntryImpl<>(rightKey, newBlock));
} else {
rightBlock.add(right);
return rightBlock.equals(k) ? Stream.of(new CacheEntryImpl<>(k, leftBlock)) : Stream.of(new CacheEntryImpl<>(k, leftBlock), new CacheEntryImpl<>(rightKey, rightBlock));
}
}, bestRegsKeys);
if (log.isDebugEnabled())
log.debug("Update of projections cache time: " + (System.currentTimeMillis() - before));
before = System.currentTimeMillis();
updateSplitCache(ind, rc, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid);
if (log.isDebugEnabled())
log.debug("Update of split cache time: " + (System.currentTimeMillis() - before));
} else {
if (log.isDebugEnabled())
log.debug("Best split [bestFeatureIdx=" + bestFeatureIdx + ", bestInfoGain=" + bestInfoGain + "]");
break;
}
}
int rc = regsCnt;
IgniteSupplier<Iterable<Cache.Entry<RegionKey, List<RegionProjection>>>> featZeroRegs = () -> {
IgniteCache<RegionKey, List<RegionProjection>> projsCache = ProjectionsCache.getOrCreate(Ignition.localIgnite());
return () -> IntStream.range(0, (rc - 1) / BLOCK_SIZE + 1).mapToObj(rBIdx -> ProjectionsCache.key(0, rBIdx, input.affinityKey(0, Ignition.localIgnite()), uuid)).map(k -> (Cache.Entry<RegionKey, List<RegionProjection>>) new CacheEntryImpl<>(k, projsCache.localPeek(k))).iterator();
};
Map<Integer, Double> vals = CacheUtils.reduce(prjsCache.getName(), ignite, (TrainingContext ctx, Cache.Entry<RegionKey, List<RegionProjection>> e, Map<Integer, Double> m) -> {
int regBlockIdx = e.getKey().regionBlockIndex();
if (e.getValue() != null) {
for (int i = 0; i < e.getValue().size(); i++) {
int regIdx = regBlockIdx * BLOCK_SIZE + i;
RegionProjection reg = e.getValue().get(i);
Double res = regCalc.apply(Arrays.stream(reg.sampleIndexes()).mapToDouble(s -> ctx.labels()[s]));
m.put(regIdx, res);
}
}
return m;
}, () -> ContextCache.getOrCreate(Ignition.localIgnite()).get(uuid), featZeroRegs, (infos, infos2) -> {
Map<Integer, Double> res = new HashMap<>();
res.putAll(infos);
res.putAll(infos2);
return res;
}, HashMap::new);
int i = 0;
for (TreeTip tip : tips) {
tip.leafSetter.accept(new Leaf(vals.get(i)));
i++;
}
ProjectionsCache.clear(featuresCnt, rc, input::affinityKey, uuid, ignite);
ContextCache.getOrCreate(ignite).remove(uuid);
FeaturesCache.clear(featuresCnt, input::affinityKey, uuid, ignite);
SplitCache.clear(featuresCnt, input::affinityKey, uuid, ignite);
return new DecisionTreeModel(root.s);
}
use of org.apache.ignite.ml.math.functions.IgniteSupplier in project ignite by apache.
the class CacheUtils method update.
/**
* @param cacheName Cache name.
* @param fun An operation that accepts a cache entry and processes it.
* @param ignite Ignite.
* @param keysGen Keys generator.
* @param <K> Cache key object type.
* @param <V> Cache value object type.
*/
public static <K, V> void update(String cacheName, Ignite ignite, IgniteBiFunction<Ignite, Cache.Entry<K, V>, Stream<Cache.Entry<K, V>>> fun, IgniteSupplier<Set<K>> keysGen) {
bcast(cacheName, ignite, () -> {
Ignite ig = Ignition.localIgnite();
IgniteCache<K, V> cache = ig.getOrCreateCache(cacheName);
Affinity<K> affinity = ig.affinity(cacheName);
ClusterNode locNode = ig.cluster().localNode();
Collection<K> ks = affinity.mapKeysToNodes(keysGen.get()).get(locNode);
if (ks == null)
return;
Map<K, V> m = new ConcurrentHashMap<>();
ks.parallelStream().forEach(k -> {
V v = cache.localPeek(k);
if (v != null)
(fun.apply(ignite, new CacheEntryImpl<>(k, v))).forEach(ent -> m.put(ent.getKey(), ent.getValue()));
});
cache.putAll(m);
});
}
Aggregations