Search in sources :

Example 16 with LogisticRegressionModel

use of org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel in project ignite by apache.

the class OneVsRestTrainerTest method testTrainWithTheLinearlySeparableCase.

/**
 * Test trainer on 2 linearly separable sets.
 */
@Test
public void testTrainWithTheLinearlySeparableCase() {
    Map<Integer, double[]> cacheMock = new HashMap<>();
    for (int i = 0; i < twoLinearlySeparableClasses.length; i++) cacheMock.put(i, twoLinearlySeparableClasses[i]);
    LogisticRegressionSGDTrainer binaryTrainer = new LogisticRegressionSGDTrainer().withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG)).withMaxIterations(1000).withLocIterations(10).withBatchSize(100).withSeed(123L);
    OneVsRestTrainer<LogisticRegressionModel> trainer = new OneVsRestTrainer<>(binaryTrainer);
    MultiClassModel mdl = trainer.fit(cacheMock, parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST));
    Assert.assertTrue(!mdl.toString().isEmpty());
    Assert.assertTrue(!mdl.toString(true).isEmpty());
    Assert.assertTrue(!mdl.toString(false).isEmpty());
    TestUtils.assertEquals(1, mdl.predict(VectorUtils.of(-100, 0)), PRECISION);
    TestUtils.assertEquals(0, mdl.predict(VectorUtils.of(100, 0)), PRECISION);
}
Also used : LogisticRegressionSGDTrainer(org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer) DoubleArrayVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer) HashMap(java.util.HashMap) SimpleGDUpdateCalculator(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator) LogisticRegressionModel(org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Example 17 with LogisticRegressionModel

use of org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel in project ignite by apache.

the class OneVsRestTrainerTest method testUpdate.

/**
 */
@Test
public void testUpdate() {
    Map<Integer, double[]> cacheMock = new HashMap<>();
    for (int i = 0; i < twoLinearlySeparableClasses.length; i++) cacheMock.put(i, twoLinearlySeparableClasses[i]);
    LogisticRegressionSGDTrainer binaryTrainer = new LogisticRegressionSGDTrainer().withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG)).withMaxIterations(1000).withLocIterations(10).withBatchSize(100).withSeed(123L);
    OneVsRestTrainer<LogisticRegressionModel> trainer = new OneVsRestTrainer<>(binaryTrainer);
    Vectorizer<Integer, double[], Integer, Double> vectorizer = new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST);
    MultiClassModel originalMdl = trainer.fit(cacheMock, parts, vectorizer);
    MultiClassModel updatedOnSameDS = trainer.update(originalMdl, cacheMock, parts, vectorizer);
    MultiClassModel updatedOnEmptyDS = trainer.update(originalMdl, new HashMap<>(), parts, vectorizer);
    List<Vector> vectors = Arrays.asList(VectorUtils.of(-100, 0), VectorUtils.of(100, 0));
    for (Vector vec : vectors) {
        TestUtils.assertEquals(originalMdl.predict(vec), updatedOnSameDS.predict(vec), PRECISION);
        TestUtils.assertEquals(originalMdl.predict(vec), updatedOnEmptyDS.predict(vec), PRECISION);
    }
}
Also used : LogisticRegressionSGDTrainer(org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer) HashMap(java.util.HashMap) LogisticRegressionModel(org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel) SimpleGDUpdateCalculator(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Example 18 with LogisticRegressionModel

use of org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel in project ignite by apache.

the class PipelineMdlTest method testPredict.

/**
 */
@Test
public void testPredict() {
    Vector weights = new DenseVector(new double[] { 2.0, 3.0 });
    verifyPredict(getMdl(new LogisticRegressionModel(weights, 1.0).withRawLabels(true)));
}
Also used : LogisticRegressionModel(org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) Test(org.junit.Test)

Aggregations

LogisticRegressionModel (org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel)18 LogisticRegressionSGDTrainer (org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer)12 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)11 SimpleGDUpdateCalculator (org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator)11 Test (org.junit.Test)10 HashMap (java.util.HashMap)6 Ignite (org.apache.ignite.Ignite)6 DenseVector (org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)5 SandboxMLCache (org.apache.ignite.examples.ml.util.SandboxMLCache)3 TrainerTest (org.apache.ignite.ml.common.TrainerTest)3 ParamGrid (org.apache.ignite.ml.selection.paramgrid.ParamGrid)3 FileNotFoundException (java.io.FileNotFoundException)2 DoubleArrayVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer)2 EncoderTrainer (org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer)2 NormalizationTrainer (org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer)2 RandomStrategy (org.apache.ignite.ml.selection.paramgrid.RandomStrategy)2 IOException (java.io.IOException)1 Path (java.nio.file.Path)1 Configuration (org.apache.hadoop.conf.Configuration)1 Path (org.apache.hadoop.fs.Path)1