Search in sources :

Example 6 with KNNRegressionTrainer

use of org.apache.ignite.ml.knn.regression.KNNRegressionTrainer in project ignite by apache.

the class KNNRegressionTest method testUpdate.

/**
 */
@Test
public void testUpdate() {
    Map<Integer, double[]> data = new HashMap<>();
    data.put(0, new double[] { 11.0, 0, 0, 0, 0, 0 });
    data.put(1, new double[] { 12.0, 2.0, 0, 0, 0, 0 });
    data.put(2, new double[] { 13.0, 0, 3.0, 0, 0, 0 });
    data.put(3, new double[] { 14.0, 0, 0, 4.0, 0, 0 });
    data.put(4, new double[] { 15.0, 0, 0, 0, 5.0, 0 });
    data.put(5, new double[] { 16.0, 0, 0, 0, 0, 6.0 });
    KNNRegressionTrainer trainer = new KNNRegressionTrainer().withK(1).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
    KNNRegressionModel originalMdlOnEmptyDataset = trainer.fit(new HashMap<>(), parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST));
    KNNRegressionModel updatedOnDataset = trainer.update(originalMdlOnEmptyDataset, data, parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST));
    Vector vector = VectorUtils.of(0.0, 0.0, 0.0, 5.0, 0.0);
    assertNull(originalMdlOnEmptyDataset.predict(vector));
    assertEquals(Double.valueOf(15.0), updatedOnDataset.predict(vector));
}
Also used : EuclideanDistance(org.apache.ignite.ml.math.distances.EuclideanDistance) KNNRegressionTrainer(org.apache.ignite.ml.knn.regression.KNNRegressionTrainer) DoubleArrayVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer) HashMap(java.util.HashMap) KNNRegressionModel(org.apache.ignite.ml.knn.regression.KNNRegressionModel) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Example 7 with KNNRegressionTrainer

use of org.apache.ignite.ml.knn.regression.KNNRegressionTrainer in project ignite by apache.

the class KNNRegressionTest method testSimpleRegressionWithOneNeighbour.

/**
 */
@Test
public void testSimpleRegressionWithOneNeighbour() {
    Map<Integer, double[]> data = new HashMap<>();
    data.put(0, new double[] { 11.0, 0.0, 0.0, 0.0, 0.0, 0.0 });
    data.put(1, new double[] { 12.0, 2.0, 0.0, 0.0, 0.0, 0.0 });
    data.put(2, new double[] { 13.0, 0.0, 3.0, 0.0, 0.0, 0.0 });
    data.put(3, new double[] { 14.0, 0.0, 0.0, 4.0, 0.0, 0.0 });
    data.put(4, new double[] { 15.0, 0.0, 0.0, 0.0, 5.0, 0.0 });
    data.put(5, new double[] { 16.0, 0.0, 0.0, 0.0, 0.0, 6.0 });
    KNNRegressionTrainer trainer = new KNNRegressionTrainer().withK(1).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
    KNNRegressionModel knnMdl = trainer.fit(new LocalDatasetBuilder<>(data, parts), new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST));
    assertEquals(15, knnMdl.predict(VectorUtils.of(0.0, 0.0, 0.0, 5.0, 0.0)), 1E-12);
}
Also used : EuclideanDistance(org.apache.ignite.ml.math.distances.EuclideanDistance) KNNRegressionTrainer(org.apache.ignite.ml.knn.regression.KNNRegressionTrainer) DoubleArrayVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer) HashMap(java.util.HashMap) KNNRegressionModel(org.apache.ignite.ml.knn.regression.KNNRegressionModel) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Aggregations

KNNRegressionModel (org.apache.ignite.ml.knn.regression.KNNRegressionModel)7 KNNRegressionTrainer (org.apache.ignite.ml.knn.regression.KNNRegressionTrainer)7 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)6 HashMap (java.util.HashMap)5 EuclideanDistance (org.apache.ignite.ml.math.distances.EuclideanDistance)5 TrainerTest (org.apache.ignite.ml.common.TrainerTest)4 Test (org.junit.Test)4 DoubleArrayVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer)3 Ignite (org.apache.ignite.Ignite)2 SandboxMLCache (org.apache.ignite.examples.ml.util.SandboxMLCache)2 LocalDatasetBuilder (org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder)2 ManhattanDistance (org.apache.ignite.ml.math.distances.ManhattanDistance)2 Random (java.util.Random)1 Rmse (org.apache.ignite.ml.selection.scoring.metric.regression.Rmse)1 Rss (org.apache.ignite.ml.selection.scoring.metric.regression.Rss)1 SHA256UniformMapper (org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper)1