use of org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel in project ignite by apache.
the class OneVsRestTrainerTest method testTrainWithTheLinearlySeparableCase.
/**
* Test trainer on 2 linearly separable sets.
*/
@Test
public void testTrainWithTheLinearlySeparableCase() {
Map<Integer, double[]> cacheMock = new HashMap<>();
for (int i = 0; i < twoLinearlySeparableClasses.length; i++) cacheMock.put(i, twoLinearlySeparableClasses[i]);
LogisticRegressionSGDTrainer binaryTrainer = new LogisticRegressionSGDTrainer().withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG)).withMaxIterations(1000).withLocIterations(10).withBatchSize(100).withSeed(123L);
OneVsRestTrainer<LogisticRegressionModel> trainer = new OneVsRestTrainer<>(binaryTrainer);
MultiClassModel mdl = trainer.fit(cacheMock, parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST));
Assert.assertTrue(!mdl.toString().isEmpty());
Assert.assertTrue(!mdl.toString(true).isEmpty());
Assert.assertTrue(!mdl.toString(false).isEmpty());
TestUtils.assertEquals(1, mdl.predict(VectorUtils.of(-100, 0)), PRECISION);
TestUtils.assertEquals(0, mdl.predict(VectorUtils.of(100, 0)), PRECISION);
}
use of org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel in project ignite by apache.
the class OneVsRestTrainerTest method testUpdate.
/**
*/
@Test
public void testUpdate() {
Map<Integer, double[]> cacheMock = new HashMap<>();
for (int i = 0; i < twoLinearlySeparableClasses.length; i++) cacheMock.put(i, twoLinearlySeparableClasses[i]);
LogisticRegressionSGDTrainer binaryTrainer = new LogisticRegressionSGDTrainer().withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG)).withMaxIterations(1000).withLocIterations(10).withBatchSize(100).withSeed(123L);
OneVsRestTrainer<LogisticRegressionModel> trainer = new OneVsRestTrainer<>(binaryTrainer);
Vectorizer<Integer, double[], Integer, Double> vectorizer = new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST);
MultiClassModel originalMdl = trainer.fit(cacheMock, parts, vectorizer);
MultiClassModel updatedOnSameDS = trainer.update(originalMdl, cacheMock, parts, vectorizer);
MultiClassModel updatedOnEmptyDS = trainer.update(originalMdl, new HashMap<>(), parts, vectorizer);
List<Vector> vectors = Arrays.asList(VectorUtils.of(-100, 0), VectorUtils.of(100, 0));
for (Vector vec : vectors) {
TestUtils.assertEquals(originalMdl.predict(vec), updatedOnSameDS.predict(vec), PRECISION);
TestUtils.assertEquals(originalMdl.predict(vec), updatedOnEmptyDS.predict(vec), PRECISION);
}
}
use of org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel in project ignite by apache.
the class PipelineMdlTest method testPredict.
/**
*/
@Test
public void testPredict() {
Vector weights = new DenseVector(new double[] { 2.0, 3.0 });
verifyPredict(getMdl(new LogisticRegressionModel(weights, 1.0).withRawLabels(true)));
}
Aggregations