use of org.apache.ignite.ml.tree.DecisionTreeModel in project ignite by apache.
the class DecisionTreeRegressionExportImportExample method main.
/**
* Executes example.
*
* @param args Command line arguments, none required.
*/
public static void main(String... args) throws IOException {
System.out.println(">>> Decision tree regression trainer example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println("\n>>> Ignite grid started.");
// Create cache with training data.
CacheConfiguration<Integer, LabeledVector<Double>> trainingSetCfg = new CacheConfiguration<>();
trainingSetCfg.setName("TRAINING_SET");
trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
IgniteCache<Integer, LabeledVector<Double>> trainingSet = null;
Path jsonMdlPath = null;
try {
trainingSet = ignite.createCache(trainingSetCfg);
// Fill training data.
generatePoints(trainingSet);
// Create regression trainer.
DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(10, 0);
// Train decision tree model.
DecisionTreeModel mdl = trainer.fit(ignite, trainingSet, new LabeledDummyVectorizer<>());
System.out.println("\n>>> Exported Decision tree regression model: " + mdl);
jsonMdlPath = Files.createTempFile(null, null);
mdl.toJSON(jsonMdlPath);
DecisionTreeModel modelImportedFromJSON = DecisionTreeModel.fromJSON(jsonMdlPath);
System.out.println("\n>>> Imported Decision tree regression model: " + modelImportedFromJSON);
System.out.println(">>> ---------------------------------");
System.out.println(">>> | Prediction\t| Ground Truth\t|");
System.out.println(">>> ---------------------------------");
// Calculate score.
for (int x = 0; x < 10; x++) {
double predicted = mdl.predict(VectorUtils.of(x));
System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.sin(x));
}
System.out.println(">>> ---------------------------------");
System.out.println("\n>>> Decision tree regression trainer example completed.");
} finally {
if (trainingSet != null)
trainingSet.destroy();
if (jsonMdlPath != null)
Files.deleteIfExists(jsonMdlPath);
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.tree.DecisionTreeModel in project ignite by apache.
the class DecisionTreeClassificationExportImportExample method main.
/**
* Executes example.
*
* @param args Command line arguments, none required.
*/
public static void main(String[] args) throws IOException {
System.out.println(">>> Decision tree classification trainer example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println("\n>>> Ignite grid started.");
// Create cache with training data.
CacheConfiguration<Integer, LabeledVector<Double>> trainingSetCfg = new CacheConfiguration<>();
trainingSetCfg.setName("TRAINING_SET");
trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
IgniteCache<Integer, LabeledVector<Double>> trainingSet = null;
Path jsonMdlPath = null;
try {
trainingSet = ignite.createCache(trainingSetCfg);
Random rnd = new Random(0);
// Fill training data.
for (int i = 0; i < 1000; i++) trainingSet.put(i, generatePoint(rnd));
// Create classification trainer.
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
// Train decision tree model.
LabeledDummyVectorizer<Integer, Double> vectorizer = new LabeledDummyVectorizer<>();
DecisionTreeModel mdl = trainer.fit(ignite, trainingSet, vectorizer);
System.out.println("\n>>> Exported Decision tree classification model: " + mdl);
int correctPredictions = evaluateModel(rnd, mdl);
System.out.println("\n>>> Accuracy for exported Decision tree classification model: " + correctPredictions / 10.0 + "%");
jsonMdlPath = Files.createTempFile(null, null);
mdl.toJSON(jsonMdlPath);
DecisionTreeModel modelImportedFromJSON = DecisionTreeModel.fromJSON(jsonMdlPath);
System.out.println("\n>>> Imported Decision tree classification model: " + modelImportedFromJSON);
correctPredictions = evaluateModel(rnd, modelImportedFromJSON);
System.out.println("\n>>> Accuracy for imported Decision tree classification model: " + correctPredictions / 10.0 + "%");
System.out.println("\n>>> Decision tree classification trainer example completed.");
} finally {
if (trainingSet != null)
trainingSet.destroy();
if (jsonMdlPath != null)
Files.deleteIfExists(jsonMdlPath);
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.tree.DecisionTreeModel in project ignite by apache.
the class DecisionTreeRegressionFromSparkExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws FileNotFoundException {
System.out.println();
System.out.println(">>> Decision tree regression model loaded from Spark through serialization over partitioned dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
IgniteCache<Integer, Vector> dataCache = null;
try {
dataCache = TitanicUtils.readPassengersWithoutNulls(ignite);
final Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>(0, 1, 5, 6).labeled(4);
DecisionTreeModel mdl = (DecisionTreeModel) SparkModelParser.parse(SPARK_MDL_PATH, SupportedSparkModels.DECISION_TREE_REGRESSION, env);
System.out.println(">>> Decision tree regression model: " + mdl);
System.out.println(">>> ---------------------------------");
System.out.println(">>> | Prediction\t| Ground Truth\t|");
System.out.println(">>> ---------------------------------");
try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
for (Cache.Entry<Integer, Vector> observation : observations) {
LabeledVector<Double> lv = vectorizer.apply(observation.getKey(), observation.getValue());
Vector inputs = lv.features();
double groundTruth = lv.label();
double prediction = mdl.predict(inputs);
System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
}
}
System.out.println(">>> ---------------------------------");
} finally {
dataCache.destroy();
}
}
}
use of org.apache.ignite.ml.tree.DecisionTreeModel in project ignite by apache.
the class DecisionTreeClassificationTrainerSQLInferenceExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws IOException {
System.out.println(">>> Decision tree classification trainer example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite-ml.xml")) {
System.out.println(">>> Ignite grid started.");
// Dummy cache is required to perform SQL queries.
CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME).setSqlSchema("PUBLIC").setSqlFunctionClasses(SQLFunctions.class);
IgniteCache<?, ?> cache = null;
try {
cache = ignite.getOrCreateCache(cacheCfg);
System.out.println(">>> Creating table with training data...");
cache.query(new SqlFieldsQuery("create table titanic_train (\n" + " passengerid int primary key,\n" + " pclass int,\n" + " survived int,\n" + " name varchar(255),\n" + " sex varchar(255),\n" + " age float,\n" + " sibsp int,\n" + " parch int,\n" + " ticket varchar(255),\n" + " fare float,\n" + " cabin varchar(255),\n" + " embarked varchar(255)\n" + ") with \"template=partitioned\";")).getAll();
System.out.println(">>> Creating table with test data...");
cache.query(new SqlFieldsQuery("create table titanic_test (\n" + " passengerid int primary key,\n" + " pclass int,\n" + " survived int,\n" + " name varchar(255),\n" + " sex varchar(255),\n" + " age float,\n" + " sibsp int,\n" + " parch int,\n" + " ticket varchar(255),\n" + " fare float,\n" + " cabin varchar(255),\n" + " embarked varchar(255)\n" + ") with \"template=partitioned\";")).getAll();
loadTitanicDatasets(ignite, cache);
System.out.println(">>> Prepare trainer...");
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
System.out.println(">>> Perform training...");
DecisionTreeModel mdl = trainer.fit(new SqlDatasetBuilder(ignite, "SQL_PUBLIC_TITANIC_TRAIN"), new BinaryObjectVectorizer<>("pclass", "age", "sibsp", "parch", "fare").withFeature("sex", BinaryObjectVectorizer.Mapping.create().map("male", 1.0).defaultValue(0.0)).labeled("survived"));
System.out.println(">>> Saving model...");
// Model storage is used to store raw serialized model.
System.out.println("Saving model into model storage...");
IgniteModelStorageUtil.saveModel(ignite, mdl, "titanic_model_tree");
// Making inference using saved model.
System.out.println("Inference...");
try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " + "survived as truth, " + "predict('titanic_model_tree', pclass, age, sibsp, parch, fare, case sex when 'male' then 1 else 0 end) as prediction" + " from titanic_train"))) {
// Print inference result.
System.out.println("| Truth | Prediction |");
System.out.println("|--------------------|");
for (List<?> row : cursor) System.out.println("| " + row.get(0) + " | " + row.get(1) + " |");
}
IgniteModelStorageUtil.removeModel(ignite, "titanic_model_tree");
} finally {
cache.query(new SqlFieldsQuery("DROP TABLE titanic_train"));
cache.query(new SqlFieldsQuery("DROP TABLE titanic_test"));
cache.destroy();
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.tree.DecisionTreeModel in project ignite by apache.
the class Step_1_Read_and_Learn method main.
/**
* Run example.
*/
public static void main(String[] args) {
System.out.println();
System.out.println(">>> Tutorial step 1 (read and learn) example started.");
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
try {
IgniteCache<Integer, Vector> dataCache = TitanicUtils.readPassengersWithoutNulls(ignite);
final Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>(0, 5, 6).labeled(1);
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0);
DecisionTreeModel mdl = trainer.fit(ignite, dataCache, vectorizer);
System.out.println("\n>>> Trained model: " + mdl);
double accuracy = Evaluator.evaluate(dataCache, mdl, vectorizer, new Accuracy<>());
System.out.println("\n>>> Accuracy " + accuracy);
System.out.println("\n>>> Test Error " + (1 - accuracy));
System.out.println(">>> Tutorial step 1 (read and learn) example completed.");
} catch (FileNotFoundException e) {
e.printStackTrace();
}
} finally {
System.out.flush();
}
}
Aggregations