Search in sources :

Example 1 with DecisionTreeModel

use of org.apache.ignite.ml.tree.DecisionTreeModel in project ignite by apache.

the class DecisionTreeRegressionExportImportExample method main.

/**
 * Executes example.
 *
 * @param args Command line arguments, none required.
 */
public static void main(String... args) throws IOException {
    System.out.println(">>> Decision tree regression trainer example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println("\n>>> Ignite grid started.");
        // Create cache with training data.
        CacheConfiguration<Integer, LabeledVector<Double>> trainingSetCfg = new CacheConfiguration<>();
        trainingSetCfg.setName("TRAINING_SET");
        trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
        IgniteCache<Integer, LabeledVector<Double>> trainingSet = null;
        Path jsonMdlPath = null;
        try {
            trainingSet = ignite.createCache(trainingSetCfg);
            // Fill training data.
            generatePoints(trainingSet);
            // Create regression trainer.
            DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(10, 0);
            // Train decision tree model.
            DecisionTreeModel mdl = trainer.fit(ignite, trainingSet, new LabeledDummyVectorizer<>());
            System.out.println("\n>>> Exported Decision tree regression model: " + mdl);
            jsonMdlPath = Files.createTempFile(null, null);
            mdl.toJSON(jsonMdlPath);
            DecisionTreeModel modelImportedFromJSON = DecisionTreeModel.fromJSON(jsonMdlPath);
            System.out.println("\n>>> Imported Decision tree regression model: " + modelImportedFromJSON);
            System.out.println(">>> ---------------------------------");
            System.out.println(">>> | Prediction\t| Ground Truth\t|");
            System.out.println(">>> ---------------------------------");
            // Calculate score.
            for (int x = 0; x < 10; x++) {
                double predicted = mdl.predict(VectorUtils.of(x));
                System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.sin(x));
            }
            System.out.println(">>> ---------------------------------");
            System.out.println("\n>>> Decision tree regression trainer example completed.");
        } finally {
            if (trainingSet != null)
                trainingSet.destroy();
            if (jsonMdlPath != null)
                Files.deleteIfExists(jsonMdlPath);
        }
    } finally {
        System.out.flush();
    }
}
Also used : Path(java.nio.file.Path) DecisionTreeModel(org.apache.ignite.ml.tree.DecisionTreeModel) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) DecisionTreeRegressionTrainer(org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer) Ignite(org.apache.ignite.Ignite) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration)

Example 2 with DecisionTreeModel

use of org.apache.ignite.ml.tree.DecisionTreeModel in project ignite by apache.

the class DecisionTreeClassificationExportImportExample method main.

/**
 * Executes example.
 *
 * @param args Command line arguments, none required.
 */
public static void main(String[] args) throws IOException {
    System.out.println(">>> Decision tree classification trainer example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println("\n>>> Ignite grid started.");
        // Create cache with training data.
        CacheConfiguration<Integer, LabeledVector<Double>> trainingSetCfg = new CacheConfiguration<>();
        trainingSetCfg.setName("TRAINING_SET");
        trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
        IgniteCache<Integer, LabeledVector<Double>> trainingSet = null;
        Path jsonMdlPath = null;
        try {
            trainingSet = ignite.createCache(trainingSetCfg);
            Random rnd = new Random(0);
            // Fill training data.
            for (int i = 0; i < 1000; i++) trainingSet.put(i, generatePoint(rnd));
            // Create classification trainer.
            DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
            // Train decision tree model.
            LabeledDummyVectorizer<Integer, Double> vectorizer = new LabeledDummyVectorizer<>();
            DecisionTreeModel mdl = trainer.fit(ignite, trainingSet, vectorizer);
            System.out.println("\n>>> Exported Decision tree classification model: " + mdl);
            int correctPredictions = evaluateModel(rnd, mdl);
            System.out.println("\n>>> Accuracy for exported Decision tree classification model: " + correctPredictions / 10.0 + "%");
            jsonMdlPath = Files.createTempFile(null, null);
            mdl.toJSON(jsonMdlPath);
            DecisionTreeModel modelImportedFromJSON = DecisionTreeModel.fromJSON(jsonMdlPath);
            System.out.println("\n>>> Imported Decision tree classification model: " + modelImportedFromJSON);
            correctPredictions = evaluateModel(rnd, modelImportedFromJSON);
            System.out.println("\n>>> Accuracy for imported Decision tree classification model: " + correctPredictions / 10.0 + "%");
            System.out.println("\n>>> Decision tree classification trainer example completed.");
        } finally {
            if (trainingSet != null)
                trainingSet.destroy();
            if (jsonMdlPath != null)
                Files.deleteIfExists(jsonMdlPath);
        }
    } finally {
        System.out.flush();
    }
}
Also used : Path(java.nio.file.Path) DecisionTreeModel(org.apache.ignite.ml.tree.DecisionTreeModel) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LabeledDummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer) DecisionTreeClassificationTrainer(org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer) Random(java.util.Random) Ignite(org.apache.ignite.Ignite) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration)

Example 3 with DecisionTreeModel

use of org.apache.ignite.ml.tree.DecisionTreeModel in project ignite by apache.

the class DecisionTreeRegressionFromSparkExample method main.

/**
 * Run example.
 */
public static void main(String[] args) throws FileNotFoundException {
    System.out.println();
    System.out.println(">>> Decision tree regression model loaded from Spark through serialization over partitioned dataset usage example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println(">>> Ignite grid started.");
        IgniteCache<Integer, Vector> dataCache = null;
        try {
            dataCache = TitanicUtils.readPassengersWithoutNulls(ignite);
            final Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>(0, 1, 5, 6).labeled(4);
            DecisionTreeModel mdl = (DecisionTreeModel) SparkModelParser.parse(SPARK_MDL_PATH, SupportedSparkModels.DECISION_TREE_REGRESSION, env);
            System.out.println(">>> Decision tree regression model: " + mdl);
            System.out.println(">>> ---------------------------------");
            System.out.println(">>> | Prediction\t| Ground Truth\t|");
            System.out.println(">>> ---------------------------------");
            try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
                for (Cache.Entry<Integer, Vector> observation : observations) {
                    LabeledVector<Double> lv = vectorizer.apply(observation.getKey(), observation.getValue());
                    Vector inputs = lv.features();
                    double groundTruth = lv.label();
                    double prediction = mdl.predict(inputs);
                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
                }
            }
            System.out.println(">>> ---------------------------------");
        } finally {
            dataCache.destroy();
        }
    }
}
Also used : DecisionTreeModel(org.apache.ignite.ml.tree.DecisionTreeModel) Ignite(org.apache.ignite.Ignite) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) IgniteCache(org.apache.ignite.IgniteCache) Cache(javax.cache.Cache)

Example 4 with DecisionTreeModel

use of org.apache.ignite.ml.tree.DecisionTreeModel in project ignite by apache.

the class DecisionTreeClassificationTrainerSQLInferenceExample method main.

/**
 * Run example.
 */
public static void main(String[] args) throws IOException {
    System.out.println(">>> Decision tree classification trainer example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite-ml.xml")) {
        System.out.println(">>> Ignite grid started.");
        // Dummy cache is required to perform SQL queries.
        CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME).setSqlSchema("PUBLIC").setSqlFunctionClasses(SQLFunctions.class);
        IgniteCache<?, ?> cache = null;
        try {
            cache = ignite.getOrCreateCache(cacheCfg);
            System.out.println(">>> Creating table with training data...");
            cache.query(new SqlFieldsQuery("create table titanic_train (\n" + "    passengerid int primary key,\n" + "    pclass int,\n" + "    survived int,\n" + "    name varchar(255),\n" + "    sex varchar(255),\n" + "    age float,\n" + "    sibsp int,\n" + "    parch int,\n" + "    ticket varchar(255),\n" + "    fare float,\n" + "    cabin varchar(255),\n" + "    embarked varchar(255)\n" + ") with \"template=partitioned\";")).getAll();
            System.out.println(">>> Creating table with test data...");
            cache.query(new SqlFieldsQuery("create table titanic_test (\n" + "    passengerid int primary key,\n" + "    pclass int,\n" + "    survived int,\n" + "    name varchar(255),\n" + "    sex varchar(255),\n" + "    age float,\n" + "    sibsp int,\n" + "    parch int,\n" + "    ticket varchar(255),\n" + "    fare float,\n" + "    cabin varchar(255),\n" + "    embarked varchar(255)\n" + ") with \"template=partitioned\";")).getAll();
            loadTitanicDatasets(ignite, cache);
            System.out.println(">>> Prepare trainer...");
            DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
            System.out.println(">>> Perform training...");
            DecisionTreeModel mdl = trainer.fit(new SqlDatasetBuilder(ignite, "SQL_PUBLIC_TITANIC_TRAIN"), new BinaryObjectVectorizer<>("pclass", "age", "sibsp", "parch", "fare").withFeature("sex", BinaryObjectVectorizer.Mapping.create().map("male", 1.0).defaultValue(0.0)).labeled("survived"));
            System.out.println(">>> Saving model...");
            // Model storage is used to store raw serialized model.
            System.out.println("Saving model into model storage...");
            IgniteModelStorageUtil.saveModel(ignite, mdl, "titanic_model_tree");
            // Making inference using saved model.
            System.out.println("Inference...");
            try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " + "survived as truth, " + "predict('titanic_model_tree', pclass, age, sibsp, parch, fare, case sex when 'male' then 1 else 0 end) as prediction" + " from titanic_train"))) {
                // Print inference result.
                System.out.println("| Truth | Prediction |");
                System.out.println("|--------------------|");
                for (List<?> row : cursor) System.out.println("|     " + row.get(0) + " |        " + row.get(1) + " |");
            }
            IgniteModelStorageUtil.removeModel(ignite, "titanic_model_tree");
        } finally {
            cache.query(new SqlFieldsQuery("DROP TABLE titanic_train"));
            cache.query(new SqlFieldsQuery("DROP TABLE titanic_test"));
            cache.destroy();
        }
    } finally {
        System.out.flush();
    }
}
Also used : DecisionTreeClassificationTrainer(org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer) DecisionTreeModel(org.apache.ignite.ml.tree.DecisionTreeModel) Ignite(org.apache.ignite.Ignite) List(java.util.List) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) SqlFieldsQuery(org.apache.ignite.cache.query.SqlFieldsQuery) SqlDatasetBuilder(org.apache.ignite.ml.sql.SqlDatasetBuilder)

Example 5 with DecisionTreeModel

use of org.apache.ignite.ml.tree.DecisionTreeModel in project ignite by apache.

the class Step_1_Read_and_Learn method main.

/**
 * Run example.
 */
public static void main(String[] args) {
    System.out.println();
    System.out.println(">>> Tutorial step 1 (read and learn) example started.");
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        try {
            IgniteCache<Integer, Vector> dataCache = TitanicUtils.readPassengersWithoutNulls(ignite);
            final Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>(0, 5, 6).labeled(1);
            DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0);
            DecisionTreeModel mdl = trainer.fit(ignite, dataCache, vectorizer);
            System.out.println("\n>>> Trained model: " + mdl);
            double accuracy = Evaluator.evaluate(dataCache, mdl, vectorizer, new Accuracy<>());
            System.out.println("\n>>> Accuracy " + accuracy);
            System.out.println("\n>>> Test Error " + (1 - accuracy));
            System.out.println(">>> Tutorial step 1 (read and learn) example completed.");
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        }
    } finally {
        System.out.flush();
    }
}
Also used : DecisionTreeClassificationTrainer(org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer) DecisionTreeModel(org.apache.ignite.ml.tree.DecisionTreeModel) FileNotFoundException(java.io.FileNotFoundException) Ignite(org.apache.ignite.Ignite) Vector(org.apache.ignite.ml.math.primitives.vector.Vector)

Aggregations

DecisionTreeModel (org.apache.ignite.ml.tree.DecisionTreeModel)32 Ignite (org.apache.ignite.Ignite)27 DecisionTreeClassificationTrainer (org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer)26 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)20 FileNotFoundException (java.io.FileNotFoundException)18 EncoderTrainer (org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer)12 CrossValidation (org.apache.ignite.ml.selection.cv.CrossValidation)9 CrossValidationResult (org.apache.ignite.ml.selection.cv.CrossValidationResult)7 ParamGrid (org.apache.ignite.ml.selection.paramgrid.ParamGrid)7 CacheConfiguration (org.apache.ignite.configuration.CacheConfiguration)6 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)6 HashMap (java.util.HashMap)5 RendezvousAffinityFunction (org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction)5 NormalizationTrainer (org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer)5 Test (org.junit.Test)4 Random (java.util.Random)3 SandboxMLCache (org.apache.ignite.examples.ml.util.SandboxMLCache)3 LabeledDummyVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer)3 Path (java.nio.file.Path)2 List (java.util.List)2