Search in sources :

Example 1 with InMemoryModelReader

use of org.apache.ignite.ml.inference.reader.InMemoryModelReader in project ignite by apache.

the class IgniteModelDistributedInferenceExample method main.

/**
 * Run example.
 */
public static void main(String... args) throws IOException, ExecutionException, InterruptedException {
    System.out.println();
    System.out.println(">>> Linear regression model over cache based dataset usage example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println(">>> Ignite grid started.");
        IgniteCache<Integer, Vector> dataCache = null;
        try {
            dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.MORTALITY_DATA);
            System.out.println(">>> Create new linear regression trainer object.");
            LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
            System.out.println(">>> Perform the training to get the model.");
            LinearRegressionModel mdl = trainer.fit(ignite, dataCache, new DummyVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST));
            System.out.println(">>> Linear regression model: " + mdl);
            System.out.println(">>> Preparing model reader and model parser.");
            ModelReader reader = new InMemoryModelReader(mdl);
            ModelParser<Vector, Double, ?> parser = new IgniteModelParser<>();
            try (Model<Vector, Future<Double>> infMdl = new IgniteDistributedModelBuilder(ignite, 4, 4).build(reader, parser)) {
                System.out.println(">>> Inference model is ready.");
                System.out.println(">>> ---------------------------------");
                System.out.println(">>> | Prediction\t| Ground Truth\t|");
                System.out.println(">>> ---------------------------------");
                try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
                    for (Cache.Entry<Integer, Vector> observation : observations) {
                        Vector val = observation.getValue();
                        Vector inputs = val.copyOfRange(1, val.size());
                        double groundTruth = val.get(0);
                        double prediction = infMdl.predict(inputs).get();
                        System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
                    }
                }
            }
            System.out.println(">>> ---------------------------------");
            System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
        } finally {
            if (dataCache != null)
                dataCache.destroy();
        }
    } finally {
        System.out.flush();
    }
}
Also used : SandboxMLCache(org.apache.ignite.examples.ml.util.SandboxMLCache) LinearRegressionModel(org.apache.ignite.ml.regressions.linear.LinearRegressionModel) IgniteModelParser(org.apache.ignite.ml.inference.parser.IgniteModelParser) DummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer) InMemoryModelReader(org.apache.ignite.ml.inference.reader.InMemoryModelReader) LinearRegressionLSQRTrainer(org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer) InMemoryModelReader(org.apache.ignite.ml.inference.reader.InMemoryModelReader) ModelReader(org.apache.ignite.ml.inference.reader.ModelReader) Future(java.util.concurrent.Future) Ignite(org.apache.ignite.Ignite) IgniteDistributedModelBuilder(org.apache.ignite.ml.inference.builder.IgniteDistributedModelBuilder) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) IgniteCache(org.apache.ignite.IgniteCache) SandboxMLCache(org.apache.ignite.examples.ml.util.SandboxMLCache) Cache(javax.cache.Cache)

Aggregations

Future (java.util.concurrent.Future)1 Cache (javax.cache.Cache)1 Ignite (org.apache.ignite.Ignite)1 IgniteCache (org.apache.ignite.IgniteCache)1 SandboxMLCache (org.apache.ignite.examples.ml.util.SandboxMLCache)1 DummyVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer)1 IgniteDistributedModelBuilder (org.apache.ignite.ml.inference.builder.IgniteDistributedModelBuilder)1 IgniteModelParser (org.apache.ignite.ml.inference.parser.IgniteModelParser)1 InMemoryModelReader (org.apache.ignite.ml.inference.reader.InMemoryModelReader)1 ModelReader (org.apache.ignite.ml.inference.reader.ModelReader)1 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)1 LinearRegressionLSQRTrainer (org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer)1 LinearRegressionModel (org.apache.ignite.ml.regressions.linear.LinearRegressionModel)1