Search in sources :

Example 1 with LocalDatasetBuilder

use of org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder in project ignite by apache.

the class MeanAbsValueConvergenceCheckerTest method testConvergenceChecking.

/**
 */
@Test
public void testConvergenceChecking() {
    LocalDatasetBuilder<Integer, LabeledVector<Double>> datasetBuilder = new LocalDatasetBuilder<>(data, 1);
    ConvergenceChecker<Integer, LabeledVector<Double>> checker = createChecker(new MeanAbsValueConvergenceCheckerFactory(0.1), datasetBuilder);
    double error = checker.computeError(VectorUtils.of(1, 2), 4.0, notConvergedMdl);
    LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder();
    Assert.assertEquals(1.9, error, 0.01);
    Assert.assertFalse(checker.isConverged(envBuilder, datasetBuilder, notConvergedMdl));
    Assert.assertTrue(checker.isConverged(envBuilder, datasetBuilder, convergedMdl));
    try (LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(vectorizer), envBuilder.buildForTrainer())) {
        double onDSError = checker.computeMeanErrorOnDataset(dataset, notConvergedMdl);
        Assert.assertEquals(1.55, onDSError, 0.01);
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
Also used : EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) FeatureMatrixWithLabelsOnHeapData(org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LocalDatasetBuilder(org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder) ConvergenceCheckerTest(org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerTest) Test(org.junit.Test)

Example 2 with LocalDatasetBuilder

use of org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder in project ignite by apache.

the class RegressionEvaluatorTest method testEvaluatorWithFilter.

/**
 * Test evaluator and trainer with test-train splitting.
 */
@Test
public void testEvaluatorWithFilter() {
    Map<Integer, Vector> data = new HashMap<>();
    data.put(0, VectorUtils.of(60323, 83.0, 234289, 2356, 1590, 107608, 1947));
    data.put(1, VectorUtils.of(61122, 88.5, 259426, 2325, 1456, 108632, 1948));
    data.put(2, VectorUtils.of(60171, 88.2, 258054, 3682, 1616, 109773, 1949));
    data.put(3, VectorUtils.of(61187, 89.5, 284599, 3351, 1650, 110929, 1950));
    data.put(4, VectorUtils.of(63221, 96.2, 328975, 2099, 3099, 112075, 1951));
    data.put(5, VectorUtils.of(63639, 98.1, 346999, 1932, 3594, 113270, 1952));
    data.put(6, VectorUtils.of(64989, 99.0, 365385, 1870, 3547, 115094, 1953));
    data.put(7, VectorUtils.of(63761, 100.0, 363112, 3578, 3350, 116219, 1954));
    data.put(8, VectorUtils.of(66019, 101.2, 397469, 2904, 3048, 117388, 1955));
    data.put(9, VectorUtils.of(68169, 108.4, 442769, 2936, 2798, 120445, 1957));
    data.put(10, VectorUtils.of(66513, 110.8, 444546, 4681, 2637, 121950, 1958));
    data.put(11, VectorUtils.of(68655, 112.6, 482704, 3813, 2552, 123366, 1959));
    data.put(12, VectorUtils.of(69564, 114.2, 502601, 3931, 2514, 125368, 1960));
    data.put(13, VectorUtils.of(69331, 115.7, 518173, 4806, 2572, 127852, 1961));
    data.put(14, VectorUtils.of(70551, 116.9, 554894, 4007, 2827, 130081, 1962));
    KNNRegressionTrainer trainer = new KNNRegressionTrainer().withK(3).withDistanceMeasure(new EuclideanDistance());
    TrainTestSplit<Integer, Vector> split = new TrainTestDatasetSplitter<Integer, Vector>(new SHA256UniformMapper<>(new Random(0))).split(0.5);
    Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST);
    KNNRegressionModel mdl = trainer.fit(data, split.getTestFilter(), parts, vectorizer);
    double score = Evaluator.evaluate(new LocalDatasetBuilder<>(data, split.getTrainFilter(), parts), mdl, vectorizer, new Rss()).getSingle();
    assertEquals(4800164.444444457, score, 1e-4);
}
Also used : KNNRegressionTrainer(org.apache.ignite.ml.knn.regression.KNNRegressionTrainer) SHA256UniformMapper(org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper) HashMap(java.util.HashMap) KNNRegressionModel(org.apache.ignite.ml.knn.regression.KNNRegressionModel) EuclideanDistance(org.apache.ignite.ml.math.distances.EuclideanDistance) Rss(org.apache.ignite.ml.selection.scoring.metric.regression.Rss) Random(java.util.Random) LocalDatasetBuilder(org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Example 3 with LocalDatasetBuilder

use of org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder in project ignite by apache.

the class EncoderTrainerTest method testFitWithUnknownStringValueInTheGivenData.

/**
 * Tests {@code fit()} method.
 */
@Test
public void testFitWithUnknownStringValueInTheGivenData() {
    Map<Integer, Vector> data = new HashMap<>();
    data.put(1, VectorUtils.of(3.0, 0.0));
    data.put(2, VectorUtils.of(3.0, 12.0));
    data.put(3, VectorUtils.of(3.0, 12.0));
    data.put(4, VectorUtils.of(2.0, 45.0));
    data.put(5, VectorUtils.of(2.0, 45.0));
    data.put(6, VectorUtils.of(14.0, 12.0));
    final Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<>(0, 1);
    DatasetBuilder<Integer, Vector> datasetBuilder = new LocalDatasetBuilder<>(data, parts);
    EncoderTrainer<Integer, Vector> strEncoderTrainer = new EncoderTrainer<Integer, Vector>().withEncoderType(EncoderType.STRING_ENCODER).withEncodedFeature(0).withEncodedFeature(1);
    EncoderPreprocessor<Integer, Vector> preprocessor = strEncoderTrainer.fit(TestUtils.testEnvBuilder(), datasetBuilder, vectorizer);
    try {
        preprocessor.apply(7, new DenseVector(new Serializable[] { "Monday", "September" })).features().asArray();
        fail("UnknownCategorialFeatureValue");
    } catch (UnknownCategorialValueException e) {
        return;
    }
    fail("UnknownCategorialFeatureValue");
}
Also used : Serializable(java.io.Serializable) HashMap(java.util.HashMap) DummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer) UnknownCategorialValueException(org.apache.ignite.ml.math.exceptions.preprocessing.UnknownCategorialValueException) LocalDatasetBuilder(org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Example 4 with LocalDatasetBuilder

use of org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder in project ignite by apache.

the class EncoderTrainerTest method testFitOnStringCategorialFeaturesWithReversedOrder.

/**
 * Tests {@code fit()} method.
 */
@Test
public void testFitOnStringCategorialFeaturesWithReversedOrder() {
    Map<Integer, Vector> data = new HashMap<>();
    data.put(1, new DenseVector(new Serializable[] { "Monday", "September" }));
    data.put(2, new DenseVector(new Serializable[] { "Monday", "August" }));
    data.put(3, new DenseVector(new Serializable[] { "Monday", "August" }));
    data.put(4, new DenseVector(new Serializable[] { "Friday", "June" }));
    data.put(5, new DenseVector(new Serializable[] { "Friday", "June" }));
    data.put(6, new DenseVector(new Serializable[] { "Sunday", "August" }));
    final Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<>(0, 1);
    DatasetBuilder<Integer, Vector> datasetBuilder = new LocalDatasetBuilder<>(data, parts);
    EncoderTrainer<Integer, Vector> strEncoderTrainer = new EncoderTrainer<Integer, Vector>().withEncoderType(EncoderType.STRING_ENCODER).withEncoderIndexingStrategy(EncoderSortingStrategy.FREQUENCY_ASC).withEncodedFeature(0).withEncodedFeature(1);
    EncoderPreprocessor<Integer, Vector> preprocessor = strEncoderTrainer.fit(TestUtils.testEnvBuilder(), datasetBuilder, vectorizer);
    assertArrayEquals(new double[] { 2.0, 0.0 }, preprocessor.apply(7, new DenseVector(new Serializable[] { "Monday", "September" })).features().asArray(), 1e-8);
}
Also used : Serializable(java.io.Serializable) HashMap(java.util.HashMap) DummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer) LocalDatasetBuilder(org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Example 5 with LocalDatasetBuilder

use of org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder in project ignite by apache.

the class MedianOfMedianConvergenceCheckerTest method testConvergenceChecking.

/**
 */
@Test
public void testConvergenceChecking() {
    data.put(666, VectorUtils.of(10, 11).labeled(100000.0));
    LocalDatasetBuilder<Integer, LabeledVector<Double>> datasetBuilder = new LocalDatasetBuilder<>(data, 1);
    ConvergenceChecker<Integer, LabeledVector<Double>> checker = createChecker(new MedianOfMedianConvergenceCheckerFactory(0.1), datasetBuilder);
    double error = checker.computeError(VectorUtils.of(1, 2), 4.0, notConvergedMdl);
    Assert.assertEquals(1.9, error, 0.01);
    LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder();
    Assert.assertFalse(checker.isConverged(envBuilder, datasetBuilder, notConvergedMdl));
    Assert.assertTrue(checker.isConverged(envBuilder, datasetBuilder, convergedMdl));
    try (LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(vectorizer), TestUtils.testEnvBuilder().buildForTrainer())) {
        double onDSError = checker.computeMeanErrorOnDataset(dataset, notConvergedMdl);
        Assert.assertEquals(1.6, onDSError, 0.01);
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
Also used : EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) FeatureMatrixWithLabelsOnHeapData(org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LocalDatasetBuilder(org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder) ConvergenceCheckerTest(org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerTest) Test(org.junit.Test)

Aggregations

LocalDatasetBuilder (org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder)9 Test (org.junit.Test)9 HashMap (java.util.HashMap)7 TrainerTest (org.apache.ignite.ml.common.TrainerTest)7 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)7 Serializable (java.io.Serializable)5 DenseVector (org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)5 DummyVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer)3 ConvergenceCheckerTest (org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerTest)2 FeatureMatrixWithLabelsOnHeapData (org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData)2 EmptyContext (org.apache.ignite.ml.dataset.primitive.context.EmptyContext)2 LearningEnvironmentBuilder (org.apache.ignite.ml.environment.LearningEnvironmentBuilder)2 KNNRegressionModel (org.apache.ignite.ml.knn.regression.KNNRegressionModel)2 KNNRegressionTrainer (org.apache.ignite.ml.knn.regression.KNNRegressionTrainer)2 EuclideanDistance (org.apache.ignite.ml.math.distances.EuclideanDistance)2 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)2 Random (java.util.Random)1 UnknownCategorialValueException (org.apache.ignite.ml.math.exceptions.preprocessing.UnknownCategorialValueException)1 Rss (org.apache.ignite.ml.selection.scoring.metric.regression.Rss)1 SHA256UniformMapper (org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper)1