use of org.apache.ignite.ml.inference.reader.ModelReader in project ignite by apache.
the class IgniteModelDistributedInferenceExample method main.
/**
* Run example.
*/
public static void main(String... args) throws IOException, ExecutionException, InterruptedException {
System.out.println();
System.out.println(">>> Linear regression model over cache based dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
IgniteCache<Integer, Vector> dataCache = null;
try {
dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.MORTALITY_DATA);
System.out.println(">>> Create new linear regression trainer object.");
LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
System.out.println(">>> Perform the training to get the model.");
LinearRegressionModel mdl = trainer.fit(ignite, dataCache, new DummyVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST));
System.out.println(">>> Linear regression model: " + mdl);
System.out.println(">>> Preparing model reader and model parser.");
ModelReader reader = new InMemoryModelReader(mdl);
ModelParser<Vector, Double, ?> parser = new IgniteModelParser<>();
try (Model<Vector, Future<Double>> infMdl = new IgniteDistributedModelBuilder(ignite, 4, 4).build(reader, parser)) {
System.out.println(">>> Inference model is ready.");
System.out.println(">>> ---------------------------------");
System.out.println(">>> | Prediction\t| Ground Truth\t|");
System.out.println(">>> ---------------------------------");
try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
for (Cache.Entry<Integer, Vector> observation : observations) {
Vector val = observation.getValue();
Vector inputs = val.copyOfRange(1, val.size());
double groundTruth = val.get(0);
double prediction = infMdl.predict(inputs).get();
System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
}
}
}
System.out.println(">>> ---------------------------------");
System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
} finally {
if (dataCache != null)
dataCache.destroy();
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.inference.reader.ModelReader in project ignite by apache.
the class CatboostClassificationModelParserExample method main.
/**
* Run example.
*/
public static void main(String... args) throws ExecutionException, InterruptedException, FileNotFoundException {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
File mdlRsrc = IgniteUtils.resolveIgnitePath(TEST_MODEL_RES);
if (mdlRsrc == null)
throw new IllegalArgumentException("File not found [resource_path=" + TEST_MODEL_RES + "]");
ModelReader reader = new FileSystemModelReader(mdlRsrc.getPath());
AsyncModelBuilder mdlBuilder = new IgniteDistributedModelBuilder(ignite, 4, 4);
File testData = IgniteUtils.resolveIgnitePath(TEST_DATA_RES);
if (testData == null)
throw new IllegalArgumentException("File not found [resource_path=" + TEST_DATA_RES + "]");
File testExpRes = IgniteUtils.resolveIgnitePath(TEST_ER_RES);
if (testExpRes == null)
throw new IllegalArgumentException("File not found [resource_path=" + TEST_ER_RES + "]");
try (Model<NamedVector, Future<Double>> mdl = mdlBuilder.build(reader, parser);
Scanner testDataScanner = new Scanner(testData);
Scanner testExpResultsScanner = new Scanner(testExpRes)) {
String header = testDataScanner.nextLine();
String[] columns = header.split(",");
while (testDataScanner.hasNextLine()) {
String testDataStr = testDataScanner.nextLine();
String testExpResultsStr = testExpResultsScanner.nextLine();
HashMap<String, Double> testObj = new HashMap<>();
String[] values = testDataStr.split(",");
for (int i = 0; i < columns.length; i++) {
testObj.put(columns[i], Double.valueOf(values[i]));
}
double prediction = mdl.predict(VectorUtils.of(testObj)).get();
double expPrediction = Double.parseDouble(testExpResultsStr);
System.out.println("Expected: " + expPrediction + ", prediction: " + prediction);
}
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.inference.reader.ModelReader in project ignite by apache.
the class H2OMojoParserTest method testParseAndPredict.
/**
*/
@Test
public void testParseAndPredict() {
URL url = H2OMojoParserTest.class.getClassLoader().getResource(TEST_MODEL_RESOURCE);
if (url == null)
throw new IllegalStateException("File not found [resource_name=" + TEST_MODEL_RESOURCE + "]");
ModelReader reader = new FileSystemModelReader(url.getPath());
try (H2OMojoModel mdl = mdlBuilder.build(reader, parser)) {
HashMap<String, Double> input = new HashMap<>();
input.put("RACE", 1.0);
input.put("DPROS", 2.0);
input.put("DCAPS", 1.0);
input.put("PSA", 1.4);
input.put("VOL", 0.0);
input.put("GLEASON", 6.0);
double prediction = mdl.predict(VectorUtils.of(input));
assertEquals(64.50328, prediction, 1e-5);
}
}
use of org.apache.ignite.ml.inference.reader.ModelReader in project ignite by apache.
the class CatboostClassificationModelParserTest method testParseAndPredict.
/**
* End-to-end test for {@code parse()} and {@code predict()} methods.
*/
@Test
public void testParseAndPredict() {
URL url = CatboostClassificationModelParserTest.class.getClassLoader().getResource(TEST_MODEL_RESOURCE);
if (url == null)
throw new IllegalStateException("File not found [resource_name=" + TEST_MODEL_RESOURCE + "]");
ModelReader reader = new FileSystemModelReader(url.getPath());
try (CatboostClassificationModel mdl = mdlBuilder.build(reader, parser)) {
HashMap<String, Double> input = new HashMap<>();
input.put("ACTION", 1.0);
input.put("RESOURCE", 39353.0);
input.put("MGR_ID", 85475.0);
input.put("ROLE_ROLLUP_1", 117961.0);
input.put("ROLE_ROLLUP_2", 118300.0);
input.put("ROLE_DEPTNAME", 123472.0);
input.put("ROLE_TITLE", 117905.0);
input.put("ROLE_FAMILY_DESC", 117906.0);
input.put("ROLE_FAMILY", 290919.0);
input.put("ROLE_CODE", 117908.0);
double prediction = mdl.predict(VectorUtils.of(input));
assertEquals(0.9928904609329371, prediction, 1e-5);
}
}
use of org.apache.ignite.ml.inference.reader.ModelReader in project ignite by apache.
the class XGBoostModelParserTest method testParseAndPredict.
/**
* End-to-end test for {@code parse()} and {@code predict()} methods.
*/
@Test
public void testParseAndPredict() {
URL url = XGBoostModelParserTest.class.getClassLoader().getResource(TEST_MODEL_RESOURCE);
if (url == null)
throw new IllegalStateException("File not found [resource_name=" + TEST_MODEL_RESOURCE + "]");
ModelReader reader = new FileSystemModelReader(url.getPath());
try (XGModelComposition mdl = mdlBuilder.build(reader, parser);
Scanner testDataScanner = new Scanner(XGBoostModelParserTest.class.getClassLoader().getResourceAsStream("datasets/agaricus-test-data.txt"));
Scanner testExpResultsScanner = new Scanner(XGBoostModelParserTest.class.getClassLoader().getResourceAsStream("datasets/agaricus-test-expected-results.txt"))) {
while (testDataScanner.hasNextLine()) {
assertTrue(testExpResultsScanner.hasNextLine());
String testDataStr = testDataScanner.nextLine();
String testExpResultsStr = testExpResultsScanner.nextLine();
HashMap<String, Double> testObj = new HashMap<>();
for (String keyValueString : testDataStr.split(" ")) {
String[] keyVal = keyValueString.split(":");
if (keyVal.length == 2)
testObj.put("f" + keyVal[0], Double.parseDouble(keyVal[1]));
}
double prediction = mdl.predict(VectorUtils.of(testObj));
double expPrediction = Double.parseDouble(testExpResultsStr);
assertEquals(expPrediction, prediction, 1e-6);
}
}
}
Aggregations