use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.
the class KNNClassificationTest method testBinaryClassification.
/**
*/
@Test
public void testBinaryClassification() {
Map<Integer, double[]> data = new HashMap<>();
data.put(0, new double[] { 1.0, 1.0, 1.0 });
data.put(1, new double[] { 1.0, 2.0, 1.0 });
data.put(2, new double[] { 2.0, 1.0, 1.0 });
data.put(3, new double[] { -1.0, -1.0, 2.0 });
data.put(4, new double[] { -1.0, -2.0, 2.0 });
data.put(5, new double[] { -2.0, -1.0, 2.0 });
KNNClassificationTrainer trainer = new KNNClassificationTrainer().withK(3).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
KNNClassificationModel knnMdl = trainer.fit(data, parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST));
assertTrue(!knnMdl.toString().isEmpty());
assertTrue(!knnMdl.toString(true).isEmpty());
assertTrue(!knnMdl.toString(false).isEmpty());
Vector firstVector = VectorUtils.of(2.0, 2.0);
assertEquals(1.0, knnMdl.predict(firstVector), 0);
Vector secondVector = VectorUtils.of(-2.0, -2.0);
assertEquals(2.0, knnMdl.predict(secondVector), 0);
}
use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.
the class BlasTest method testSyrNonSquareMatrix.
/**
* Tests 'syr' operation for non-square dense matrix A.
*/
@Test(expected = NonSquareMatrixException.class)
public void testSyrNonSquareMatrix() {
double alpha = 3.0;
DenseMatrix a = new DenseMatrix(new double[][] { { 10.0, 11.0, 12.0 }, { 0.0, 1.0, 2.0 } }, 2);
Vector x = new DenseVector(new double[] { 1.0, 2.0 });
new Blas().syr(alpha, x, a);
}
use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.
the class BlasTest method testScalSparse.
/**
* Test 'scal' operation for sparse matrix.
*/
@Test
public void testScalSparse() {
double[] data = new double[] { 1.0, 1.0 };
double alpha = 2.0;
SparseVector v = sparseFromArray(data);
Vector exp = sparseFromArray(data).times(alpha);
Blas.scal(alpha, v);
Assert.assertEquals(v, exp);
}
use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.
the class MLDeployingTest method createVectorizer.
/**
*/
private Vectorizer<Integer, Vector, Integer, Double> createVectorizer() throws ClassNotFoundException, NoSuchMethodException, InstantiationException, IllegalAccessException, java.lang.reflect.InvocationTargetException {
ClassLoader ldr = getExternalClassLoader();
Class<?> clazz = ldr.loadClass(EXT_VECTORIZER);
Constructor ctor = clazz.getConstructor();
Vectorizer<Integer, Vector, Integer, Double> vectorizer = (Vectorizer<Integer, Vector, Integer, Double>) ctor.newInstance();
vectorizer = vectorizer.labeled(Vectorizer.LabelCoordinate.LAST);
return vectorizer;
}
use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.
the class KNNRegressionTest method testLongly.
/**
*/
private void testLongly(boolean weighted) {
Map<Integer, double[]> data = new HashMap<>();
data.put(0, new double[] { 60323, 83.0, 234289, 2356, 1590, 107608, 1947 });
data.put(1, new double[] { 61122, 88.5, 259426, 2325, 1456, 108632, 1948 });
data.put(2, new double[] { 60171, 88.2, 258054, 3682, 1616, 109773, 1949 });
data.put(3, new double[] { 61187, 89.5, 284599, 3351, 1650, 110929, 1950 });
data.put(4, new double[] { 63221, 96.2, 328975, 2099, 3099, 112075, 1951 });
data.put(5, new double[] { 63639, 98.1, 346999, 1932, 3594, 113270, 1952 });
data.put(6, new double[] { 64989, 99.0, 365385, 1870, 3547, 115094, 1953 });
data.put(7, new double[] { 63761, 100.0, 363112, 3578, 3350, 116219, 1954 });
data.put(8, new double[] { 66019, 101.2, 397469, 2904, 3048, 117388, 1955 });
data.put(9, new double[] { 68169, 108.4, 442769, 2936, 2798, 120445, 1957 });
data.put(10, new double[] { 66513, 110.8, 444546, 4681, 2637, 121950, 1958 });
data.put(11, new double[] { 68655, 112.6, 482704, 3813, 2552, 123366, 1959 });
data.put(12, new double[] { 69564, 114.2, 502601, 3931, 2514, 125368, 1960 });
data.put(13, new double[] { 69331, 115.7, 518173, 4806, 2572, 127852, 1961 });
data.put(14, new double[] { 70551, 116.9, 554894, 4007, 2827, 130081, 1962 });
KNNRegressionTrainer trainer = new KNNRegressionTrainer().withK(3).withDistanceMeasure(new EuclideanDistance()).withWeighted(weighted);
KNNRegressionModel knnMdl = trainer.fit(new LocalDatasetBuilder<>(data, parts), new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST));
Vector vector = VectorUtils.of(104.6, 419180.0, 2822.0, 2857.0, 118734.0, 1956.0);
assertNotNull(knnMdl.predict(vector));
assertEquals(67857, knnMdl.predict(vector), 2000);
// Assert.assertTrue(knnMdl.toString().contains(stgy.name()));
// Assert.assertTrue(knnMdl.toString(true).contains(stgy.name()));
// Assert.assertTrue(knnMdl.toString(false).contains(stgy.name()));
}
Aggregations