Search in sources :

Example 1 with ModelReader

use of org.apache.ignite.ml.inference.reader.ModelReader in project ignite by apache.

the class IgniteModelDistributedInferenceExample method main.

/**
 * Run example.
 */
public static void main(String... args) throws IOException, ExecutionException, InterruptedException {
    System.out.println();
    System.out.println(">>> Linear regression model over cache based dataset usage example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println(">>> Ignite grid started.");
        IgniteCache<Integer, Vector> dataCache = null;
        try {
            dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.MORTALITY_DATA);
            System.out.println(">>> Create new linear regression trainer object.");
            LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
            System.out.println(">>> Perform the training to get the model.");
            LinearRegressionModel mdl = trainer.fit(ignite, dataCache, new DummyVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST));
            System.out.println(">>> Linear regression model: " + mdl);
            System.out.println(">>> Preparing model reader and model parser.");
            ModelReader reader = new InMemoryModelReader(mdl);
            ModelParser<Vector, Double, ?> parser = new IgniteModelParser<>();
            try (Model<Vector, Future<Double>> infMdl = new IgniteDistributedModelBuilder(ignite, 4, 4).build(reader, parser)) {
                System.out.println(">>> Inference model is ready.");
                System.out.println(">>> ---------------------------------");
                System.out.println(">>> | Prediction\t| Ground Truth\t|");
                System.out.println(">>> ---------------------------------");
                try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
                    for (Cache.Entry<Integer, Vector> observation : observations) {
                        Vector val = observation.getValue();
                        Vector inputs = val.copyOfRange(1, val.size());
                        double groundTruth = val.get(0);
                        double prediction = infMdl.predict(inputs).get();
                        System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
                    }
                }
            }
            System.out.println(">>> ---------------------------------");
            System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
        } finally {
            if (dataCache != null)
                dataCache.destroy();
        }
    } finally {
        System.out.flush();
    }
}
Also used : SandboxMLCache(org.apache.ignite.examples.ml.util.SandboxMLCache) LinearRegressionModel(org.apache.ignite.ml.regressions.linear.LinearRegressionModel) IgniteModelParser(org.apache.ignite.ml.inference.parser.IgniteModelParser) DummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer) InMemoryModelReader(org.apache.ignite.ml.inference.reader.InMemoryModelReader) LinearRegressionLSQRTrainer(org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer) InMemoryModelReader(org.apache.ignite.ml.inference.reader.InMemoryModelReader) ModelReader(org.apache.ignite.ml.inference.reader.ModelReader) Future(java.util.concurrent.Future) Ignite(org.apache.ignite.Ignite) IgniteDistributedModelBuilder(org.apache.ignite.ml.inference.builder.IgniteDistributedModelBuilder) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) IgniteCache(org.apache.ignite.IgniteCache) SandboxMLCache(org.apache.ignite.examples.ml.util.SandboxMLCache) Cache(javax.cache.Cache)

Example 2 with ModelReader

use of org.apache.ignite.ml.inference.reader.ModelReader in project ignite by apache.

the class CatboostClassificationModelParserExample method main.

/**
 * Run example.
 */
public static void main(String... args) throws ExecutionException, InterruptedException, FileNotFoundException {
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        File mdlRsrc = IgniteUtils.resolveIgnitePath(TEST_MODEL_RES);
        if (mdlRsrc == null)
            throw new IllegalArgumentException("File not found [resource_path=" + TEST_MODEL_RES + "]");
        ModelReader reader = new FileSystemModelReader(mdlRsrc.getPath());
        AsyncModelBuilder mdlBuilder = new IgniteDistributedModelBuilder(ignite, 4, 4);
        File testData = IgniteUtils.resolveIgnitePath(TEST_DATA_RES);
        if (testData == null)
            throw new IllegalArgumentException("File not found [resource_path=" + TEST_DATA_RES + "]");
        File testExpRes = IgniteUtils.resolveIgnitePath(TEST_ER_RES);
        if (testExpRes == null)
            throw new IllegalArgumentException("File not found [resource_path=" + TEST_ER_RES + "]");
        try (Model<NamedVector, Future<Double>> mdl = mdlBuilder.build(reader, parser);
            Scanner testDataScanner = new Scanner(testData);
            Scanner testExpResultsScanner = new Scanner(testExpRes)) {
            String header = testDataScanner.nextLine();
            String[] columns = header.split(",");
            while (testDataScanner.hasNextLine()) {
                String testDataStr = testDataScanner.nextLine();
                String testExpResultsStr = testExpResultsScanner.nextLine();
                HashMap<String, Double> testObj = new HashMap<>();
                String[] values = testDataStr.split(",");
                for (int i = 0; i < columns.length; i++) {
                    testObj.put(columns[i], Double.valueOf(values[i]));
                }
                double prediction = mdl.predict(VectorUtils.of(testObj)).get();
                double expPrediction = Double.parseDouble(testExpResultsStr);
                System.out.println("Expected: " + expPrediction + ", prediction: " + prediction);
            }
        }
    } finally {
        System.out.flush();
    }
}
Also used : Scanner(java.util.Scanner) HashMap(java.util.HashMap) FileSystemModelReader(org.apache.ignite.ml.inference.reader.FileSystemModelReader) ModelReader(org.apache.ignite.ml.inference.reader.ModelReader) FileSystemModelReader(org.apache.ignite.ml.inference.reader.FileSystemModelReader) NamedVector(org.apache.ignite.ml.math.primitives.vector.NamedVector) AsyncModelBuilder(org.apache.ignite.ml.inference.builder.AsyncModelBuilder) Future(java.util.concurrent.Future) Ignite(org.apache.ignite.Ignite) IgniteDistributedModelBuilder(org.apache.ignite.ml.inference.builder.IgniteDistributedModelBuilder) File(java.io.File)

Example 3 with ModelReader

use of org.apache.ignite.ml.inference.reader.ModelReader in project ignite by apache.

the class H2OMojoParserTest method testParseAndPredict.

/**
 */
@Test
public void testParseAndPredict() {
    URL url = H2OMojoParserTest.class.getClassLoader().getResource(TEST_MODEL_RESOURCE);
    if (url == null)
        throw new IllegalStateException("File not found [resource_name=" + TEST_MODEL_RESOURCE + "]");
    ModelReader reader = new FileSystemModelReader(url.getPath());
    try (H2OMojoModel mdl = mdlBuilder.build(reader, parser)) {
        HashMap<String, Double> input = new HashMap<>();
        input.put("RACE", 1.0);
        input.put("DPROS", 2.0);
        input.put("DCAPS", 1.0);
        input.put("PSA", 1.4);
        input.put("VOL", 0.0);
        input.put("GLEASON", 6.0);
        double prediction = mdl.predict(VectorUtils.of(input));
        assertEquals(64.50328, prediction, 1e-5);
    }
}
Also used : FileSystemModelReader(org.apache.ignite.ml.inference.reader.FileSystemModelReader) ModelReader(org.apache.ignite.ml.inference.reader.ModelReader) HashMap(java.util.HashMap) FileSystemModelReader(org.apache.ignite.ml.inference.reader.FileSystemModelReader) URL(java.net.URL) Test(org.junit.Test)

Example 4 with ModelReader

use of org.apache.ignite.ml.inference.reader.ModelReader in project ignite by apache.

the class CatboostClassificationModelParserTest method testParseAndPredict.

/**
 * End-to-end test for {@code parse()} and {@code predict()} methods.
 */
@Test
public void testParseAndPredict() {
    URL url = CatboostClassificationModelParserTest.class.getClassLoader().getResource(TEST_MODEL_RESOURCE);
    if (url == null)
        throw new IllegalStateException("File not found [resource_name=" + TEST_MODEL_RESOURCE + "]");
    ModelReader reader = new FileSystemModelReader(url.getPath());
    try (CatboostClassificationModel mdl = mdlBuilder.build(reader, parser)) {
        HashMap<String, Double> input = new HashMap<>();
        input.put("ACTION", 1.0);
        input.put("RESOURCE", 39353.0);
        input.put("MGR_ID", 85475.0);
        input.put("ROLE_ROLLUP_1", 117961.0);
        input.put("ROLE_ROLLUP_2", 118300.0);
        input.put("ROLE_DEPTNAME", 123472.0);
        input.put("ROLE_TITLE", 117905.0);
        input.put("ROLE_FAMILY_DESC", 117906.0);
        input.put("ROLE_FAMILY", 290919.0);
        input.put("ROLE_CODE", 117908.0);
        double prediction = mdl.predict(VectorUtils.of(input));
        assertEquals(0.9928904609329371, prediction, 1e-5);
    }
}
Also used : FileSystemModelReader(org.apache.ignite.ml.inference.reader.FileSystemModelReader) ModelReader(org.apache.ignite.ml.inference.reader.ModelReader) HashMap(java.util.HashMap) CatboostClassificationModel(org.apache.ignite.ml.catboost.CatboostClassificationModel) FileSystemModelReader(org.apache.ignite.ml.inference.reader.FileSystemModelReader) URL(java.net.URL) Test(org.junit.Test)

Example 5 with ModelReader

use of org.apache.ignite.ml.inference.reader.ModelReader in project ignite by apache.

the class XGBoostModelParserTest method testParseAndPredict.

/**
 * End-to-end test for {@code parse()} and {@code predict()} methods.
 */
@Test
public void testParseAndPredict() {
    URL url = XGBoostModelParserTest.class.getClassLoader().getResource(TEST_MODEL_RESOURCE);
    if (url == null)
        throw new IllegalStateException("File not found [resource_name=" + TEST_MODEL_RESOURCE + "]");
    ModelReader reader = new FileSystemModelReader(url.getPath());
    try (XGModelComposition mdl = mdlBuilder.build(reader, parser);
        Scanner testDataScanner = new Scanner(XGBoostModelParserTest.class.getClassLoader().getResourceAsStream("datasets/agaricus-test-data.txt"));
        Scanner testExpResultsScanner = new Scanner(XGBoostModelParserTest.class.getClassLoader().getResourceAsStream("datasets/agaricus-test-expected-results.txt"))) {
        while (testDataScanner.hasNextLine()) {
            assertTrue(testExpResultsScanner.hasNextLine());
            String testDataStr = testDataScanner.nextLine();
            String testExpResultsStr = testExpResultsScanner.nextLine();
            HashMap<String, Double> testObj = new HashMap<>();
            for (String keyValueString : testDataStr.split(" ")) {
                String[] keyVal = keyValueString.split(":");
                if (keyVal.length == 2)
                    testObj.put("f" + keyVal[0], Double.parseDouble(keyVal[1]));
            }
            double prediction = mdl.predict(VectorUtils.of(testObj));
            double expPrediction = Double.parseDouble(testExpResultsStr);
            assertEquals(expPrediction, prediction, 1e-6);
        }
    }
}
Also used : Scanner(java.util.Scanner) XGModelComposition(org.apache.ignite.ml.xgboost.XGModelComposition) HashMap(java.util.HashMap) FileSystemModelReader(org.apache.ignite.ml.inference.reader.FileSystemModelReader) URL(java.net.URL) FileSystemModelReader(org.apache.ignite.ml.inference.reader.FileSystemModelReader) ModelReader(org.apache.ignite.ml.inference.reader.ModelReader) Test(org.junit.Test)

Aggregations

ModelReader (org.apache.ignite.ml.inference.reader.ModelReader)9 HashMap (java.util.HashMap)8 FileSystemModelReader (org.apache.ignite.ml.inference.reader.FileSystemModelReader)8 Scanner (java.util.Scanner)5 Future (java.util.concurrent.Future)5 Ignite (org.apache.ignite.Ignite)5 IgniteDistributedModelBuilder (org.apache.ignite.ml.inference.builder.IgniteDistributedModelBuilder)5 File (java.io.File)4 URL (java.net.URL)4 AsyncModelBuilder (org.apache.ignite.ml.inference.builder.AsyncModelBuilder)4 NamedVector (org.apache.ignite.ml.math.primitives.vector.NamedVector)4 Test (org.junit.Test)4 Cache (javax.cache.Cache)1 IgniteCache (org.apache.ignite.IgniteCache)1 SandboxMLCache (org.apache.ignite.examples.ml.util.SandboxMLCache)1 CatboostClassificationModel (org.apache.ignite.ml.catboost.CatboostClassificationModel)1 CatboostRegressionModel (org.apache.ignite.ml.catboost.CatboostRegressionModel)1 DummyVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer)1 IgniteModelParser (org.apache.ignite.ml.inference.parser.IgniteModelParser)1 InMemoryModelReader (org.apache.ignite.ml.inference.reader.InMemoryModelReader)1