use of org.apache.ignite.ml.regressions.linear.LinearRegressionModel in project ignite by apache.
the class LinearRegressionSGDTrainerExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws IOException {
System.out.println();
System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
IgniteCache<Integer, Vector> dataCache = null;
try {
dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.MORTALITY_DATA);
System.out.println(">>> Create new linear regression trainer object.");
LinearRegressionSGDTrainer<?> trainer = new LinearRegressionSGDTrainer<>(new UpdatesStrategy<>(new RPropUpdateCalculator(), RPropParameterUpdate.SUM_LOCAL, RPropParameterUpdate.AVG), 100000, 10, 100, 123L);
System.out.println(">>> Perform the training to get the model.");
Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST);
LinearRegressionModel mdl = trainer.fit(ignite, dataCache, vectorizer);
System.out.println(">>> Linear regression model: " + mdl);
double rmse = Evaluator.evaluate(dataCache, mdl, vectorizer, MetricName.RMSE);
System.out.println("\n>>> Rmse = " + rmse);
System.out.println(">>> ---------------------------------");
System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
} finally {
if (dataCache != null)
dataCache.destroy();
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.regressions.linear.LinearRegressionModel in project ignite by apache.
the class LinearRegressionExportImportExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws IOException {
System.out.println();
System.out.println(">>> Linear regression model over cache based dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
IgniteCache<Integer, Vector> dataCache = null;
Path jsonMdlPath = null;
try {
dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.MORTALITY_DATA);
System.out.println("\n>>> Create new linear regression trainer object.");
LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
System.out.println("\n>>> Perform the training to get the model.");
LinearRegressionModel mdl = trainer.fit(ignite, dataCache, new DummyVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST));
System.out.println("\n>>> Exported LinearRegression model: " + mdl);
double rmse = Evaluator.evaluate(dataCache, mdl, new DummyVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST), MetricName.RMSE);
System.out.println("\n>>> RMSE for exported LinearRegression model: " + rmse);
jsonMdlPath = Files.createTempFile(null, null);
mdl.toJSON(jsonMdlPath);
LinearRegressionModel modelImportedFromJSON = LinearRegressionModel.fromJSON(jsonMdlPath);
System.out.println("\n>>> Imported LinearRegression model: " + modelImportedFromJSON);
rmse = Evaluator.evaluate(dataCache, mdl, new DummyVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST), MetricName.RMSE);
System.out.println("\n>>> RMSE for imported LinearRegression model: " + rmse);
System.out.println("\n>>> Linear regression model over cache based dataset usage example completed.");
} finally {
if (dataCache != null)
dataCache.destroy();
if (jsonMdlPath != null)
Files.deleteIfExists(jsonMdlPath);
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.regressions.linear.LinearRegressionModel in project ignite by apache.
the class SparkModelParser method loadLinRegModel.
/**
* Load linear regression model.
*
* @param pathToMdl Path to model.
* @param learningEnvironment Learning environment.
*/
private static Model loadLinRegModel(String pathToMdl, LearningEnvironment learningEnvironment) {
Vector coefficients = null;
double interceptor = 0;
try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdl), new Configuration()))) {
PageReadStore pages;
final MessageType schema = r.getFooter().getFileMetaData().getSchema();
final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
while (null != (pages = r.readNextRowGroup())) {
final long rows = pages.getRowCount();
final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
for (int i = 0; i < rows; i++) {
final SimpleGroup g = (SimpleGroup) recordReader.read();
interceptor = readLinRegInterceptor(g);
coefficients = readLinRegCoefficients(g);
}
}
} catch (IOException e) {
String msg = "Error reading parquet file: " + e.getMessage();
learningEnvironment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
e.printStackTrace();
}
return new LinearRegressionModel(coefficients, interceptor);
}
use of org.apache.ignite.ml.regressions.linear.LinearRegressionModel in project ignite by apache.
the class StackingTest method testSimpleVectorStack.
/**
* Tests simple stack training.
*/
@Test
public void testSimpleVectorStack() {
StackedVectorDatasetTrainer<Double, LinearRegressionModel, Double> trainer = new StackedVectorDatasetTrainer<>();
UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG);
MLPArchitecture arch = new MLPArchitecture(2).withAddedLayer(10, true, Activators.RELU).withAddedLayer(1, false, Activators.SIGMOID);
DatasetTrainer<MultilayerPerceptron, Double> mlpTrainer = new MLPTrainer<>(arch, LossFunctions.MSE, updatesStgy, 3000, 10, 50, 123L).withConvertedLabels(VectorUtils::num2Arr);
final double factor = 3;
StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = trainer.withAggregatorTrainer(new LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor)).addMatrix2MatrixTrainer(mlpTrainer).withEnvironmentBuilder(TestUtils.testEnvBuilder()).fit(getCacheMock(xor), parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST));
assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(0.0, 0.0)), 0.3);
assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(0.0, 1.0)), 0.3);
assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(1.0, 0.0)), 0.3);
assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(1.0, 1.0)), 0.3);
}
Aggregations