use of org.apache.ignite.ml.sql.SqlDatasetBuilder in project ignite by apache.
the class DecisionTreeClassificationTrainerSQLInferenceExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws IOException {
System.out.println(">>> Decision tree classification trainer example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite-ml.xml")) {
System.out.println(">>> Ignite grid started.");
// Dummy cache is required to perform SQL queries.
CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME).setSqlSchema("PUBLIC").setSqlFunctionClasses(SQLFunctions.class);
IgniteCache<?, ?> cache = null;
try {
cache = ignite.getOrCreateCache(cacheCfg);
System.out.println(">>> Creating table with training data...");
cache.query(new SqlFieldsQuery("create table titanic_train (\n" + " passengerid int primary key,\n" + " pclass int,\n" + " survived int,\n" + " name varchar(255),\n" + " sex varchar(255),\n" + " age float,\n" + " sibsp int,\n" + " parch int,\n" + " ticket varchar(255),\n" + " fare float,\n" + " cabin varchar(255),\n" + " embarked varchar(255)\n" + ") with \"template=partitioned\";")).getAll();
System.out.println(">>> Creating table with test data...");
cache.query(new SqlFieldsQuery("create table titanic_test (\n" + " passengerid int primary key,\n" + " pclass int,\n" + " survived int,\n" + " name varchar(255),\n" + " sex varchar(255),\n" + " age float,\n" + " sibsp int,\n" + " parch int,\n" + " ticket varchar(255),\n" + " fare float,\n" + " cabin varchar(255),\n" + " embarked varchar(255)\n" + ") with \"template=partitioned\";")).getAll();
loadTitanicDatasets(ignite, cache);
System.out.println(">>> Prepare trainer...");
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
System.out.println(">>> Perform training...");
DecisionTreeModel mdl = trainer.fit(new SqlDatasetBuilder(ignite, "SQL_PUBLIC_TITANIC_TRAIN"), new BinaryObjectVectorizer<>("pclass", "age", "sibsp", "parch", "fare").withFeature("sex", BinaryObjectVectorizer.Mapping.create().map("male", 1.0).defaultValue(0.0)).labeled("survived"));
System.out.println(">>> Saving model...");
// Model storage is used to store raw serialized model.
System.out.println("Saving model into model storage...");
IgniteModelStorageUtil.saveModel(ignite, mdl, "titanic_model_tree");
// Making inference using saved model.
System.out.println("Inference...");
try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " + "survived as truth, " + "predict('titanic_model_tree', pclass, age, sibsp, parch, fare, case sex when 'male' then 1 else 0 end) as prediction" + " from titanic_train"))) {
// Print inference result.
System.out.println("| Truth | Prediction |");
System.out.println("|--------------------|");
for (List<?> row : cursor) System.out.println("| " + row.get(0) + " | " + row.get(1) + " |");
}
IgniteModelStorageUtil.removeModel(ignite, "titanic_model_tree");
} finally {
cache.query(new SqlFieldsQuery("DROP TABLE titanic_train"));
cache.query(new SqlFieldsQuery("DROP TABLE titanic_test"));
cache.destroy();
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.sql.SqlDatasetBuilder in project ignite by apache.
the class MovieLensSQLExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws IOException {
System.out.println();
System.out.println(">>> Recommendation system over cache based dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite-ml.xml")) {
System.out.println(">>> Ignite grid started.");
// Dummy cache is required to perform SQL queries.
CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME).setSqlSchema("PUBLIC").setSqlFunctionClasses(SQLFunctions.class);
IgniteCache<?, ?> cache = null;
try {
cache = ignite.getOrCreateCache(cacheCfg);
System.out.println(">>> Creating table with training data...");
cache.query(new SqlFieldsQuery("create table ratings (\n" + " rating_id int primary key,\n" + " movie_id int,\n" + " user_id int,\n" + " rating float\n" + ") with \"template=partitioned\";")).getAll();
System.out.println(">>> Loading data...");
loadMovieLensDataset(ignite, cache, 10_000);
LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(1);
RecommendationTrainer trainer = new RecommendationTrainer().withMaxIterations(100).withBatchSize(10).withLearningRate(10).withLearningEnvironmentBuilder(envBuilder).withTrainerEnvironment(envBuilder.buildForTrainer());
System.out.println(">>> Training model...");
RecommendationModel<Serializable, Serializable> mdl = trainer.fit(new SqlDatasetBuilder(ignite, "SQL_PUBLIC_RATINGS"), "movie_id", "user_id", "rating");
System.out.println("Saving model into model storage...");
IgniteModelStorageUtil.saveModel(ignite, mdl, "movielens_model");
System.out.println("Inference...");
try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " + "rating, " + "predictRecommendation('movielens_model', movie_id, user_id) as prediction " + "from ratings"))) {
for (List<?> row : cursor) {
double rating = (Double) row.get(0);
double prediction = (Double) row.get(1);
System.out.println("Rating: " + rating + ", prediction: " + prediction);
}
}
} finally {
cache.query(new SqlFieldsQuery("DROP TABLE ratings"));
cache.destroy();
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.sql.SqlDatasetBuilder in project ignite by apache.
the class RecommendationTrainerSQLTest method testFit.
/**
*/
@Test
public void testFit() {
// Dummy cache is required to perform SQL queries.
CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME).setSqlSchema("PUBLIC").setSqlFunctionClasses(SQLFunctions.class);
IgniteCache<?, ?> cache = null;
try {
cache = ignite.getOrCreateCache(cacheCfg);
System.out.println(">>> Creating table with training data...");
cache.query(new SqlFieldsQuery("create table ratings (\n" + " rating_id int primary key,\n" + " obj_id int,\n" + " subj_id int,\n" + " rating float\n" + ") with \"template=partitioned\";")).getAll();
int size = 100;
Random rnd = new Random(0L);
SqlFieldsQuery qry = new SqlFieldsQuery("insert into ratings (rating_id, obj_id, subj_id, rating) values (?, ?, ?, ?)");
// Quadrant I contains "0", quadrant II contains "1", quadrant III contains "0", quadrant IV contains "1".
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
if (rnd.nextBoolean()) {
double rating = ((i > size / 2) ^ (j > size / 2)) ? 1.0 : 0.0;
qry.setArgs(i * size + j, i, j, rating);
cache.query(qry);
}
}
}
RecommendationTrainer trainer = new RecommendationTrainer().withMaxIterations(100).withLearningRate(50.0).withBatchSize(10).withK(2).withLearningEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(1)).withTrainerEnvironment(LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(1).buildForTrainer());
RecommendationModel<Serializable, Serializable> mdl = trainer.fit(new SqlDatasetBuilder(ignite, "SQL_PUBLIC_RATINGS"), "obj_id", "subj_id", "rating");
int incorrect = 0;
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
if (rnd.nextBoolean()) {
double rating = ((i > size / 2) ^ (j > size / 2)) ? 1.0 : 0.0;
double prediction = mdl.predict(new ObjectSubjectPair<>(i, j));
if (Math.abs(prediction - rating) >= 1e-5)
incorrect++;
}
}
}
assertEquals(0, incorrect);
} finally {
cache.query(new SqlFieldsQuery("DROP TABLE ratings"));
cache.destroy();
}
}
use of org.apache.ignite.ml.sql.SqlDatasetBuilder in project ignite by apache.
the class DecisionTreeClassificationTrainerSQLTableExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws IgniteCheckedException, IOException {
System.out.println(">>> Decision tree classification trainer example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
// Dummy cache is required to perform SQL queries.
CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME).setSqlSchema("PUBLIC");
IgniteCache<?, ?> cache = null;
try {
cache = ignite.getOrCreateCache(cacheCfg);
System.out.println(">>> Creating table with training data...");
cache.query(new SqlFieldsQuery("create table titanic_train (\n" + " passengerid int primary key,\n" + " pclass int,\n" + " survived int,\n" + " name varchar(255),\n" + " sex varchar(255),\n" + " age float,\n" + " sibsp int,\n" + " parch int,\n" + " ticket varchar(255),\n" + " fare float,\n" + " cabin varchar(255),\n" + " embarked varchar(255)\n" + ") with \"template=partitioned\";")).getAll();
System.out.println(">>> Creating table with test data...");
cache.query(new SqlFieldsQuery("create table titanic_test (\n" + " passengerid int primary key,\n" + " pclass int,\n" + " survived int,\n" + " name varchar(255),\n" + " sex varchar(255),\n" + " age float,\n" + " sibsp int,\n" + " parch int,\n" + " ticket varchar(255),\n" + " fare float,\n" + " cabin varchar(255),\n" + " embarked varchar(255)\n" + ") with \"template=partitioned\";")).getAll();
loadTitanicDatasets(ignite, cache);
System.out.println(">>> Prepare trainer...");
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
System.out.println(">>> Perform training...");
DecisionTreeModel mdl = trainer.fit(new SqlDatasetBuilder(ignite, "SQL_PUBLIC_TITANIC_TRAIN"), new BinaryObjectVectorizer<>("pclass", "age", "sibsp", "parch", "fare").withFeature("sex", BinaryObjectVectorizer.Mapping.create().map("male", 1.0).defaultValue(0.0)).labeled("survived"));
System.out.println("Tree is here: " + mdl.toString(true));
System.out.println(">>> Perform inference...");
try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " + "pclass, " + "sex, " + "age, " + "sibsp, " + "parch, " + "fare from titanic_test"))) {
for (List<?> passenger : cursor) {
Vector input = VectorUtils.of(new Double[] { asDouble(passenger.get(0)), "male".equals(passenger.get(1)) ? 1.0 : 0.0, asDouble(passenger.get(2)), asDouble(passenger.get(3)), asDouble(passenger.get(4)), asDouble(passenger.get(5)) });
double prediction = mdl.predict(input);
System.out.printf("Passenger %s will %s.\n", passenger, prediction == 0 ? "die" : "survive");
}
}
System.out.println(">>> Example completed.");
} finally {
cache.query(new SqlFieldsQuery("DROP TABLE titanic_train"));
cache.query(new SqlFieldsQuery("DROP TABLE titanic_test"));
cache.destroy();
}
} finally {
System.out.flush();
}
}
Aggregations