Search in sources :

Example 11 with SimpleGDUpdateCalculator

use of org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator in project ignite by apache.

the class LogisticRegressionSGDTrainerTest method testUpdate.

/**
 */
@Test
public void testUpdate() {
    Map<Integer, double[]> cacheMock = new HashMap<>();
    for (int i = 0; i < twoLinearlySeparableClasses.length; i++) cacheMock.put(i, twoLinearlySeparableClasses[i]);
    LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer().withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG)).withMaxIterations(100000).withLocIterations(100).withBatchSize(10).withSeed(123L);
    LogisticRegressionModel originalMdl = trainer.fit(cacheMock, parts, new DoubleArrayVectorizer<Integer>().labeled(0));
    Vectorizer<Integer, double[], Integer, Double> vectorizer = new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST);
    LogisticRegressionModel updatedOnSameDS = trainer.update(originalMdl, cacheMock, parts, vectorizer);
    LogisticRegressionModel updatedOnEmptyDS = trainer.update(originalMdl, new HashMap<>(), parts, vectorizer);
    Vector v1 = VectorUtils.of(100, 10);
    Vector v2 = VectorUtils.of(10, 100);
    TestUtils.assertEquals(originalMdl.predict(v1), updatedOnSameDS.predict(v1), PRECISION);
    TestUtils.assertEquals(originalMdl.predict(v2), updatedOnSameDS.predict(v2), PRECISION);
    TestUtils.assertEquals(originalMdl.predict(v2), updatedOnEmptyDS.predict(v2), PRECISION);
    TestUtils.assertEquals(originalMdl.predict(v1), updatedOnEmptyDS.predict(v1), PRECISION);
}
Also used : DoubleArrayVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer) HashMap(java.util.HashMap) SimpleGDUpdateCalculator(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Example 12 with SimpleGDUpdateCalculator

use of org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator in project ignite by apache.

the class CrossValidationTest method testBasicFunctionality.

/**
 */
@Test
public void testBasicFunctionality() {
    Map<Integer, double[]> data = new HashMap<>();
    for (int i = 0; i < twoLinearlySeparableClasses.length; i++) data.put(i, twoLinearlySeparableClasses[i]);
    LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer().withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG)).withMaxIterations(100000).withLocIterations(100).withBatchSize(14).withSeed(123L);
    Vectorizer<Integer, double[], Integer, Double> vectorizer = new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST);
    DebugCrossValidation<LogisticRegressionModel, Integer, double[]> scoreCalculator = new DebugCrossValidation<>();
    int folds = 4;
    scoreCalculator.withUpstreamMap(data).withAmountOfParts(1).withTrainer(trainer).withMetric(MetricName.ACCURACY).withPreprocessor(vectorizer).withAmountOfFolds(folds).isRunningOnPipeline(false);
    double[] scores = scoreCalculator.scoreByFolds();
    assertEquals(0.8389830508474576, scores[0], 1e-6);
    assertEquals(0.9402985074626866, scores[1], 1e-6);
    assertEquals(0.8809523809523809, scores[2], 1e-6);
    assertEquals(0.9921259842519685, scores[3], 1e-6);
}
Also used : LogisticRegressionSGDTrainer(org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer) HashMap(java.util.HashMap) LogisticRegressionModel(org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel) SimpleGDUpdateCalculator(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator) Test(org.junit.Test)

Example 13 with SimpleGDUpdateCalculator

use of org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator in project ignite by apache.

the class CrossValidationTest method testRandomSearchWithPipeline.

/**
 */
@Test
public void testRandomSearchWithPipeline() {
    Map<Integer, double[]> data = new HashMap<>();
    for (int i = 0; i < twoLinearlySeparableClasses.length; i++) data.put(i, twoLinearlySeparableClasses[i]);
    LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer().withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG)).withMaxIterations(100000).withLocIterations(100).withBatchSize(14).withSeed(123L);
    Vectorizer<Integer, double[], Integer, Double> vectorizer = new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST);
    ParamGrid paramGrid = new ParamGrid().withParameterSearchStrategy(new RandomStrategy().withMaxTries(10).withSeed(1234L).withSatisfactoryFitness(0.9)).addHyperParam("maxIterations", trainer::withMaxIterations, new Double[] { 10.0, 100.0, 1000.0, 10000.0 }).addHyperParam("locIterations", trainer::withLocIterations, new Double[] { 10.0, 100.0, 1000.0, 10000.0 }).addHyperParam("batchSize", trainer::withBatchSize, new Double[] { 1.0, 2.0, 4.0, 8.0, 16.0 });
    Pipeline<Integer, double[], Integer, Double> pipeline = new Pipeline<Integer, double[], Integer, Double>().addVectorizer(vectorizer).addTrainer(trainer);
    DebugCrossValidation<LogisticRegressionModel, Integer, double[]> scoreCalculator = (DebugCrossValidation<LogisticRegressionModel, Integer, double[]>) new DebugCrossValidation<LogisticRegressionModel, Integer, double[]>().withUpstreamMap(data).withAmountOfParts(1).withPipeline(pipeline).withMetric(MetricName.ACCURACY).withPreprocessor(vectorizer).withAmountOfFolds(4).isRunningOnPipeline(true).withParamGrid(paramGrid);
    CrossValidationResult crossValidationRes = scoreCalculator.tuneHyperParameters();
    assertEquals(crossValidationRes.getBestAvgScore(), 0.9343858500738256, 1e-6);
    assertEquals(crossValidationRes.getScoringBoard().size(), 10);
}
Also used : LogisticRegressionSGDTrainer(org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer) HashMap(java.util.HashMap) LogisticRegressionModel(org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel) Pipeline(org.apache.ignite.ml.pipeline.Pipeline) ParamGrid(org.apache.ignite.ml.selection.paramgrid.ParamGrid) SimpleGDUpdateCalculator(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator) RandomStrategy(org.apache.ignite.ml.selection.paramgrid.RandomStrategy) Test(org.junit.Test)

Example 14 with SimpleGDUpdateCalculator

use of org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator in project ignite by apache.

the class CrossValidationTest method testRandomSearch.

/**
 */
@Test
public void testRandomSearch() {
    Map<Integer, double[]> data = new HashMap<>();
    for (int i = 0; i < twoLinearlySeparableClasses.length; i++) data.put(i, twoLinearlySeparableClasses[i]);
    LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer().withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG)).withMaxIterations(100000).withLocIterations(100).withBatchSize(14).withSeed(123L);
    Vectorizer<Integer, double[], Integer, Double> vectorizer = new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST);
    ParamGrid paramGrid = new ParamGrid().withParameterSearchStrategy(new RandomStrategy().withMaxTries(10).withSeed(1234L).withSatisfactoryFitness(0.9)).addHyperParam("maxIterations", trainer::withMaxIterations, new Double[] { 10.0, 100.0, 1000.0, 10000.0 }).addHyperParam("locIterations", trainer::withLocIterations, new Double[] { 10.0, 100.0, 1000.0, 10000.0 }).addHyperParam("batchSize", trainer::withBatchSize, new Double[] { 1.0, 2.0, 4.0, 8.0, 16.0 });
    DebugCrossValidation<LogisticRegressionModel, Integer, double[]> scoreCalculator = (DebugCrossValidation<LogisticRegressionModel, Integer, double[]>) new DebugCrossValidation<LogisticRegressionModel, Integer, double[]>().withUpstreamMap(data).withAmountOfParts(1).withTrainer(trainer).withMetric(MetricName.ACCURACY).withPreprocessor(vectorizer).withAmountOfFolds(4).isRunningOnPipeline(false).withParamGrid(paramGrid);
    CrossValidationResult crossValidationRes = scoreCalculator.tuneHyperParameters();
    assertEquals(crossValidationRes.getBestAvgScore(), 0.9343858500738256, 1e-6);
    assertEquals(crossValidationRes.getScoringBoard().size(), 10);
}
Also used : LogisticRegressionSGDTrainer(org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer) HashMap(java.util.HashMap) LogisticRegressionModel(org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel) ParamGrid(org.apache.ignite.ml.selection.paramgrid.ParamGrid) SimpleGDUpdateCalculator(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator) RandomStrategy(org.apache.ignite.ml.selection.paramgrid.RandomStrategy) Test(org.junit.Test)

Example 15 with SimpleGDUpdateCalculator

use of org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator in project ignite by apache.

the class OneVsRestTrainerTest method testTrainWithTheLinearlySeparableCase.

/**
 * Test trainer on 2 linearly separable sets.
 */
@Test
public void testTrainWithTheLinearlySeparableCase() {
    Map<Integer, double[]> cacheMock = new HashMap<>();
    for (int i = 0; i < twoLinearlySeparableClasses.length; i++) cacheMock.put(i, twoLinearlySeparableClasses[i]);
    LogisticRegressionSGDTrainer binaryTrainer = new LogisticRegressionSGDTrainer().withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG)).withMaxIterations(1000).withLocIterations(10).withBatchSize(100).withSeed(123L);
    OneVsRestTrainer<LogisticRegressionModel> trainer = new OneVsRestTrainer<>(binaryTrainer);
    MultiClassModel mdl = trainer.fit(cacheMock, parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST));
    Assert.assertTrue(!mdl.toString().isEmpty());
    Assert.assertTrue(!mdl.toString(true).isEmpty());
    Assert.assertTrue(!mdl.toString(false).isEmpty());
    TestUtils.assertEquals(1, mdl.predict(VectorUtils.of(-100, 0)), PRECISION);
    TestUtils.assertEquals(0, mdl.predict(VectorUtils.of(100, 0)), PRECISION);
}
Also used : LogisticRegressionSGDTrainer(org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer) DoubleArrayVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer) HashMap(java.util.HashMap) SimpleGDUpdateCalculator(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator) LogisticRegressionModel(org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Aggregations

SimpleGDUpdateCalculator (org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator)16 LogisticRegressionSGDTrainer (org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer)12 LogisticRegressionModel (org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel)11 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)10 Test (org.junit.Test)10 HashMap (java.util.HashMap)7 Ignite (org.apache.ignite.Ignite)6 TrainerTest (org.apache.ignite.ml.common.TrainerTest)6 DoubleArrayVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer)5 SandboxMLCache (org.apache.ignite.examples.ml.util.SandboxMLCache)3 MLPArchitecture (org.apache.ignite.ml.nn.architecture.MLPArchitecture)3 SimpleGDParameterUpdate (org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate)3 ParamGrid (org.apache.ignite.ml.selection.paramgrid.ParamGrid)3 FileNotFoundException (java.io.FileNotFoundException)2 OnMajorityPredictionsAggregator (org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator)2 StackedVectorDatasetTrainer (org.apache.ignite.ml.composition.stacking.StackedVectorDatasetTrainer)2 Matrix (org.apache.ignite.ml.math.primitives.matrix.Matrix)2 DenseMatrix (org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix)2 VectorUtils (org.apache.ignite.ml.math.primitives.vector.VectorUtils)2 MLPTrainer (org.apache.ignite.ml.nn.MLPTrainer)2