Search in sources :

Example 11 with LogisticRegressionModel

use of org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel in project ignite by apache.

the class MLDeployingTest method fitAndTestModel.

/**
 */
private void fitAndTestModel(CacheBasedDatasetBuilder<Integer, Vector> datasetBuilder, Preprocessor<Integer, Vector> preprocessor) {
    LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer();
    LogisticRegressionModel mdl = trainer.fit(datasetBuilder, preprocessor);
    // For this case any answer is valid.
    assertEquals(0., mdl.predict(VectorUtils.of(0., 0.)), 1.);
}
Also used : LogisticRegressionSGDTrainer(org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer) LogisticRegressionModel(org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel)

Example 12 with LogisticRegressionModel

use of org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel in project ignite by apache.

the class SparkModelParser method loadLogRegModel.

/**
 * Load logistic regression model.
 *
 * @param pathToMdl Path to model.
 * @param learningEnvironment Learning environment.
 */
private static Model loadLogRegModel(String pathToMdl, LearningEnvironment learningEnvironment) {
    Vector coefficients = null;
    double interceptor = 0;
    try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdl), new Configuration()))) {
        PageReadStore pages;
        final MessageType schema = r.getFooter().getFileMetaData().getSchema();
        final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
        while (null != (pages = r.readNextRowGroup())) {
            final long rows = pages.getRowCount();
            final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
            for (int i = 0; i < rows; i++) {
                final SimpleGroup g = (SimpleGroup) recordReader.read();
                interceptor = readInterceptor(g);
                coefficients = readCoefficients(g);
            }
        }
    } catch (IOException e) {
        String msg = "Error reading parquet file: " + e.getMessage();
        learningEnvironment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
        e.printStackTrace();
    }
    return new LogisticRegressionModel(coefficients, interceptor);
}
Also used : Path(org.apache.hadoop.fs.Path) GroupRecordConverter(org.apache.parquet.example.data.simple.convert.GroupRecordConverter) Configuration(org.apache.hadoop.conf.Configuration) ParquetFileReader(org.apache.parquet.hadoop.ParquetFileReader) RecordReader(org.apache.parquet.io.RecordReader) SimpleGroup(org.apache.parquet.example.data.simple.SimpleGroup) LogisticRegressionModel(org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel) IOException(java.io.IOException) MessageColumnIO(org.apache.parquet.io.MessageColumnIO) ColumnIOFactory(org.apache.parquet.io.ColumnIOFactory) PageReadStore(org.apache.parquet.column.page.PageReadStore) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) MessageType(org.apache.parquet.schema.MessageType)

Example 13 with LogisticRegressionModel

use of org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel in project ignite by apache.

the class CrossValidationTest method testBasicFunctionality.

/**
 */
@Test
public void testBasicFunctionality() {
    Map<Integer, double[]> data = new HashMap<>();
    for (int i = 0; i < twoLinearlySeparableClasses.length; i++) data.put(i, twoLinearlySeparableClasses[i]);
    LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer().withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG)).withMaxIterations(100000).withLocIterations(100).withBatchSize(14).withSeed(123L);
    Vectorizer<Integer, double[], Integer, Double> vectorizer = new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST);
    DebugCrossValidation<LogisticRegressionModel, Integer, double[]> scoreCalculator = new DebugCrossValidation<>();
    int folds = 4;
    scoreCalculator.withUpstreamMap(data).withAmountOfParts(1).withTrainer(trainer).withMetric(MetricName.ACCURACY).withPreprocessor(vectorizer).withAmountOfFolds(folds).isRunningOnPipeline(false);
    double[] scores = scoreCalculator.scoreByFolds();
    assertEquals(0.8389830508474576, scores[0], 1e-6);
    assertEquals(0.9402985074626866, scores[1], 1e-6);
    assertEquals(0.8809523809523809, scores[2], 1e-6);
    assertEquals(0.9921259842519685, scores[3], 1e-6);
}
Also used : LogisticRegressionSGDTrainer(org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer) HashMap(java.util.HashMap) LogisticRegressionModel(org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel) SimpleGDUpdateCalculator(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator) Test(org.junit.Test)

Example 14 with LogisticRegressionModel

use of org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel in project ignite by apache.

the class CrossValidationTest method testRandomSearchWithPipeline.

/**
 */
@Test
public void testRandomSearchWithPipeline() {
    Map<Integer, double[]> data = new HashMap<>();
    for (int i = 0; i < twoLinearlySeparableClasses.length; i++) data.put(i, twoLinearlySeparableClasses[i]);
    LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer().withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG)).withMaxIterations(100000).withLocIterations(100).withBatchSize(14).withSeed(123L);
    Vectorizer<Integer, double[], Integer, Double> vectorizer = new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST);
    ParamGrid paramGrid = new ParamGrid().withParameterSearchStrategy(new RandomStrategy().withMaxTries(10).withSeed(1234L).withSatisfactoryFitness(0.9)).addHyperParam("maxIterations", trainer::withMaxIterations, new Double[] { 10.0, 100.0, 1000.0, 10000.0 }).addHyperParam("locIterations", trainer::withLocIterations, new Double[] { 10.0, 100.0, 1000.0, 10000.0 }).addHyperParam("batchSize", trainer::withBatchSize, new Double[] { 1.0, 2.0, 4.0, 8.0, 16.0 });
    Pipeline<Integer, double[], Integer, Double> pipeline = new Pipeline<Integer, double[], Integer, Double>().addVectorizer(vectorizer).addTrainer(trainer);
    DebugCrossValidation<LogisticRegressionModel, Integer, double[]> scoreCalculator = (DebugCrossValidation<LogisticRegressionModel, Integer, double[]>) new DebugCrossValidation<LogisticRegressionModel, Integer, double[]>().withUpstreamMap(data).withAmountOfParts(1).withPipeline(pipeline).withMetric(MetricName.ACCURACY).withPreprocessor(vectorizer).withAmountOfFolds(4).isRunningOnPipeline(true).withParamGrid(paramGrid);
    CrossValidationResult crossValidationRes = scoreCalculator.tuneHyperParameters();
    assertEquals(crossValidationRes.getBestAvgScore(), 0.9343858500738256, 1e-6);
    assertEquals(crossValidationRes.getScoringBoard().size(), 10);
}
Also used : LogisticRegressionSGDTrainer(org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer) HashMap(java.util.HashMap) LogisticRegressionModel(org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel) Pipeline(org.apache.ignite.ml.pipeline.Pipeline) ParamGrid(org.apache.ignite.ml.selection.paramgrid.ParamGrid) SimpleGDUpdateCalculator(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator) RandomStrategy(org.apache.ignite.ml.selection.paramgrid.RandomStrategy) Test(org.junit.Test)

Example 15 with LogisticRegressionModel

use of org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel in project ignite by apache.

the class CrossValidationTest method testRandomSearch.

/**
 */
@Test
public void testRandomSearch() {
    Map<Integer, double[]> data = new HashMap<>();
    for (int i = 0; i < twoLinearlySeparableClasses.length; i++) data.put(i, twoLinearlySeparableClasses[i]);
    LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer().withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG)).withMaxIterations(100000).withLocIterations(100).withBatchSize(14).withSeed(123L);
    Vectorizer<Integer, double[], Integer, Double> vectorizer = new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST);
    ParamGrid paramGrid = new ParamGrid().withParameterSearchStrategy(new RandomStrategy().withMaxTries(10).withSeed(1234L).withSatisfactoryFitness(0.9)).addHyperParam("maxIterations", trainer::withMaxIterations, new Double[] { 10.0, 100.0, 1000.0, 10000.0 }).addHyperParam("locIterations", trainer::withLocIterations, new Double[] { 10.0, 100.0, 1000.0, 10000.0 }).addHyperParam("batchSize", trainer::withBatchSize, new Double[] { 1.0, 2.0, 4.0, 8.0, 16.0 });
    DebugCrossValidation<LogisticRegressionModel, Integer, double[]> scoreCalculator = (DebugCrossValidation<LogisticRegressionModel, Integer, double[]>) new DebugCrossValidation<LogisticRegressionModel, Integer, double[]>().withUpstreamMap(data).withAmountOfParts(1).withTrainer(trainer).withMetric(MetricName.ACCURACY).withPreprocessor(vectorizer).withAmountOfFolds(4).isRunningOnPipeline(false).withParamGrid(paramGrid);
    CrossValidationResult crossValidationRes = scoreCalculator.tuneHyperParameters();
    assertEquals(crossValidationRes.getBestAvgScore(), 0.9343858500738256, 1e-6);
    assertEquals(crossValidationRes.getScoringBoard().size(), 10);
}
Also used : LogisticRegressionSGDTrainer(org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer) HashMap(java.util.HashMap) LogisticRegressionModel(org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel) ParamGrid(org.apache.ignite.ml.selection.paramgrid.ParamGrid) SimpleGDUpdateCalculator(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator) RandomStrategy(org.apache.ignite.ml.selection.paramgrid.RandomStrategy) Test(org.junit.Test)

Aggregations

LogisticRegressionModel (org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel)18 LogisticRegressionSGDTrainer (org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer)12 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)11 SimpleGDUpdateCalculator (org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator)11 Test (org.junit.Test)10 HashMap (java.util.HashMap)6 Ignite (org.apache.ignite.Ignite)6 DenseVector (org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)5 SandboxMLCache (org.apache.ignite.examples.ml.util.SandboxMLCache)3 TrainerTest (org.apache.ignite.ml.common.TrainerTest)3 ParamGrid (org.apache.ignite.ml.selection.paramgrid.ParamGrid)3 FileNotFoundException (java.io.FileNotFoundException)2 DoubleArrayVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer)2 EncoderTrainer (org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer)2 NormalizationTrainer (org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer)2 RandomStrategy (org.apache.ignite.ml.selection.paramgrid.RandomStrategy)2 IOException (java.io.IOException)1 Path (java.nio.file.Path)1 Configuration (org.apache.hadoop.conf.Configuration)1 Path (org.apache.hadoop.fs.Path)1