Search in sources :

Example 1 with IgniteFunction

use of org.apache.ignite.ml.math.functions.IgniteFunction in project ignite by apache.

the class GroupTrainer method train.

/**
 * {@inheritDoc}
 */
@Override
public final M train(T data) {
    UUID trainingUUID = UUID.randomUUID();
    LC locCtx = initialLocalContext(data, trainingUUID);
    GroupTrainingContext<K, V, LC> ctx = new GroupTrainingContext<>(locCtx, cache, ignite);
    ComputationsChain<LC, K, V, T, T> chain = (i, c) -> i;
    IgniteFunction<GroupTrainerCacheKey<K>, ResultAndUpdates<IN>> distributedInitializer = distributedInitializer(data);
    init(data, trainingUUID);
    M res = chain.thenDistributedForKeys(distributedInitializer, (t, lc) -> data.initialKeys(trainingUUID), reduceDistributedInitData()).thenLocally(this::locallyProcessInitData).thenWhile(this::shouldContinue, trainingLoopStep()).thenDistributedForEntries(this::extractContextForFinalResultCreation, finalResultsExtractor(), this::finalResultKeys, finalResultsReducer()).thenLocally(this::mapFinalResult).process(data, ctx);
    cleanup(locCtx);
    return res;
}
Also used : IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) Model(org.apache.ignite.ml.Model) ComputationsChain(org.apache.ignite.ml.trainers.group.chain.ComputationsChain) UUID(java.util.UUID) Ignite(org.apache.ignite.Ignite) IgniteCache(org.apache.ignite.IgniteCache) Serializable(java.io.Serializable) Trainer(org.apache.ignite.ml.trainers.Trainer) EntryAndContext(org.apache.ignite.ml.trainers.group.chain.EntryAndContext) List(java.util.List) Stream(java.util.stream.Stream) HasTrainingUUID(org.apache.ignite.ml.trainers.group.chain.HasTrainingUUID) UUID(java.util.UUID) HasTrainingUUID(org.apache.ignite.ml.trainers.group.chain.HasTrainingUUID)

Example 2 with IgniteFunction

use of org.apache.ignite.ml.math.functions.IgniteFunction in project ignite by apache.

the class SplitDataGenerator method testByGen.

/**
 */
<D extends ContinuousRegionInfo> void testByGen(int totalPts, IgniteFunction<ColumnDecisionTreeTrainerInput, ? extends ContinuousSplitCalculator<D>> calc, IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> catImpCalc, IgniteFunction<DoubleStream, Double> regCalc, Ignite ignite) {
    List<IgniteBiTuple<Integer, V>> lst = points(totalPts, (i, rn) -> i).collect(Collectors.toList());
    Collections.shuffle(lst, rnd);
    SparseDistributedMatrix m = new SparseDistributedMatrix(totalPts, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
    Map<Integer, List<LabeledVectorDouble>> byRegion = new HashMap<>();
    int i = 0;
    for (IgniteBiTuple<Integer, V> bt : lst) {
        byRegion.putIfAbsent(bt.get1(), new LinkedList<>());
        byRegion.get(bt.get1()).add(asLabeledVector(bt.get2().getStorage().data()));
        m.setRow(i, bt.get2().getStorage().data());
        i++;
    }
    ColumnDecisionTreeTrainer<D> trainer = new ColumnDecisionTreeTrainer<>(3, calc, catImpCalc, regCalc, ignite);
    DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, catFeaturesInfo));
    byRegion.keySet().forEach(k -> mdl.apply(byRegion.get(k).get(0).features()));
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) BiFunction(java.util.function.BiFunction) ColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainerInput) HashMap(java.util.HashMap) Random(java.util.Random) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) Function(java.util.function.Function) Supplier(java.util.function.Supplier) Vector(org.apache.ignite.ml.math.Vector) Map(java.util.Map) LinkedList(java.util.LinkedList) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) LabeledVectorDouble(org.apache.ignite.ml.structures.LabeledVectorDouble) Ignite(org.apache.ignite.Ignite) Collectors(java.util.stream.Collectors) Serializable(java.io.Serializable) DoubleStream(java.util.stream.DoubleStream) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) List(java.util.List) Stream(java.util.stream.Stream) MathIllegalArgumentException(org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException) Utils(org.apache.ignite.ml.util.Utils) ContinuousSplitCalculator(org.apache.ignite.ml.trees.ContinuousSplitCalculator) BitSet(java.util.BitSet) StorageConstants(org.apache.ignite.ml.math.StorageConstants) ContinuousRegionInfo(org.apache.ignite.ml.trees.ContinuousRegionInfo) Collections(java.util.Collections) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) HashMap(java.util.HashMap) MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) LinkedList(java.util.LinkedList) List(java.util.List) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer)

Example 3 with IgniteFunction

use of org.apache.ignite.ml.math.functions.IgniteFunction in project ignite by apache.

the class GradientDescent method getLossGradientFunction.

/**
 * Makes carrying of the gradient function and fixes data matrix.
 */
private IgniteFunction<Vector, Vector> getLossGradientFunction(Matrix data) {
    if (data instanceof SparseDistributedMatrix) {
        SparseDistributedMatrix distributedMatrix = (SparseDistributedMatrix) data;
        if (distributedMatrix.getStorage().storageMode() == StorageConstants.ROW_STORAGE_MODE)
            return weights -> calculateDistributedGradient(distributedMatrix, weights);
    }
    Matrix inputs = extractInputs(data);
    Vector groundTruth = extractGroundTruth(data);
    return weights -> lossGradient.compute(inputs, groundTruth, weights);
}
Also used : FunctionVector(org.apache.ignite.ml.math.impls.vector.FunctionVector) SparseDistributedMatrixMapReducer(org.apache.ignite.ml.optimization.util.SparseDistributedMatrixMapReducer) Vector(org.apache.ignite.ml.math.Vector) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) Matrix(org.apache.ignite.ml.math.Matrix) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) StorageConstants(org.apache.ignite.ml.math.StorageConstants) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) Matrix(org.apache.ignite.ml.math.Matrix) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) FunctionVector(org.apache.ignite.ml.math.impls.vector.FunctionVector) Vector(org.apache.ignite.ml.math.Vector) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector)

Example 4 with IgniteFunction

use of org.apache.ignite.ml.math.functions.IgniteFunction in project ignite by apache.

the class DistributedWorkersChainTest method testDistributed.

/**
 */
public void testDistributed() {
    ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create();
    IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache = TestGroupTrainingCache.getOrCreate(ignite);
    int init = 1;
    UUID trainingUUID = UUID.randomUUID();
    TestLocalContext locCtx = new TestLocalContext(0, trainingUUID);
    Map<GroupTrainerCacheKey<Double>, Integer> m = new HashMap<>();
    m.put(new GroupTrainerCacheKey<>(0L, 1.0, trainingUUID), 1);
    m.put(new GroupTrainerCacheKey<>(1L, 2.0, trainingUUID), 2);
    m.put(new GroupTrainerCacheKey<>(2L, 3.0, trainingUUID), 3);
    m.put(new GroupTrainerCacheKey<>(3L, 4.0, trainingUUID), 4);
    Stream<GroupTrainerCacheKey<Double>> keys = m.keySet().stream();
    cache.putAll(m);
    IgniteBiFunction<Integer, TestLocalContext, IgniteSupplier<Stream<GroupTrainerCacheKey<Double>>>> function = (o, l) -> () -> keys;
    IgniteFunction<List<Integer>, Integer> max = ints -> ints.stream().mapToInt(x -> x).max().orElse(Integer.MIN_VALUE);
    Integer res = chain.thenDistributedForEntries((integer, context) -> () -> null, this::readAndIncrement, function, max).process(init, new GroupTrainingContext<>(locCtx, cache, ignite));
    int localMax = m.values().stream().max(Comparator.comparingInt(i -> i)).orElse(Integer.MIN_VALUE);
    assertEquals((long) localMax, (long) res);
    for (GroupTrainerCacheKey<Double> key : m.keySet()) m.compute(key, (k, v) -> v + 1);
    assertMapEqualsCache(m, cache);
}
Also used : GridCommonAbstractTest(org.apache.ignite.testframework.junits.common.GridCommonAbstractTest) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) Chains(org.apache.ignite.ml.trainers.group.chain.Chains) ComputationsChain(org.apache.ignite.ml.trainers.group.chain.ComputationsChain) HashMap(java.util.HashMap) UUID(java.util.UUID) Ignite(org.apache.ignite.Ignite) IgniteCache(org.apache.ignite.IgniteCache) EntryAndContext(org.apache.ignite.ml.trainers.group.chain.EntryAndContext) List(java.util.List) Stream(java.util.stream.Stream) Ignition(org.apache.ignite.Ignition) IgniteBiFunction(org.apache.ignite.ml.math.functions.IgniteBiFunction) Map(java.util.Map) Comparator(java.util.Comparator) Assert(org.junit.Assert) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) HashMap(java.util.HashMap) List(java.util.List) UUID(java.util.UUID)

Example 5 with IgniteFunction

use of org.apache.ignite.ml.math.functions.IgniteFunction in project ignite by apache.

the class ColumnDecisionTreeTrainerBenchmark method testByGenStreamerLoad.

/**
 */
private void testByGenStreamerLoad(int ptsPerReg, HashMap<Integer, Integer> catsInfo, SplitDataGenerator<DenseLocalOnHeapVector> gen, Random rnd) {
    List<IgniteBiTuple<Integer, DenseLocalOnHeapVector>> lst = gen.points(ptsPerReg, (i, rn) -> i).collect(Collectors.toList());
    int featCnt = gen.featuresCnt();
    Collections.shuffle(lst, rnd);
    int numRegs = gen.regsCount();
    SparseDistributedMatrix m = new SparseDistributedMatrix(numRegs * ptsPerReg, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
    IgniteFunction<DoubleStream, Double> regCalc = s -> s.average().orElse(0.0);
    Map<Integer, List<LabeledVectorDouble>> byRegion = new HashMap<>();
    SparseDistributedMatrixStorage sto = (SparseDistributedMatrixStorage) m.getStorage();
    long before = System.currentTimeMillis();
    X.println("Batch loading started...");
    loadVectorsIntoSparseDistributedMatrixCache(sto.cache().getName(), sto.getUUID(), gen.points(ptsPerReg, (i, rn) -> i).map(IgniteBiTuple::get2).iterator(), featCnt + 1);
    X.println("Batch loading took " + (System.currentTimeMillis() - before) + " ms.");
    for (IgniteBiTuple<Integer, DenseLocalOnHeapVector> bt : lst) {
        byRegion.putIfAbsent(bt.get1(), new LinkedList<>());
        byRegion.get(bt.get1()).add(asLabeledVector(bt.get2().getStorage().data()));
    }
    ColumnDecisionTreeTrainer<VarianceSplitCalculator.VarianceData> trainer = new ColumnDecisionTreeTrainer<>(2, ContinuousSplitCalculators.VARIANCE, RegionCalculators.VARIANCE, regCalc, ignite);
    before = System.currentTimeMillis();
    DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, catsInfo));
    X.println("Training took: " + (System.currentTimeMillis() - before) + " ms.");
    byRegion.keySet().forEach(k -> {
        LabeledVectorDouble sp = byRegion.get(k).get(0);
        Tracer.showAscii(sp.features());
        X.println("Predicted value and label [pred=" + mdl.apply(sp.features()) + ", label=" + sp.doubleLabel() + "]");
        assert mdl.apply(sp.features()) == sp.doubleLabel();
    });
}
Also used : CacheAtomicityMode(org.apache.ignite.cache.CacheAtomicityMode) Arrays(java.util.Arrays) FeaturesCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache) IgniteTestResources(org.apache.ignite.testframework.junits.IgniteTestResources) Random(java.util.Random) BiIndex(org.apache.ignite.ml.trees.trainers.columnbased.BiIndex) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) SparseDistributedMatrixStorage(org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorage) VarianceSplitCalculator(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.VarianceSplitCalculator) Vector(org.apache.ignite.ml.math.Vector) Estimators(org.apache.ignite.ml.estimators.Estimators) Map(java.util.Map) X(org.apache.ignite.internal.util.typedef.X) Level(org.apache.log4j.Level) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) LabeledVectorDouble(org.apache.ignite.ml.structures.LabeledVectorDouble) BaseDecisionTreeTest(org.apache.ignite.ml.trees.BaseDecisionTreeTest) IgniteTriFunction(org.apache.ignite.ml.math.functions.IgniteTriFunction) ProjectionsCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache) UUID(java.util.UUID) StreamTransformer(org.apache.ignite.stream.StreamTransformer) Collectors(java.util.stream.Collectors) IgniteCache(org.apache.ignite.IgniteCache) ContextCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.ContextCache) DoubleStream(java.util.stream.DoubleStream) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) List(java.util.List) IgniteConfiguration(org.apache.ignite.configuration.IgniteConfiguration) Stream(java.util.stream.Stream) SparseMatrixKey(org.apache.ignite.ml.math.distributed.keys.impl.SparseMatrixKey) SplitCache(org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache) RegionCalculators(org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators) IntStream(java.util.stream.IntStream) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) Model(org.apache.ignite.ml.Model) HashMap(java.util.HashMap) Function(java.util.function.Function) GiniSplitCalculator(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator) BiIndexedCacheColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput) CacheWriteSynchronizationMode(org.apache.ignite.cache.CacheWriteSynchronizationMode) IgniteUtils(org.apache.ignite.internal.util.IgniteUtils) MnistUtils(org.apache.ignite.ml.util.MnistUtils) LinkedList(java.util.LinkedList) Properties(java.util.Properties) Iterator(java.util.Iterator) ContinuousSplitCalculators(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators) IOException(java.io.IOException) SplitDataGenerator(org.apache.ignite.ml.trees.SplitDataGenerator) Int2DoubleOpenHashMap(it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap) Ignition(org.apache.ignite.Ignition) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) IgniteDataStreamer(org.apache.ignite.IgniteDataStreamer) Tracer(org.apache.ignite.ml.math.Tracer) StorageConstants(org.apache.ignite.ml.math.StorageConstants) Assert(org.junit.Assert) Collections(java.util.Collections) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer) GridCacheProcessor(org.apache.ignite.internal.processors.cache.GridCacheProcessor) InputStream(java.io.InputStream) CacheMode(org.apache.ignite.cache.CacheMode) SparseDistributedMatrix(org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix) LabeledVectorDouble(org.apache.ignite.ml.structures.LabeledVectorDouble) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) HashMap(java.util.HashMap) Int2DoubleOpenHashMap(it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap) MatrixColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.MatrixColumnDecisionTreeTrainerInput) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) SparseDistributedMatrixStorage(org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorage) LabeledVectorDouble(org.apache.ignite.ml.structures.LabeledVectorDouble) DoubleStream(java.util.stream.DoubleStream) List(java.util.List) LinkedList(java.util.LinkedList) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer)

Aggregations

IgniteFunction (org.apache.ignite.ml.math.functions.IgniteFunction)15 IgniteCache (org.apache.ignite.IgniteCache)8 Arrays (java.util.Arrays)7 List (java.util.List)7 Ignite (org.apache.ignite.Ignite)7 Ignition (org.apache.ignite.Ignition)7 Random (java.util.Random)6 CacheConfiguration (org.apache.ignite.configuration.CacheConfiguration)6 HashMap (java.util.HashMap)5 Map (java.util.Map)5 UUID (java.util.UUID)5 DoubleStream (java.util.stream.DoubleStream)5 StorageConstants (org.apache.ignite.ml.math.StorageConstants)5 SparseDistributedMatrix (org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix)5 DenseLocalOnHeapVector (org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector)5 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)5 IOException (java.io.IOException)4 Serializable (java.io.Serializable)4 Collections (java.util.Collections)4 LinkedList (java.util.LinkedList)4