use of org.apache.ignite.ml.knn.regression.KNNRegressionTrainer in project ignite by apache.
the class KNNRegressionTest method testUpdate.
/**
*/
@Test
public void testUpdate() {
Map<Integer, double[]> data = new HashMap<>();
data.put(0, new double[] { 11.0, 0, 0, 0, 0, 0 });
data.put(1, new double[] { 12.0, 2.0, 0, 0, 0, 0 });
data.put(2, new double[] { 13.0, 0, 3.0, 0, 0, 0 });
data.put(3, new double[] { 14.0, 0, 0, 4.0, 0, 0 });
data.put(4, new double[] { 15.0, 0, 0, 0, 5.0, 0 });
data.put(5, new double[] { 16.0, 0, 0, 0, 0, 6.0 });
KNNRegressionTrainer trainer = new KNNRegressionTrainer().withK(1).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
KNNRegressionModel originalMdlOnEmptyDataset = trainer.fit(new HashMap<>(), parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST));
KNNRegressionModel updatedOnDataset = trainer.update(originalMdlOnEmptyDataset, data, parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST));
Vector vector = VectorUtils.of(0.0, 0.0, 0.0, 5.0, 0.0);
assertNull(originalMdlOnEmptyDataset.predict(vector));
assertEquals(Double.valueOf(15.0), updatedOnDataset.predict(vector));
}
use of org.apache.ignite.ml.knn.regression.KNNRegressionTrainer in project ignite by apache.
the class KNNRegressionTest method testSimpleRegressionWithOneNeighbour.
/**
*/
@Test
public void testSimpleRegressionWithOneNeighbour() {
Map<Integer, double[]> data = new HashMap<>();
data.put(0, new double[] { 11.0, 0.0, 0.0, 0.0, 0.0, 0.0 });
data.put(1, new double[] { 12.0, 2.0, 0.0, 0.0, 0.0, 0.0 });
data.put(2, new double[] { 13.0, 0.0, 3.0, 0.0, 0.0, 0.0 });
data.put(3, new double[] { 14.0, 0.0, 0.0, 4.0, 0.0, 0.0 });
data.put(4, new double[] { 15.0, 0.0, 0.0, 0.0, 5.0, 0.0 });
data.put(5, new double[] { 16.0, 0.0, 0.0, 0.0, 0.0, 6.0 });
KNNRegressionTrainer trainer = new KNNRegressionTrainer().withK(1).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
KNNRegressionModel knnMdl = trainer.fit(new LocalDatasetBuilder<>(data, parts), new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST));
assertEquals(15, knnMdl.predict(VectorUtils.of(0.0, 0.0, 0.0, 5.0, 0.0)), 1E-12);
}
Aggregations