Search in sources :

Example 1 with DecisionTreeRegressionTrainer

use of org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer in project ignite by apache.

the class DecisionTreeRegressionExportImportExample method main.

/**
 * Executes example.
 *
 * @param args Command line arguments, none required.
 */
public static void main(String... args) throws IOException {
    System.out.println(">>> Decision tree regression trainer example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println("\n>>> Ignite grid started.");
        // Create cache with training data.
        CacheConfiguration<Integer, LabeledVector<Double>> trainingSetCfg = new CacheConfiguration<>();
        trainingSetCfg.setName("TRAINING_SET");
        trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
        IgniteCache<Integer, LabeledVector<Double>> trainingSet = null;
        Path jsonMdlPath = null;
        try {
            trainingSet = ignite.createCache(trainingSetCfg);
            // Fill training data.
            generatePoints(trainingSet);
            // Create regression trainer.
            DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(10, 0);
            // Train decision tree model.
            DecisionTreeModel mdl = trainer.fit(ignite, trainingSet, new LabeledDummyVectorizer<>());
            System.out.println("\n>>> Exported Decision tree regression model: " + mdl);
            jsonMdlPath = Files.createTempFile(null, null);
            mdl.toJSON(jsonMdlPath);
            DecisionTreeModel modelImportedFromJSON = DecisionTreeModel.fromJSON(jsonMdlPath);
            System.out.println("\n>>> Imported Decision tree regression model: " + modelImportedFromJSON);
            System.out.println(">>> ---------------------------------");
            System.out.println(">>> | Prediction\t| Ground Truth\t|");
            System.out.println(">>> ---------------------------------");
            // Calculate score.
            for (int x = 0; x < 10; x++) {
                double predicted = mdl.predict(VectorUtils.of(x));
                System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.sin(x));
            }
            System.out.println(">>> ---------------------------------");
            System.out.println("\n>>> Decision tree regression trainer example completed.");
        } finally {
            if (trainingSet != null)
                trainingSet.destroy();
            if (jsonMdlPath != null)
                Files.deleteIfExists(jsonMdlPath);
        }
    } finally {
        System.out.flush();
    }
}
Also used : Path(java.nio.file.Path) DecisionTreeModel(org.apache.ignite.ml.tree.DecisionTreeModel) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) DecisionTreeRegressionTrainer(org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer) Ignite(org.apache.ignite.Ignite) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration)

Example 2 with DecisionTreeRegressionTrainer

use of org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer in project ignite by apache.

the class DecisionTreeRegressionTrainerExample method main.

/**
 * Executes example.
 *
 * @param args Command line arguments, none required.
 */
public static void main(String... args) {
    System.out.println(">>> Decision tree regression trainer example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println(">>> Ignite grid started.");
        // Create cache with training data.
        CacheConfiguration<Integer, LabeledVector<Double>> trainingSetCfg = new CacheConfiguration<>();
        trainingSetCfg.setName("TRAINING_SET");
        trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
        IgniteCache<Integer, LabeledVector<Double>> trainingSet = null;
        try {
            trainingSet = ignite.createCache(trainingSetCfg);
            // Fill training data.
            generatePoints(trainingSet);
            // Create regression trainer.
            DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(10, 0);
            // Train decision tree model.
            DecisionTreeModel mdl = trainer.fit(ignite, trainingSet, new LabeledDummyVectorizer<>());
            System.out.println(">>> Decision tree regression model: " + mdl);
            System.out.println(">>> ---------------------------------");
            System.out.println(">>> | Prediction\t| Ground Truth\t|");
            System.out.println(">>> ---------------------------------");
            // Calculate score.
            for (int x = 0; x < 10; x++) {
                double predicted = mdl.predict(VectorUtils.of(x));
                System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.sin(x));
            }
            System.out.println(">>> ---------------------------------");
            System.out.println(">>> Decision tree regression trainer example completed.");
        } finally {
            trainingSet.destroy();
        }
    } finally {
        System.out.flush();
    }
}
Also used : DecisionTreeRegressionTrainer(org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer) DecisionTreeModel(org.apache.ignite.ml.tree.DecisionTreeModel) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) Ignite(org.apache.ignite.Ignite) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration)

Example 3 with DecisionTreeRegressionTrainer

use of org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer in project gridgain by gridgain.

the class DecisionTreeRegressionTrainerExample method main.

/**
 * Executes example.
 *
 * @param args Command line arguments, none required.
 */
public static void main(String... args) {
    System.out.println(">>> Decision tree regression trainer example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples-ml/config/example-ignite.xml")) {
        System.out.println(">>> Ignite grid started.");
        // Create cache with training data.
        CacheConfiguration<Integer, LabeledVector<Double>> trainingSetCfg = new CacheConfiguration<>();
        trainingSetCfg.setName("TRAINING_SET");
        trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
        IgniteCache<Integer, LabeledVector<Double>> trainingSet = null;
        try {
            trainingSet = ignite.createCache(trainingSetCfg);
            // Fill training data.
            generatePoints(trainingSet);
            // Create regression trainer.
            DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(10, 0);
            // Train decision tree model.
            DecisionTreeNode mdl = trainer.fit(ignite, trainingSet, new LabeledDummyVectorizer<>());
            System.out.println(">>> Decision tree regression model: " + mdl);
            System.out.println(">>> ---------------------------------");
            System.out.println(">>> | Prediction\t| Ground Truth\t|");
            System.out.println(">>> ---------------------------------");
            // Calculate score.
            for (int x = 0; x < 10; x++) {
                double predicted = mdl.predict(VectorUtils.of(x));
                System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.sin(x));
            }
            System.out.println(">>> ---------------------------------");
            System.out.println(">>> Decision tree regression trainer example completed.");
        } finally {
            trainingSet.destroy();
        }
    } finally {
        System.out.flush();
    }
}
Also used : DecisionTreeRegressionTrainer(org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) Ignite(org.apache.ignite.Ignite) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) DecisionTreeNode(org.apache.ignite.ml.tree.DecisionTreeNode)

Aggregations

Ignite (org.apache.ignite.Ignite)3 RendezvousAffinityFunction (org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction)3 CacheConfiguration (org.apache.ignite.configuration.CacheConfiguration)3 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)3 DecisionTreeRegressionTrainer (org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer)3 DecisionTreeModel (org.apache.ignite.ml.tree.DecisionTreeModel)2 Path (java.nio.file.Path)1 DecisionTreeNode (org.apache.ignite.ml.tree.DecisionTreeNode)1