use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.
the class LinearRegressionModelTest method testPredict.
/**
*/
@Test
public void testPredict() {
Vector weights = new DenseVector(new double[] { 2.0, 3.0 });
LinearRegressionModel mdl = new LinearRegressionModel(weights, 1.0);
assertTrue(!mdl.toString().isEmpty());
assertTrue(!mdl.toString(true).isEmpty());
assertTrue(!mdl.toString(false).isEmpty());
Vector observation = new DenseVector(new double[] { 1.0, 1.0 });
TestUtils.assertEquals(1.0 + 2.0 * 1.0 + 3.0 * 1.0, mdl.predict(observation), PRECISION);
observation = new DenseVector(new double[] { 2.0, 1.0 });
TestUtils.assertEquals(1.0 + 2.0 * 2.0 + 3.0 * 1.0, mdl.predict(observation), PRECISION);
observation = new DenseVector(new double[] { 1.0, 2.0 });
TestUtils.assertEquals(1.0 + 2.0 * 1.0 + 3.0 * 2.0, mdl.predict(observation), PRECISION);
observation = new DenseVector(new double[] { -2.0, 1.0 });
TestUtils.assertEquals(1.0 - 2.0 * 2.0 + 3.0 * 1.0, mdl.predict(observation), PRECISION);
observation = new DenseVector(new double[] { 1.0, -2.0 });
TestUtils.assertEquals(1.0 + 2.0 * 1.0 - 3.0 * 2.0, mdl.predict(observation), PRECISION);
}
use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.
the class LogisticRegressionModelTest method testPredict.
/**
*/
@Test
public void testPredict() {
Vector weights = new DenseVector(new double[] { 2.0, 3.0 });
assertFalse(new LogisticRegressionModel(weights, 1.0).isKeepingRawLabels());
assertEquals(0.1, new LogisticRegressionModel(weights, 1.0).withThreshold(0.1).threshold(), 0);
assertTrue(!new LogisticRegressionModel(weights, 1.0).toString().isEmpty());
assertTrue(!new LogisticRegressionModel(weights, 1.0).toString(true).isEmpty());
assertTrue(!new LogisticRegressionModel(weights, 1.0).toString(false).isEmpty());
verifyPredict(new LogisticRegressionModel(weights, 1.0).withRawLabels(true));
verifyPredict(new LogisticRegressionModel(null, 1.0).withRawLabels(true).withWeights(weights));
verifyPredict(new LogisticRegressionModel(weights, 1.0).withRawLabels(true).withThreshold(0.5));
verifyPredict(new LogisticRegressionModel(weights, 0.0).withRawLabels(true).withIntercept(1.0));
}
use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.
the class LogisticRegressionModelTest method verifyPredict.
/**
*/
private void verifyPredict(LogisticRegressionModel mdl) {
Vector observation = new DenseVector(new double[] { 1.0, 1.0 });
TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 + 3.0 * 1.0), mdl.predict(observation), PRECISION);
observation = new DenseVector(new double[] { 2.0, 1.0 });
TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 2.0 + 3.0 * 1.0), mdl.predict(observation), PRECISION);
observation = new DenseVector(new double[] { 1.0, 2.0 });
TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 + 3.0 * 2.0), mdl.predict(observation), PRECISION);
observation = new DenseVector(new double[] { -2.0, 1.0 });
TestUtils.assertEquals(sigmoid(1.0 - 2.0 * 2.0 + 3.0 * 1.0), mdl.predict(observation), PRECISION);
observation = new DenseVector(new double[] { 1.0, -2.0 });
TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 - 3.0 * 2.0), mdl.predict(observation), PRECISION);
}
use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.
the class PipelineMdlTest method testPredict.
/**
*/
@Test
public void testPredict() {
Vector weights = new DenseVector(new double[] { 2.0, 3.0 });
verifyPredict(getMdl(new LogisticRegressionModel(weights, 1.0).withRawLabels(true)));
}
use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.
the class VectorUtils method concat.
/**
* Concatenates given vectors.
*
* @param vs Other vectors.
* @return Concatenation result.
*/
public static Vector concat(Vector... vs) {
Vector res = vs.length == 0 ? new DenseVector() : vs[0];
for (int i = 1; i < vs.length; i++) {
Vector v = vs[i];
res = concat(res, v);
}
return res;
}
Aggregations