use of org.apache.ignite.ml.svm.SVMLinearClassificationModel in project ignite by apache.
the class SVMFromSparkExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws FileNotFoundException {
System.out.println();
System.out.println(">>> SVM model loaded from Spark through serialization over partitioned dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
IgniteCache<Integer, Vector> dataCache = null;
try {
dataCache = TitanicUtils.readPassengers(ignite);
final Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>(0, 5, 6).labeled(1);
SVMLinearClassificationModel mdl = (SVMLinearClassificationModel) SparkModelParser.parse(SPARK_MDL_PATH, SupportedSparkModels.LINEAR_SVM, env);
System.out.println(">>> SVM: " + mdl);
double accuracy = Evaluator.evaluate(dataCache, mdl, vectorizer, new Accuracy<>());
System.out.println("\n>>> Accuracy " + accuracy);
System.out.println("\n>>> Test Error " + (1 - accuracy));
} finally {
dataCache.destroy();
}
}
}
use of org.apache.ignite.ml.svm.SVMLinearClassificationModel in project ignite by apache.
the class SparkModelParser method loadLinearSVMModel.
/**
* Load SVM model.
*
* @param pathToMdl Path to model.
* @param learningEnvironment Learning environment.
*/
private static Model loadLinearSVMModel(String pathToMdl, LearningEnvironment learningEnvironment) {
Vector coefficients = null;
double interceptor = 0;
try (ParquetFileReader r = ParquetFileReader.open(HadoopInputFile.fromPath(new Path(pathToMdl), new Configuration()))) {
PageReadStore pages;
final MessageType schema = r.getFooter().getFileMetaData().getSchema();
final MessageColumnIO colIO = new ColumnIOFactory().getColumnIO(schema);
while (null != (pages = r.readNextRowGroup())) {
final long rows = pages.getRowCount();
final RecordReader recordReader = colIO.getRecordReader(pages, new GroupRecordConverter(schema));
for (int i = 0; i < rows; i++) {
final SimpleGroup g = (SimpleGroup) recordReader.read();
interceptor = readSVMInterceptor(g);
coefficients = readSVMCoefficients(g);
}
}
} catch (IOException e) {
String msg = "Error reading parquet file: " + e.getMessage();
learningEnvironment.logger().log(MLLogger.VerboseLevel.HIGH, msg);
e.printStackTrace();
}
return new SVMLinearClassificationModel(coefficients, interceptor);
}
use of org.apache.ignite.ml.svm.SVMLinearClassificationModel in project ignite by apache.
the class LocalModelsTest method importExportSVMBinaryClassificationModelTest.
/**
*/
@Test
public void importExportSVMBinaryClassificationModelTest() throws IOException {
executeModelTest(mdlFilePath -> {
SVMLinearClassificationModel mdl = new SVMLinearClassificationModel(new DenseVector(new double[] { 1, 2 }), 3);
Exporter<SVMLinearClassificationModel, String> exporter = new FileExporter<>();
mdl.saveModel(exporter, mdlFilePath);
SVMLinearClassificationModel load = exporter.load(mdlFilePath);
Assert.assertNotNull(load);
Assert.assertEquals("", mdl, load);
return null;
});
}
Aggregations