Search in sources :

Example 16 with EmptyContext

use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.

the class LogisticRegressionSGDTrainer method updateModel.

/**
 * {@inheritDoc}
 */
@Override
protected <K, V> LogisticRegressionModel updateModel(LogisticRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
    IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> {
        Integer cols = dataset.compute(data -> {
            if (data.getFeatures() == null)
                return null;
            return data.getFeatures().length / data.getRows();
        }, (a, b) -> {
            // If both are null then zero will be propagated, no good.
            if (a == null)
                return b;
            return a;
        });
        if (cols == null)
            throw new IllegalStateException("Cannot train on empty dataset");
        MLPArchitecture architecture = new MLPArchitecture(cols);
        architecture = architecture.withAddedLayer(1, true, Activators.SIGMOID);
        return architecture;
    };
    MLPTrainer<?> trainer = new MLPTrainer<>(archSupplier, LossFunctions.L2, updatesStgy, maxIterations, batchSize, locIterations, seed).withEnvironmentBuilder(envBuilder);
    MultilayerPerceptron mlp;
    IgniteFunction<LabeledVector<Double>, LabeledVector<double[]>> func = lv -> new LabeledVector<>(lv.features(), new double[] { lv.label() });
    PatchedPreprocessor<K, V, Double, double[]> patchedPreprocessor = new PatchedPreprocessor<>(func, extractor);
    if (mdl != null) {
        mlp = restoreMLPState(mdl);
        mlp = trainer.update(mlp, datasetBuilder, patchedPreprocessor);
    } else
        mlp = trainer.fit(datasetBuilder, patchedPreprocessor);
    double[] params = mlp.parameters().getStorage().data();
    return new LogisticRegressionModel(new DenseVector(Arrays.copyOf(params, params.length - 1)), params[params.length - 1]);
}
Also used : SimpleGDUpdateCalculator(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator) Arrays(java.util.Arrays) Activators(org.apache.ignite.ml.nn.Activators) SimpleGDParameterUpdate(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate) UpdatesStrategy(org.apache.ignite.ml.nn.UpdatesStrategy) SimpleLabeledDatasetData(org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) MLPArchitecture(org.apache.ignite.ml.nn.architecture.MLPArchitecture) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) Dataset(org.apache.ignite.ml.dataset.Dataset) SingleLabelDatasetTrainer(org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer) LossFunctions(org.apache.ignite.ml.optimization.LossFunctions) MultilayerPerceptron(org.apache.ignite.ml.nn.MultilayerPerceptron) PatchedPreprocessor(org.apache.ignite.ml.preprocessing.developer.PatchedPreprocessor) NotNull(org.jetbrains.annotations.NotNull) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) MLPTrainer(org.apache.ignite.ml.nn.MLPTrainer) MLPArchitecture(org.apache.ignite.ml.nn.architecture.MLPArchitecture) Dataset(org.apache.ignite.ml.dataset.Dataset) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) MultilayerPerceptron(org.apache.ignite.ml.nn.MultilayerPerceptron) PatchedPreprocessor(org.apache.ignite.ml.preprocessing.developer.PatchedPreprocessor) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)

Example 17 with EmptyContext

use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.

the class Deltas method updateModel.

/**
 * {@inheritDoc}
 */
@Override
protected <K, V> SVMLinearClassificationModel updateModel(SVMLinearClassificationModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
    assert datasetBuilder != null;
    IgniteFunction<Double, Double> lbTransformer = lb -> {
        if (lb == 0.0)
            return -1.0;
        else
            return lb;
    };
    IgniteFunction<LabeledVector<Double>, LabeledVector<Double>> func = lv -> new LabeledVector<>(lv.features(), lbTransformer.apply(lv.label()));
    PatchedPreprocessor<K, V, Double, Double> patchedPreprocessor = new PatchedPreprocessor<>(func, preprocessor);
    PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(patchedPreprocessor);
    Vector weights;
    try (Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, learningEnvironment())) {
        if (mdl == null) {
            final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
                if (a == null)
                    return b == null ? 0 : b;
                if (b == null)
                    return a;
                return b;
            });
            final int weightVectorSizeWithIntercept = cols + 1;
            weights = initializeWeightsWithZeros(weightVectorSizeWithIntercept);
        } else
            weights = getStateVector(mdl);
        for (int i = 0; i < this.getAmountOfIterations(); i++) {
            Vector deltaWeights = calculateUpdates(weights, dataset);
            if (deltaWeights == null)
                return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
            // creates new vector
            weights = weights.plus(deltaWeights);
        }
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    return new SVMLinearClassificationModel(weights.copyOfRange(1, weights.size()), weights.get(0));
}
Also used : IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) Random(java.util.Random) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) SparseVector(org.apache.ignite.ml.math.primitives.vector.impl.SparseVector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) Dataset(org.apache.ignite.ml.dataset.Dataset) SingleLabelDatasetTrainer(org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer) LabeledVectorSet(org.apache.ignite.ml.structures.LabeledVectorSet) PatchedPreprocessor(org.apache.ignite.ml.preprocessing.developer.PatchedPreprocessor) PartitionDataBuilder(org.apache.ignite.ml.dataset.PartitionDataBuilder) NotNull(org.jetbrains.annotations.NotNull) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) UpstreamEntry(org.apache.ignite.ml.dataset.UpstreamEntry) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) LabeledDatasetPartitionDataBuilderOnHeap(org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) LabeledDatasetPartitionDataBuilderOnHeap(org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap) Dataset(org.apache.ignite.ml.dataset.Dataset) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LabeledVectorSet(org.apache.ignite.ml.structures.LabeledVectorSet) PatchedPreprocessor(org.apache.ignite.ml.preprocessing.developer.PatchedPreprocessor) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) SparseVector(org.apache.ignite.ml.math.primitives.vector.impl.SparseVector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)

Example 18 with EmptyContext

use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.

the class GDBOnTreesLearningStrategy method update.

/**
 * {@inheritDoc}
 */
@Override
public <K, V> List<IgniteModel<Vector, Double>> update(GDBModel mdlToUpdate, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> vectorizer) {
    LearningEnvironment environment = envBuilder.buildForTrainer();
    environment.initDeployingContext(vectorizer);
    DatasetTrainer<? extends IgniteModel<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get();
    assert trainer instanceof DecisionTreeTrainer;
    DecisionTreeTrainer decisionTreeTrainer = (DecisionTreeTrainer) trainer;
    List<IgniteModel<Vector, Double>> models = initLearningState(mdlToUpdate);
    ConvergenceChecker<K, V> convCheck = checkConvergenceStgyFactory.create(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, vectorizer);
    try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new DecisionTreeDataBuilder<>(vectorizer, useIdx), environment)) {
        for (int i = 0; i < cntOfIterations; i++) {
            double[] weights = Arrays.copyOf(compositionWeights, models.size());
            WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLbVal);
            ModelsComposition currComposition = new ModelsComposition(models, aggregator);
            if (convCheck.isConverged(dataset, currComposition))
                break;
            dataset.compute(part -> {
                if (part.getCopiedOriginalLabels() == null)
                    part.setCopiedOriginalLabels(Arrays.copyOf(part.getLabels(), part.getLabels().length));
                for (int j = 0; j < part.getLabels().length; j++) {
                    double mdlAnswer = currComposition.predict(VectorUtils.of(part.getFeatures()[j]));
                    double originalLbVal = externalLbToInternalMapping.apply(part.getCopiedOriginalLabels()[j]);
                    part.getLabels()[j] = -loss.gradient(sampleSize, originalLbVal, mdlAnswer);
                }
            });
            long startTs = System.currentTimeMillis();
            models.add(decisionTreeTrainer.fit(dataset));
            double learningTime = (double) (System.currentTimeMillis() - startTs) / 1000.0;
            trainerEnvironment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
        }
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    compositionWeights = Arrays.copyOf(compositionWeights, models.size());
    return models;
}
Also used : EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) WeightedPredictionsAggregator(org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator) ModelsComposition(org.apache.ignite.ml.composition.ModelsComposition) DecisionTreeTrainer(org.apache.ignite.ml.tree.DecisionTreeTrainer) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) DecisionTreeData(org.apache.ignite.ml.tree.data.DecisionTreeData) IgniteModel(org.apache.ignite.ml.IgniteModel) Vector(org.apache.ignite.ml.math.primitives.vector.Vector)

Example 19 with EmptyContext

use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.

the class DataStreamGeneratorFillCacheTest method testCacheFilling.

/**
 */
@Test
public void testCacheFilling() {
    IgniteConfiguration configuration = new IgniteConfiguration().setDiscoverySpi(new TcpDiscoverySpi().setIpFinder(new TcpDiscoveryVmIpFinder().setAddresses(Arrays.asList("127.0.0.1:47500..47509"))));
    String cacheName = "TEST_CACHE";
    CacheConfiguration<UUID, LabeledVector<Double>> cacheConfiguration = new CacheConfiguration<UUID, LabeledVector<Double>>(cacheName).setAffinity(new RendezvousAffinityFunction(false, 10));
    int datasetSize = 5000;
    try (Ignite ignite = Ignition.start(configuration)) {
        IgniteCache<UUID, LabeledVector<Double>> cache = ignite.getOrCreateCache(cacheConfiguration);
        DataStreamGenerator generator = new GaussRandomProducer(0).vectorize(1).asDataStream();
        generator.fillCacheWithVecUUIDAsKey(datasetSize, cache);
        LabeledDummyVectorizer<UUID, Double> vectorizer = new LabeledDummyVectorizer<>();
        CacheBasedDatasetBuilder<UUID, LabeledVector<Double>> datasetBuilder = new CacheBasedDatasetBuilder<>(ignite, cache);
        IgniteFunction<SimpleDatasetData, StatPair> map = data -> new StatPair(DoubleStream.of(data.getFeatures()).sum(), data.getRows());
        LearningEnvironment env = LearningEnvironmentBuilder.defaultBuilder().buildForTrainer();
        env.deployingContext().initByClientObject(map);
        try (CacheBasedDataset<UUID, LabeledVector<Double>, EmptyContext, SimpleDatasetData> dataset = datasetBuilder.build(LearningEnvironmentBuilder.defaultBuilder(), new EmptyContextBuilder<>(), new SimpleDatasetDataBuilder<>(vectorizer), env)) {
            StatPair res = dataset.compute(map, StatPair::sum);
            assertEquals(datasetSize, res.cntOfRows);
            assertEquals(0.0, res.elementsSum / res.cntOfRows, 1e-2);
        }
        ignite.destroyCache(cacheName);
    }
}
Also used : Arrays(java.util.Arrays) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) EmptyContextBuilder(org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder) GaussRandomProducer(org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) SimpleDatasetData(org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData) SimpleDatasetDataBuilder(org.apache.ignite.ml.dataset.primitive.builder.data.SimpleDatasetDataBuilder) GridCommonAbstractTest(org.apache.ignite.testframework.junits.common.GridCommonAbstractTest) TcpDiscoveryVmIpFinder(org.apache.ignite.spi.discovery.tcp.ipfinder.vm.TcpDiscoveryVmIpFinder) Test(org.junit.Test) UUID(java.util.UUID) Ignite(org.apache.ignite.Ignite) CacheBasedDatasetBuilder(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder) IgniteCache(org.apache.ignite.IgniteCache) LabeledDummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer) DoubleStream(java.util.stream.DoubleStream) IgniteConfiguration(org.apache.ignite.configuration.IgniteConfiguration) Ignition(org.apache.ignite.Ignition) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) CacheBasedDataset(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDataset) TcpDiscoverySpi(org.apache.ignite.spi.discovery.tcp.TcpDiscoverySpi) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) TcpDiscoveryVmIpFinder(org.apache.ignite.spi.discovery.tcp.ipfinder.vm.TcpDiscoveryVmIpFinder) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LabeledDummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) Ignite(org.apache.ignite.Ignite) UUID(java.util.UUID) TcpDiscoverySpi(org.apache.ignite.spi.discovery.tcp.TcpDiscoverySpi) GaussRandomProducer(org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer) CacheBasedDatasetBuilder(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) IgniteConfiguration(org.apache.ignite.configuration.IgniteConfiguration) SimpleDatasetData(org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData) GridCommonAbstractTest(org.apache.ignite.testframework.junits.common.GridCommonAbstractTest) Test(org.junit.Test)

Example 20 with EmptyContext

use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.

the class ConvergenceChecker method isConverged.

/**
 * Checks convergency on dataset.
 *
 * @param envBuilder Learning environment builder.
 * @param currMdl Current model.
 * @return True if GDB is converged.
 */
public boolean isConverged(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, ModelsComposition currMdl) {
    LearningEnvironment environment = envBuilder.buildForTrainer();
    environment.initDeployingContext(preprocessor);
    try (Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(preprocessor), environment)) {
        return isConverged(dataset, currMdl);
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
Also used : LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) FeatureMatrixWithLabelsOnHeapData(org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData)

Aggregations

EmptyContext (org.apache.ignite.ml.dataset.primitive.context.EmptyContext)23 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)16 Dataset (org.apache.ignite.ml.dataset.Dataset)12 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)12 DatasetBuilder (org.apache.ignite.ml.dataset.DatasetBuilder)11 Preprocessor (org.apache.ignite.ml.preprocessing.Preprocessor)11 Arrays (java.util.Arrays)9 LearningEnvironmentBuilder (org.apache.ignite.ml.environment.LearningEnvironmentBuilder)9 UpstreamEntry (org.apache.ignite.ml.dataset.UpstreamEntry)6 SingleLabelDatasetTrainer (org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer)6 LearningEnvironment (org.apache.ignite.ml.environment.LearningEnvironment)5 ArrayList (java.util.ArrayList)4 Map (java.util.Map)4 PartitionDataBuilder (org.apache.ignite.ml.dataset.PartitionDataBuilder)4 FeatureMatrixWithLabelsOnHeapData (org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData)4 DenseVector (org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)4 NotNull (org.jetbrains.annotations.NotNull)4 Serializable (java.io.Serializable)3 List (java.util.List)3 Optional (java.util.Optional)3