Search in sources :

Example 6 with LearningEnvironment

use of org.apache.ignite.ml.environment.LearningEnvironment in project ignite by apache.

the class LocalDataset method compute.

/**
 * {@inheritDoc}
 */
@Override
public <R> R compute(IgniteBiFunction<D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce, R identity) {
    R res = identity;
    for (int part = 0; part < data.size(); part++) {
        D partData = data.get(part);
        LearningEnvironment env = envs.get(part);
        if (partData != null)
            res = reduce.apply(res, map.apply(partData, env));
    }
    return res;
}
Also used : LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment)

Example 7 with LearningEnvironment

use of org.apache.ignite.ml.environment.LearningEnvironment in project ignite by apache.

the class LocalDatasetBuilder method build.

/**
 * {@inheritDoc}
 */
@Override
public <C extends Serializable, D extends AutoCloseable> LocalDataset<C, D> build(LearningEnvironmentBuilder envBuilder, PartitionContextBuilder<K, V, C> partCtxBuilder, PartitionDataBuilder<K, V, C, D> partDataBuilder, LearningEnvironment learningEnvironment) {
    List<C> ctxList = new ArrayList<>();
    List<D> dataList = new ArrayList<>();
    List<UpstreamEntry<K, V>> entriesList = new ArrayList<>();
    upstreamMap.entrySet().stream().filter(en -> filter.apply(en.getKey(), en.getValue())).map(en -> new UpstreamEntry<>(en.getKey(), en.getValue())).forEach(entriesList::add);
    int partSize = Math.max(1, entriesList.size() / partitions);
    Iterator<UpstreamEntry<K, V>> firstKeysIter = entriesList.iterator();
    Iterator<UpstreamEntry<K, V>> secondKeysIter = entriesList.iterator();
    Iterator<UpstreamEntry<K, V>> thirdKeysIter = entriesList.iterator();
    int ptr = 0;
    List<LearningEnvironment> envs = IntStream.range(0, partitions).boxed().map(envBuilder::buildForWorker).collect(Collectors.toList());
    for (int part = 0; part < partitions; part++) {
        int cntBeforeTransform = part == partitions - 1 ? entriesList.size() - ptr : Math.min(partSize, entriesList.size() - ptr);
        LearningEnvironment env = envs.get(part);
        UpstreamTransformer transformer1 = upstreamTransformerBuilder.build(env);
        UpstreamTransformer transformer2 = Utils.copy(transformer1);
        UpstreamTransformer transformer3 = Utils.copy(transformer1);
        int cnt = (int) transformer1.transform(Utils.asStream(new IteratorWindow<>(thirdKeysIter, k -> k, cntBeforeTransform))).count();
        Iterator<UpstreamEntry> iter = transformer2.transform(Utils.asStream(new IteratorWindow<>(firstKeysIter, k -> k, cntBeforeTransform)).map(x -> (UpstreamEntry) x)).iterator();
        Iterator<UpstreamEntry<K, V>> convertedBack = Utils.asStream(iter).map(x -> (UpstreamEntry<K, V>) x).iterator();
        C ctx = cntBeforeTransform > 0 ? partCtxBuilder.build(env, convertedBack, cnt) : null;
        Iterator<UpstreamEntry> iter1 = transformer3.transform(Utils.asStream(new IteratorWindow<>(secondKeysIter, k -> k, cntBeforeTransform))).iterator();
        Iterator<UpstreamEntry<K, V>> convertedBack1 = Utils.asStream(iter1).map(x -> (UpstreamEntry<K, V>) x).iterator();
        D data = cntBeforeTransform > 0 ? partDataBuilder.build(env, convertedBack1, cnt, ctx) : null;
        ctxList.add(ctx);
        dataList.add(data);
        ptr += cntBeforeTransform;
    }
    return new LocalDataset<>(envs, ctxList, dataList);
}
Also used : IntStream(java.util.stream.IntStream) UpstreamTransformer(org.apache.ignite.ml.dataset.UpstreamTransformer) IgniteBiPredicate(org.apache.ignite.lang.IgniteBiPredicate) Iterator(java.util.Iterator) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) Collectors(java.util.stream.Collectors) Serializable(java.io.Serializable) ArrayList(java.util.ArrayList) List(java.util.List) Utils(org.apache.ignite.ml.util.Utils) PartitionContextBuilder(org.apache.ignite.ml.dataset.PartitionContextBuilder) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) Map(java.util.Map) PartitionDataBuilder(org.apache.ignite.ml.dataset.PartitionDataBuilder) UpstreamTransformerBuilder(org.apache.ignite.ml.dataset.UpstreamTransformerBuilder) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) UpstreamEntry(org.apache.ignite.ml.dataset.UpstreamEntry) ArrayList(java.util.ArrayList) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) UpstreamTransformer(org.apache.ignite.ml.dataset.UpstreamTransformer) UpstreamEntry(org.apache.ignite.ml.dataset.UpstreamEntry)

Example 8 with LearningEnvironment

use of org.apache.ignite.ml.environment.LearningEnvironment in project ignite by apache.

the class GDBOnTreesLearningStrategy method update.

/**
 * {@inheritDoc}
 */
@Override
public <K, V> List<IgniteModel<Vector, Double>> update(GDBModel mdlToUpdate, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> vectorizer) {
    LearningEnvironment environment = envBuilder.buildForTrainer();
    environment.initDeployingContext(vectorizer);
    DatasetTrainer<? extends IgniteModel<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get();
    assert trainer instanceof DecisionTreeTrainer;
    DecisionTreeTrainer decisionTreeTrainer = (DecisionTreeTrainer) trainer;
    List<IgniteModel<Vector, Double>> models = initLearningState(mdlToUpdate);
    ConvergenceChecker<K, V> convCheck = checkConvergenceStgyFactory.create(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, vectorizer);
    try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new DecisionTreeDataBuilder<>(vectorizer, useIdx), environment)) {
        for (int i = 0; i < cntOfIterations; i++) {
            double[] weights = Arrays.copyOf(compositionWeights, models.size());
            WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLbVal);
            ModelsComposition currComposition = new ModelsComposition(models, aggregator);
            if (convCheck.isConverged(dataset, currComposition))
                break;
            dataset.compute(part -> {
                if (part.getCopiedOriginalLabels() == null)
                    part.setCopiedOriginalLabels(Arrays.copyOf(part.getLabels(), part.getLabels().length));
                for (int j = 0; j < part.getLabels().length; j++) {
                    double mdlAnswer = currComposition.predict(VectorUtils.of(part.getFeatures()[j]));
                    double originalLbVal = externalLbToInternalMapping.apply(part.getCopiedOriginalLabels()[j]);
                    part.getLabels()[j] = -loss.gradient(sampleSize, originalLbVal, mdlAnswer);
                }
            });
            long startTs = System.currentTimeMillis();
            models.add(decisionTreeTrainer.fit(dataset));
            double learningTime = (double) (System.currentTimeMillis() - startTs) / 1000.0;
            trainerEnvironment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
        }
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    compositionWeights = Arrays.copyOf(compositionWeights, models.size());
    return models;
}
Also used : EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) WeightedPredictionsAggregator(org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator) ModelsComposition(org.apache.ignite.ml.composition.ModelsComposition) DecisionTreeTrainer(org.apache.ignite.ml.tree.DecisionTreeTrainer) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) DecisionTreeData(org.apache.ignite.ml.tree.data.DecisionTreeData) IgniteModel(org.apache.ignite.ml.IgniteModel) Vector(org.apache.ignite.ml.math.primitives.vector.Vector)

Example 9 with LearningEnvironment

use of org.apache.ignite.ml.environment.LearningEnvironment in project ignite by apache.

the class DataStreamGeneratorFillCacheTest method testCacheFilling.

/**
 */
@Test
public void testCacheFilling() {
    IgniteConfiguration configuration = new IgniteConfiguration().setDiscoverySpi(new TcpDiscoverySpi().setIpFinder(new TcpDiscoveryVmIpFinder().setAddresses(Arrays.asList("127.0.0.1:47500..47509"))));
    String cacheName = "TEST_CACHE";
    CacheConfiguration<UUID, LabeledVector<Double>> cacheConfiguration = new CacheConfiguration<UUID, LabeledVector<Double>>(cacheName).setAffinity(new RendezvousAffinityFunction(false, 10));
    int datasetSize = 5000;
    try (Ignite ignite = Ignition.start(configuration)) {
        IgniteCache<UUID, LabeledVector<Double>> cache = ignite.getOrCreateCache(cacheConfiguration);
        DataStreamGenerator generator = new GaussRandomProducer(0).vectorize(1).asDataStream();
        generator.fillCacheWithVecUUIDAsKey(datasetSize, cache);
        LabeledDummyVectorizer<UUID, Double> vectorizer = new LabeledDummyVectorizer<>();
        CacheBasedDatasetBuilder<UUID, LabeledVector<Double>> datasetBuilder = new CacheBasedDatasetBuilder<>(ignite, cache);
        IgniteFunction<SimpleDatasetData, StatPair> map = data -> new StatPair(DoubleStream.of(data.getFeatures()).sum(), data.getRows());
        LearningEnvironment env = LearningEnvironmentBuilder.defaultBuilder().buildForTrainer();
        env.deployingContext().initByClientObject(map);
        try (CacheBasedDataset<UUID, LabeledVector<Double>, EmptyContext, SimpleDatasetData> dataset = datasetBuilder.build(LearningEnvironmentBuilder.defaultBuilder(), new EmptyContextBuilder<>(), new SimpleDatasetDataBuilder<>(vectorizer), env)) {
            StatPair res = dataset.compute(map, StatPair::sum);
            assertEquals(datasetSize, res.cntOfRows);
            assertEquals(0.0, res.elementsSum / res.cntOfRows, 1e-2);
        }
        ignite.destroyCache(cacheName);
    }
}
Also used : Arrays(java.util.Arrays) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) EmptyContextBuilder(org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder) GaussRandomProducer(org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) SimpleDatasetData(org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData) SimpleDatasetDataBuilder(org.apache.ignite.ml.dataset.primitive.builder.data.SimpleDatasetDataBuilder) GridCommonAbstractTest(org.apache.ignite.testframework.junits.common.GridCommonAbstractTest) TcpDiscoveryVmIpFinder(org.apache.ignite.spi.discovery.tcp.ipfinder.vm.TcpDiscoveryVmIpFinder) Test(org.junit.Test) UUID(java.util.UUID) Ignite(org.apache.ignite.Ignite) CacheBasedDatasetBuilder(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder) IgniteCache(org.apache.ignite.IgniteCache) LabeledDummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer) DoubleStream(java.util.stream.DoubleStream) IgniteConfiguration(org.apache.ignite.configuration.IgniteConfiguration) Ignition(org.apache.ignite.Ignition) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) CacheBasedDataset(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDataset) TcpDiscoverySpi(org.apache.ignite.spi.discovery.tcp.TcpDiscoverySpi) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) TcpDiscoveryVmIpFinder(org.apache.ignite.spi.discovery.tcp.ipfinder.vm.TcpDiscoveryVmIpFinder) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LabeledDummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) Ignite(org.apache.ignite.Ignite) UUID(java.util.UUID) TcpDiscoverySpi(org.apache.ignite.spi.discovery.tcp.TcpDiscoverySpi) GaussRandomProducer(org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer) CacheBasedDatasetBuilder(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) IgniteConfiguration(org.apache.ignite.configuration.IgniteConfiguration) SimpleDatasetData(org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData) GridCommonAbstractTest(org.apache.ignite.testframework.junits.common.GridCommonAbstractTest) Test(org.junit.Test)

Example 10 with LearningEnvironment

use of org.apache.ignite.ml.environment.LearningEnvironment in project ignite by apache.

the class SparkModelParser method parse.

/**
 * Load model from parquet (presented as a directory).
 *
 * @param pathToMdl Path to directory with saved model.
 * @param parsedSparkMdl Parsed spark model.
 */
public static Model parse(String pathToMdl, SupportedSparkModels parsedSparkMdl) throws IllegalArgumentException {
    LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder();
    LearningEnvironment environment = envBuilder.buildForTrainer();
    return extractModel(pathToMdl, parsedSparkMdl, environment);
}
Also used : LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder)

Aggregations

LearningEnvironment (org.apache.ignite.ml.environment.LearningEnvironment)18 LearningEnvironmentBuilder (org.apache.ignite.ml.environment.LearningEnvironmentBuilder)6 UUID (java.util.UUID)5 Serializable (java.io.Serializable)4 PartitionDataBuilder (org.apache.ignite.ml.dataset.PartitionDataBuilder)4 IgniteFunction (org.apache.ignite.ml.math.functions.IgniteFunction)4 ArrayList (java.util.ArrayList)3 Arrays (java.util.Arrays)3 Iterator (java.util.Iterator)3 Map (java.util.Map)3 Ignite (org.apache.ignite.Ignite)3 IgniteCache (org.apache.ignite.IgniteCache)3 Ignition (org.apache.ignite.Ignition)3 IgniteBiPredicate (org.apache.ignite.lang.IgniteBiPredicate)3 PartitionContextBuilder (org.apache.ignite.ml.dataset.PartitionContextBuilder)3 UpstreamTransformer (org.apache.ignite.ml.dataset.UpstreamTransformer)3 UpstreamTransformerBuilder (org.apache.ignite.ml.dataset.UpstreamTransformerBuilder)3 EmptyContext (org.apache.ignite.ml.dataset.primitive.context.EmptyContext)3 BitSet (java.util.BitSet)2 Collection (java.util.Collection)2