Search in sources :

Example 21 with EmptyContext

use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.

the class MedianOfMedianConvergenceCheckerTest method testConvergenceChecking.

/**
 */
@Test
public void testConvergenceChecking() {
    data.put(666, VectorUtils.of(10, 11).labeled(100000.0));
    LocalDatasetBuilder<Integer, LabeledVector<Double>> datasetBuilder = new LocalDatasetBuilder<>(data, 1);
    ConvergenceChecker<Integer, LabeledVector<Double>> checker = createChecker(new MedianOfMedianConvergenceCheckerFactory(0.1), datasetBuilder);
    double error = checker.computeError(VectorUtils.of(1, 2), 4.0, notConvergedMdl);
    Assert.assertEquals(1.9, error, 0.01);
    LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder();
    Assert.assertFalse(checker.isConverged(envBuilder, datasetBuilder, notConvergedMdl));
    Assert.assertTrue(checker.isConverged(envBuilder, datasetBuilder, convergedMdl));
    try (LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(vectorizer), TestUtils.testEnvBuilder().buildForTrainer())) {
        double onDSError = checker.computeMeanErrorOnDataset(dataset, notConvergedMdl);
        Assert.assertEquals(1.6, onDSError, 0.01);
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
Also used : EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) FeatureMatrixWithLabelsOnHeapData(org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LocalDatasetBuilder(org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder) ConvergenceCheckerTest(org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerTest) Test(org.junit.Test)

Example 22 with EmptyContext

use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.

the class ImputerTrainer method fit.

/**
 * {@inheritDoc}
 */
@Override
public ImputerPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> basePreprocessor) {
    PartitionContextBuilder<K, V, EmptyContext> builder = (env, upstream, upstreamSize) -> new EmptyContext();
    try (Dataset<EmptyContext, ImputerPartitionData> dataset = datasetBuilder.build(envBuilder, builder, (env, upstream, upstreamSize, ctx) -> {
        double[] sums = null;
        int[] counts = null;
        double[] maxs = null;
        double[] mins = null;
        Map<Double, Integer>[] valuesByFreq = null;
        while (upstream.hasNext()) {
            UpstreamEntry<K, V> entity = upstream.next();
            LabeledVector row = basePreprocessor.apply(entity.getKey(), entity.getValue());
            switch(imputingStgy) {
                case MEAN:
                    sums = updateTheSums(row, sums);
                    counts = updateTheCounts(row, counts);
                    break;
                case MOST_FREQUENT:
                    valuesByFreq = updateFrequenciesByGivenRow(row, valuesByFreq);
                    break;
                case LEAST_FREQUENT:
                    valuesByFreq = updateFrequenciesByGivenRow(row, valuesByFreq);
                    break;
                case MAX:
                    maxs = updateTheMaxs(row, maxs);
                    break;
                case MIN:
                    mins = updateTheMins(row, mins);
                    break;
                case COUNT:
                    counts = updateTheCounts(row, counts);
                    break;
                default:
                    throw new UnsupportedOperationException("The chosen strategy is not supported");
            }
        }
        ImputerPartitionData partData;
        switch(imputingStgy) {
            case MEAN:
                partData = new ImputerPartitionData().withSums(sums).withCounts(counts);
                break;
            case MOST_FREQUENT:
                partData = new ImputerPartitionData().withValuesByFrequency(valuesByFreq);
                break;
            case LEAST_FREQUENT:
                partData = new ImputerPartitionData().withValuesByFrequency(valuesByFreq);
                break;
            case MAX:
                partData = new ImputerPartitionData().withMaxs(maxs);
                break;
            case MIN:
                partData = new ImputerPartitionData().withMins(mins);
                break;
            case COUNT:
                partData = new ImputerPartitionData().withCounts(counts);
                break;
            default:
                throw new UnsupportedOperationException("The chosen strategy is not supported");
        }
        return partData;
    }, learningEnvironment(basePreprocessor))) {
        Vector imputingValues;
        switch(imputingStgy) {
            case MEAN:
                imputingValues = VectorUtils.of(calculateImputingValuesBySumsAndCounts(dataset));
                break;
            case MOST_FREQUENT:
                imputingValues = VectorUtils.of(calculateImputingValuesByTheMostFrequentValues(dataset));
                break;
            case LEAST_FREQUENT:
                imputingValues = VectorUtils.of(calculateImputingValuesByTheLeastFrequentValues(dataset));
                break;
            case MAX:
                imputingValues = VectorUtils.of(calculateImputingValuesByMaxValues(dataset));
                break;
            case MIN:
                imputingValues = VectorUtils.of(calculateImputingValuesByMinValues(dataset));
                break;
            case COUNT:
                imputingValues = VectorUtils.of(calculateImputingValuesByCounts(dataset));
                break;
            default:
                throw new UnsupportedOperationException("The chosen strategy is not supported");
        }
        return new ImputerPreprocessor<>(imputingValues, basePreprocessor);
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
Also used : Arrays(java.util.Arrays) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) HashMap(java.util.HashMap) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) PreprocessingTrainer(org.apache.ignite.ml.preprocessing.PreprocessingTrainer) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) PartitionContextBuilder(org.apache.ignite.ml.dataset.PartitionContextBuilder) VectorUtils(org.apache.ignite.ml.math.primitives.vector.VectorUtils) Dataset(org.apache.ignite.ml.dataset.Dataset) Map(java.util.Map) Optional(java.util.Optional) Comparator(java.util.Comparator) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) UpstreamEntry(org.apache.ignite.ml.dataset.UpstreamEntry) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) HashMap(java.util.HashMap) Map(java.util.Map) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector)

Example 23 with EmptyContext

use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.

the class LearningEnvironmentTest method testRandomNumbersGenerator.

/**
 * Test random number generator provided by  {@link LearningEnvironment}.
 * We test that:
 * 1. Correct random generator is returned for each partition.
 * 2. Its state is saved between compute calls (for this we do several iterations of compute).
 */
@Test
public void testRandomNumbersGenerator() {
    // We make such builders that provide as functions returning partition index * iteration as random number generator nextInt
    LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder().withRandomDependency(MockRandom::new);
    int partitions = 10;
    int iterations = 2;
    DatasetTrainer<IgniteModel<Object, Vector>, Void> trainer = new DatasetTrainer<IgniteModel<Object, Vector>, Void>() {

        /**
         * {@inheritDoc}
         */
        @Override
        public <K, V> IgniteModel<Object, Vector> fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
            Dataset<EmptyContext, TestUtils.DataWrapper<Integer>> ds = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), (PartitionDataBuilder<K, V, EmptyContext, TestUtils.DataWrapper<Integer>>) (env, upstreamData, upstreamDataSize, ctx) -> TestUtils.DataWrapper.of(env.partition()), envBuilder.buildForTrainer());
            Vector v = null;
            for (int iter = 0; iter < iterations; iter++) {
                v = ds.compute((dw, env) -> VectorUtils.fill(-1, partitions).set(env.partition(), env.randomNumbersGenerator().nextInt()), (v1, v2) -> zipOverridingEmpty(v1, v2, -1));
            }
            return constantModel(v);
        }

        /**
         * {@inheritDoc}
         */
        @Override
        public boolean isUpdateable(IgniteModel<Object, Vector> mdl) {
            return false;
        }

        /**
         * {@inheritDoc}
         */
        @Override
        protected <K, V> IgniteModel<Object, Vector> updateModel(IgniteModel<Object, Vector> mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
            return null;
        }
    };
    trainer.withEnvironmentBuilder(envBuilder);
    IgniteModel<Object, Vector> mdl = trainer.fit(getCacheMock(partitions), partitions, null);
    Vector exp = VectorUtils.zeroes(partitions);
    for (int i = 0; i < partitions; i++) exp.set(i, i * iterations);
    Vector res = mdl.predict(null);
    assertEquals(exp, res);
}
Also used : IntStream(java.util.stream.IntStream) TestUtils(org.apache.ignite.ml.TestUtils) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) Random(java.util.Random) DatasetTrainer(org.apache.ignite.ml.trainers.DatasetTrainer) ParallelismStrategy(org.apache.ignite.ml.environment.parallelism.ParallelismStrategy) FeatureMeta(org.apache.ignite.ml.dataset.feature.FeatureMeta) RandomForestRegressionTrainer(org.apache.ignite.ml.tree.randomforest.RandomForestRegressionTrainer) Map(java.util.Map) EmptyContextBuilder(org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder) MLLogger(org.apache.ignite.ml.environment.logging.MLLogger) PartitionDataBuilder(org.apache.ignite.ml.dataset.PartitionDataBuilder) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) ConsoleLogger(org.apache.ignite.ml.environment.logging.ConsoleLogger) Test(org.junit.Test) FeaturesCountSelectionStrategies(org.apache.ignite.ml.tree.randomforest.data.FeaturesCountSelectionStrategies) IgniteModel(org.apache.ignite.ml.IgniteModel) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) Collectors(java.util.stream.Collectors) VectorUtils(org.apache.ignite.ml.math.primitives.vector.VectorUtils) Dataset(org.apache.ignite.ml.dataset.Dataset) DefaultParallelismStrategy(org.apache.ignite.ml.environment.parallelism.DefaultParallelismStrategy) TestUtils.constantModel(org.apache.ignite.ml.TestUtils.constantModel) Assert.assertEquals(org.junit.Assert.assertEquals) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) DatasetTrainer(org.apache.ignite.ml.trainers.DatasetTrainer) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) IgniteModel(org.apache.ignite.ml.IgniteModel) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Test(org.junit.Test)

Aggregations

EmptyContext (org.apache.ignite.ml.dataset.primitive.context.EmptyContext)23 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)16 Dataset (org.apache.ignite.ml.dataset.Dataset)12 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)12 DatasetBuilder (org.apache.ignite.ml.dataset.DatasetBuilder)11 Preprocessor (org.apache.ignite.ml.preprocessing.Preprocessor)11 Arrays (java.util.Arrays)9 LearningEnvironmentBuilder (org.apache.ignite.ml.environment.LearningEnvironmentBuilder)9 UpstreamEntry (org.apache.ignite.ml.dataset.UpstreamEntry)6 SingleLabelDatasetTrainer (org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer)6 LearningEnvironment (org.apache.ignite.ml.environment.LearningEnvironment)5 ArrayList (java.util.ArrayList)4 Map (java.util.Map)4 PartitionDataBuilder (org.apache.ignite.ml.dataset.PartitionDataBuilder)4 FeatureMatrixWithLabelsOnHeapData (org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData)4 DenseVector (org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)4 NotNull (org.jetbrains.annotations.NotNull)4 Serializable (java.io.Serializable)3 List (java.util.List)3 Optional (java.util.Optional)3