use of org.apache.ignite.ml.inference.reader.FileSystemModelReader in project ignite by apache.
the class H2OMojoModelParserExample method main.
/**
* Run example.
*/
public static void main(String... args) throws ExecutionException, InterruptedException, FileNotFoundException {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
File mdlRsrc = IgniteUtils.resolveIgnitePath(MODEL_RES);
if (mdlRsrc == null)
throw new IllegalArgumentException("File not found [resource_path=" + MODEL_RES + "]");
ModelReader reader = new FileSystemModelReader(mdlRsrc.getPath());
AsyncModelBuilder mdlBuilder = new IgniteDistributedModelBuilder(ignite, 4, 4);
File testData = IgniteUtils.resolveIgnitePath(DATA_RES);
if (testData == null)
throw new IllegalArgumentException("File not found [resource_path=" + DATA_RES + "]");
try (Model<NamedVector, Future<Double>> mdl = mdlBuilder.build(reader, parser);
Scanner testDataScanner = new Scanner(testData)) {
while (testDataScanner.hasNextLine()) {
String testDataStr = testDataScanner.nextLine();
String actual = null;
HashMap<String, Double> testObj = new HashMap<>();
for (String keyValueString : testDataStr.split(" ")) {
String[] keyVal = keyValueString.split(":");
if (keyVal.length == 2)
testObj.put("C" + (1 + Integer.parseInt(keyVal[0])), Double.parseDouble(keyVal[1]));
else
actual = keyValueString;
}
double prediction = mdl.predict(VectorUtils.of(testObj)).get();
System.out.println("Actual: " + actual + ", prediction: " + prediction);
}
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.inference.reader.FileSystemModelReader in project ignite by apache.
the class CatboostRegressionModelParserExample method main.
/**
* Run example.
*/
public static void main(String... args) throws ExecutionException, InterruptedException, FileNotFoundException {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
File mdlRsrc = IgniteUtils.resolveIgnitePath(TEST_MODEL_RES);
if (mdlRsrc == null)
throw new IllegalArgumentException("File not found [resource_path=" + TEST_MODEL_RES + "]");
ModelReader reader = new FileSystemModelReader(mdlRsrc.getPath());
AsyncModelBuilder mdlBuilder = new IgniteDistributedModelBuilder(ignite, 4, 4);
File testData = IgniteUtils.resolveIgnitePath(TEST_DATA_RES);
if (testData == null)
throw new IllegalArgumentException("File not found [resource_path=" + TEST_DATA_RES + "]");
File testExpRes = IgniteUtils.resolveIgnitePath(TEST_ER_RES);
if (testExpRes == null)
throw new IllegalArgumentException("File not found [resource_path=" + TEST_ER_RES + "]");
try (Model<NamedVector, Future<Double>> mdl = mdlBuilder.build(reader, parser);
Scanner testDataScanner = new Scanner(testData);
Scanner testExpResultsScanner = new Scanner(testExpRes)) {
String[] columns = new String[] { "f_0", "f_1", "f_2", "f_3", "f_4", "f_5", "f_6", "f_7", "f_8", "f_9", "f_10", "f_11", "f_12" };
while (testDataScanner.hasNextLine()) {
String testDataStr = testDataScanner.nextLine();
String testExpResultsStr = testExpResultsScanner.nextLine();
HashMap<String, Double> testObj = new HashMap<>();
String[] values = testDataStr.split(",");
for (int i = 0; i < columns.length; i++) {
testObj.put(columns[i], Double.valueOf(values[i]));
}
double prediction = mdl.predict(VectorUtils.of(testObj)).get();
double expPrediction = Double.parseDouble(testExpResultsStr);
System.out.println("Expected: " + expPrediction + ", prediction: " + prediction);
}
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.inference.reader.FileSystemModelReader in project ignite by apache.
the class CatboostRegressionModelParserTest method testParseAndPredict.
/**
* End-to-end test for {@code parse()} and {@code predict()} methods.
*/
@Test
public void testParseAndPredict() {
URL url = CatboostRegressionModelParserTest.class.getClassLoader().getResource(TEST_MODEL_RESOURCE);
if (url == null)
throw new IllegalStateException("File not found [resource_name=" + TEST_MODEL_RESOURCE + "]");
ModelReader reader = new FileSystemModelReader(url.getPath());
try (CatboostRegressionModel mdl = mdlBuilder.build(reader, parser)) {
HashMap<String, Double> input = new HashMap<>();
input.put("f_0", 0.02731d);
input.put("f_1", 0.0d);
input.put("f_2", 7.07d);
input.put("f_3", 0d);
input.put("f_4", 0.469d);
input.put("f_5", 6.421d);
input.put("f_6", 78.9d);
input.put("f_7", 4.9671d);
input.put("f_8", 2d);
input.put("f_9", 242.0d);
input.put("f_10", 17.8d);
input.put("f_11", 396.9d);
input.put("f_12", 9.14d);
double prediction = mdl.predict(VectorUtils.of(input));
assertEquals(21.164552741740483, prediction, 1e-5);
}
}
Aggregations