use of org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel in project ignite by apache.
the class MLDeployingTest method fitAndTestModel.
/**
*/
private void fitAndTestModel(CacheBasedDatasetBuilder<Integer, Vector> datasetBuilder, Preprocessor<Integer, Vector> preprocessor) {
LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer();
LogisticRegressionModel mdl = trainer.fit(datasetBuilder, preprocessor);
// For this case any answer is valid.
assertEquals(0., mdl.predict(VectorUtils.of(0., 0.)), 1.);
}
use of org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel in project ignite by apache.
the class SparkModelParser method loadLogRegModel.
/**
* Load logistic regression model.
*
* @param pathToMdl Path to model.
* @param learningEnvironment Learning environment.
*/
private static Model loadLogRegModel(String pathToMdl, LearningEnvironment learningEnvironment) {
Vector coefficients = null;
double interceptor = 0;
try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdl), new Configuration()))) {
PageReadStore pages;
final MessageType schema = r.getFooter().getFileMetaData().getSchema();
final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
while (null != (pages = r.readNextRowGroup())) {
final long rows = pages.getRowCount();
final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
for (int i = 0; i < rows; i++) {
final SimpleGroup g = (SimpleGroup) recordReader.read();
interceptor = readInterceptor(g);
coefficients = readCoefficients(g);
}
}
} catch (IOException e) {
String msg = "Error reading parquet file: " + e.getMessage();
learningEnvironment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
e.printStackTrace();
}
return new LogisticRegressionModel(coefficients, interceptor);
}
use of org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel in project ignite by apache.
the class CrossValidationTest method testBasicFunctionality.
/**
*/
@Test
public void testBasicFunctionality() {
Map<Integer, double[]> data = new HashMap<>();
for (int i = 0; i < twoLinearlySeparableClasses.length; i++) data.put(i, twoLinearlySeparableClasses[i]);
LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer().withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG)).withMaxIterations(100000).withLocIterations(100).withBatchSize(14).withSeed(123L);
Vectorizer<Integer, double[], Integer, Double> vectorizer = new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST);
DebugCrossValidation<LogisticRegressionModel, Integer, double[]> scoreCalculator = new DebugCrossValidation<>();
int folds = 4;
scoreCalculator.withUpstreamMap(data).withAmountOfParts(1).withTrainer(trainer).withMetric(MetricName.ACCURACY).withPreprocessor(vectorizer).withAmountOfFolds(folds).isRunningOnPipeline(false);
double[] scores = scoreCalculator.scoreByFolds();
assertEquals(0.8389830508474576, scores[0], 1e-6);
assertEquals(0.9402985074626866, scores[1], 1e-6);
assertEquals(0.8809523809523809, scores[2], 1e-6);
assertEquals(0.9921259842519685, scores[3], 1e-6);
}
use of org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel in project ignite by apache.
the class CrossValidationTest method testRandomSearchWithPipeline.
/**
*/
@Test
public void testRandomSearchWithPipeline() {
Map<Integer, double[]> data = new HashMap<>();
for (int i = 0; i < twoLinearlySeparableClasses.length; i++) data.put(i, twoLinearlySeparableClasses[i]);
LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer().withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG)).withMaxIterations(100000).withLocIterations(100).withBatchSize(14).withSeed(123L);
Vectorizer<Integer, double[], Integer, Double> vectorizer = new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST);
ParamGrid paramGrid = new ParamGrid().withParameterSearchStrategy(new RandomStrategy().withMaxTries(10).withSeed(1234L).withSatisfactoryFitness(0.9)).addHyperParam("maxIterations", trainer::withMaxIterations, new Double[] { 10.0, 100.0, 1000.0, 10000.0 }).addHyperParam("locIterations", trainer::withLocIterations, new Double[] { 10.0, 100.0, 1000.0, 10000.0 }).addHyperParam("batchSize", trainer::withBatchSize, new Double[] { 1.0, 2.0, 4.0, 8.0, 16.0 });
Pipeline<Integer, double[], Integer, Double> pipeline = new Pipeline<Integer, double[], Integer, Double>().addVectorizer(vectorizer).addTrainer(trainer);
DebugCrossValidation<LogisticRegressionModel, Integer, double[]> scoreCalculator = (DebugCrossValidation<LogisticRegressionModel, Integer, double[]>) new DebugCrossValidation<LogisticRegressionModel, Integer, double[]>().withUpstreamMap(data).withAmountOfParts(1).withPipeline(pipeline).withMetric(MetricName.ACCURACY).withPreprocessor(vectorizer).withAmountOfFolds(4).isRunningOnPipeline(true).withParamGrid(paramGrid);
CrossValidationResult crossValidationRes = scoreCalculator.tuneHyperParameters();
assertEquals(crossValidationRes.getBestAvgScore(), 0.9343858500738256, 1e-6);
assertEquals(crossValidationRes.getScoringBoard().size(), 10);
}
use of org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel in project ignite by apache.
the class CrossValidationTest method testRandomSearch.
/**
*/
@Test
public void testRandomSearch() {
Map<Integer, double[]> data = new HashMap<>();
for (int i = 0; i < twoLinearlySeparableClasses.length; i++) data.put(i, twoLinearlySeparableClasses[i]);
LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer().withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG)).withMaxIterations(100000).withLocIterations(100).withBatchSize(14).withSeed(123L);
Vectorizer<Integer, double[], Integer, Double> vectorizer = new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST);
ParamGrid paramGrid = new ParamGrid().withParameterSearchStrategy(new RandomStrategy().withMaxTries(10).withSeed(1234L).withSatisfactoryFitness(0.9)).addHyperParam("maxIterations", trainer::withMaxIterations, new Double[] { 10.0, 100.0, 1000.0, 10000.0 }).addHyperParam("locIterations", trainer::withLocIterations, new Double[] { 10.0, 100.0, 1000.0, 10000.0 }).addHyperParam("batchSize", trainer::withBatchSize, new Double[] { 1.0, 2.0, 4.0, 8.0, 16.0 });
DebugCrossValidation<LogisticRegressionModel, Integer, double[]> scoreCalculator = (DebugCrossValidation<LogisticRegressionModel, Integer, double[]>) new DebugCrossValidation<LogisticRegressionModel, Integer, double[]>().withUpstreamMap(data).withAmountOfParts(1).withTrainer(trainer).withMetric(MetricName.ACCURACY).withPreprocessor(vectorizer).withAmountOfFolds(4).isRunningOnPipeline(false).withParamGrid(paramGrid);
CrossValidationResult crossValidationRes = scoreCalculator.tuneHyperParameters();
assertEquals(crossValidationRes.getBestAvgScore(), 0.9343858500738256, 1e-6);
assertEquals(crossValidationRes.getScoringBoard().size(), 10);
}
Aggregations