use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.
the class RegressionMetricsTest method testCalculation.
/**
*/
@Test
public void testCalculation() {
Map<Vector, Double> linearSet = new HashMap<Vector, Double>() {
{
put(VectorUtils.of(0.), 0.);
put(VectorUtils.of(1.), 1.);
put(VectorUtils.of(2.), 2.);
put(VectorUtils.of(3.), 3.);
}
};
IgniteModel<Vector, Double> linearModel = v -> v.get(0);
IgniteModel<Vector, Double> squareModel = v -> Math.pow(v.get(0), 2);
EvaluationResult linearRes = Evaluator.evaluateRegression(linearSet, linearModel, Vector::labeled);
assertEquals(0., linearRes.get(MetricName.MAE), 0.01);
assertEquals(0., linearRes.get(MetricName.MSE), 0.01);
assertEquals(0., linearRes.get(MetricName.R2), 0.01);
assertEquals(0., linearRes.get(MetricName.RSS), 0.01);
assertEquals(0., linearRes.get(MetricName.RMSE), 0.01);
EvaluationResult squareRes = Evaluator.evaluateRegression(linearSet, squareModel, Vector::labeled);
assertEquals(2., squareRes.get(MetricName.MAE), 0.01);
assertEquals(10., squareRes.get(MetricName.MSE), 0.01);
assertEquals(8., squareRes.get(MetricName.R2), 0.01);
assertEquals(40., squareRes.get(MetricName.RSS), 0.01);
assertEquals(Math.sqrt(10), squareRes.get(MetricName.RMSE), 0.01);
}
use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.
the class RandomForestClassifierTrainerTest method testUpdate.
/**
*/
@Test
public void testUpdate() {
int sampleSize = 1000;
Map<Integer, LabeledVector<Double>> sample = new HashMap<>();
for (int i = 0; i < sampleSize; i++) {
double x1 = i;
double x2 = x1 / 10.0;
double x3 = x2 / 10.0;
double x4 = x3 / 10.0;
sample.put(i, VectorUtils.of(x1, x2, x3, x4).labeled((double) i % 2));
}
ArrayList<FeatureMeta> meta = new ArrayList<>();
for (int i = 0; i < 4; i++) meta.add(new FeatureMeta("", i, false));
DatasetTrainer<RandomForestModel, Double> trainer = new RandomForestClassifierTrainer(meta).withAmountOfTrees(100).withFeaturesCountSelectionStrgy(x -> 2).withEnvironmentBuilder(TestUtils.testEnvBuilder());
RandomForestModel originalMdl = trainer.fit(sample, parts, new LabeledDummyVectorizer<>());
RandomForestModel updatedOnSameDS = trainer.update(originalMdl, sample, parts, new LabeledDummyVectorizer<>());
RandomForestModel updatedOnEmptyDS = trainer.update(originalMdl, new HashMap<Integer, LabeledVector<Double>>(), parts, new LabeledDummyVectorizer<>());
Vector v = VectorUtils.of(5, 0.5, 0.05, 0.005);
assertEquals(originalMdl.predict(v), updatedOnSameDS.predict(v), 0.01);
assertEquals(originalMdl.predict(v), updatedOnEmptyDS.predict(v), 0.01);
}
use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.
the class LogisticRegressionModelTest method testPredictOnAnObservationWithWrongCardinality.
/**
*/
@Test(expected = CardinalityException.class)
public void testPredictOnAnObservationWithWrongCardinality() {
Vector weights = new DenseVector(new double[] { 2.0, 3.0 });
LogisticRegressionModel mdl = new LogisticRegressionModel(weights, 1.0);
Vector observation = new DenseVector(new double[] { 1.0 });
mdl.predict(observation);
}
use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.
the class SVMModelTest method testPredictOnAnObservationWithWrongCardinality.
/**
*/
@Test(expected = CardinalityException.class)
public void testPredictOnAnObservationWithWrongCardinality() {
Vector weights = new DenseVector(new double[] { 2.0, 3.0 });
SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(weights, 1.0);
Vector observation = new DenseVector(new double[] { 1.0 });
mdl.predict(observation);
}
use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.
the class DataStreamGeneratorTest method testAsDatasetBuilder.
/**
*/
@Test
public void testAsDatasetBuilder() throws Exception {
AtomicInteger cntr = new AtomicInteger();
DataStreamGenerator generator = new DataStreamGenerator() {
@Override
public Stream<LabeledVector<Double>> labeled() {
return Stream.generate(() -> {
int val = cntr.getAndIncrement();
return new LabeledVector<>(VectorUtils.of(val), (double) val % 2);
});
}
};
int N = 100;
cntr.set(0);
DatasetBuilder<Vector, Double> b1 = generator.asDatasetBuilder(N, 2);
cntr.set(0);
DatasetBuilder<Vector, Double> b2 = generator.asDatasetBuilder(N, (v, l) -> l == 0, 2);
cntr.set(0);
DatasetBuilder<Vector, Double> b3 = generator.asDatasetBuilder(N, (v, l) -> l == 1, 2, new UpstreamTransformerBuilder() {
@Override
public UpstreamTransformer build(LearningEnvironment env) {
return new UpstreamTransformerForTest();
}
});
checkDataset(N, b1, v -> (Double) v.label() == 0 || (Double) v.label() == 1);
checkDataset(N / 2, b2, v -> (Double) v.label() == 0);
checkDataset(N / 2, b3, v -> (Double) v.label() < 0);
}
Aggregations