Search in sources :

Example 1 with SqlDatasetBuilder

use of org.apache.ignite.ml.sql.SqlDatasetBuilder in project ignite by apache.

the class DecisionTreeClassificationTrainerSQLInferenceExample method main.

/**
 * Run example.
 */
public static void main(String[] args) throws IOException {
    System.out.println(">>> Decision tree classification trainer example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite-ml.xml")) {
        System.out.println(">>> Ignite grid started.");
        // Dummy cache is required to perform SQL queries.
        CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME).setSqlSchema("PUBLIC").setSqlFunctionClasses(SQLFunctions.class);
        IgniteCache<?, ?> cache = null;
        try {
            cache = ignite.getOrCreateCache(cacheCfg);
            System.out.println(">>> Creating table with training data...");
            cache.query(new SqlFieldsQuery("create table titanic_train (\n" + "    passengerid int primary key,\n" + "    pclass int,\n" + "    survived int,\n" + "    name varchar(255),\n" + "    sex varchar(255),\n" + "    age float,\n" + "    sibsp int,\n" + "    parch int,\n" + "    ticket varchar(255),\n" + "    fare float,\n" + "    cabin varchar(255),\n" + "    embarked varchar(255)\n" + ") with \"template=partitioned\";")).getAll();
            System.out.println(">>> Creating table with test data...");
            cache.query(new SqlFieldsQuery("create table titanic_test (\n" + "    passengerid int primary key,\n" + "    pclass int,\n" + "    survived int,\n" + "    name varchar(255),\n" + "    sex varchar(255),\n" + "    age float,\n" + "    sibsp int,\n" + "    parch int,\n" + "    ticket varchar(255),\n" + "    fare float,\n" + "    cabin varchar(255),\n" + "    embarked varchar(255)\n" + ") with \"template=partitioned\";")).getAll();
            loadTitanicDatasets(ignite, cache);
            System.out.println(">>> Prepare trainer...");
            DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
            System.out.println(">>> Perform training...");
            DecisionTreeModel mdl = trainer.fit(new SqlDatasetBuilder(ignite, "SQL_PUBLIC_TITANIC_TRAIN"), new BinaryObjectVectorizer<>("pclass", "age", "sibsp", "parch", "fare").withFeature("sex", BinaryObjectVectorizer.Mapping.create().map("male", 1.0).defaultValue(0.0)).labeled("survived"));
            System.out.println(">>> Saving model...");
            // Model storage is used to store raw serialized model.
            System.out.println("Saving model into model storage...");
            IgniteModelStorageUtil.saveModel(ignite, mdl, "titanic_model_tree");
            // Making inference using saved model.
            System.out.println("Inference...");
            try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " + "survived as truth, " + "predict('titanic_model_tree', pclass, age, sibsp, parch, fare, case sex when 'male' then 1 else 0 end) as prediction" + " from titanic_train"))) {
                // Print inference result.
                System.out.println("| Truth | Prediction |");
                System.out.println("|--------------------|");
                for (List<?> row : cursor) System.out.println("|     " + row.get(0) + " |        " + row.get(1) + " |");
            }
            IgniteModelStorageUtil.removeModel(ignite, "titanic_model_tree");
        } finally {
            cache.query(new SqlFieldsQuery("DROP TABLE titanic_train"));
            cache.query(new SqlFieldsQuery("DROP TABLE titanic_test"));
            cache.destroy();
        }
    } finally {
        System.out.flush();
    }
}
Also used : DecisionTreeClassificationTrainer(org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer) DecisionTreeModel(org.apache.ignite.ml.tree.DecisionTreeModel) Ignite(org.apache.ignite.Ignite) List(java.util.List) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) SqlFieldsQuery(org.apache.ignite.cache.query.SqlFieldsQuery) SqlDatasetBuilder(org.apache.ignite.ml.sql.SqlDatasetBuilder)

Example 2 with SqlDatasetBuilder

use of org.apache.ignite.ml.sql.SqlDatasetBuilder in project ignite by apache.

the class MovieLensSQLExample method main.

/**
 * Run example.
 */
public static void main(String[] args) throws IOException {
    System.out.println();
    System.out.println(">>> Recommendation system over cache based dataset usage example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite-ml.xml")) {
        System.out.println(">>> Ignite grid started.");
        // Dummy cache is required to perform SQL queries.
        CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME).setSqlSchema("PUBLIC").setSqlFunctionClasses(SQLFunctions.class);
        IgniteCache<?, ?> cache = null;
        try {
            cache = ignite.getOrCreateCache(cacheCfg);
            System.out.println(">>> Creating table with training data...");
            cache.query(new SqlFieldsQuery("create table ratings (\n" + "    rating_id int primary key,\n" + "    movie_id int,\n" + "    user_id int,\n" + "    rating float\n" + ") with \"template=partitioned\";")).getAll();
            System.out.println(">>> Loading data...");
            loadMovieLensDataset(ignite, cache, 10_000);
            LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(1);
            RecommendationTrainer trainer = new RecommendationTrainer().withMaxIterations(100).withBatchSize(10).withLearningRate(10).withLearningEnvironmentBuilder(envBuilder).withTrainerEnvironment(envBuilder.buildForTrainer());
            System.out.println(">>> Training model...");
            RecommendationModel<Serializable, Serializable> mdl = trainer.fit(new SqlDatasetBuilder(ignite, "SQL_PUBLIC_RATINGS"), "movie_id", "user_id", "rating");
            System.out.println("Saving model into model storage...");
            IgniteModelStorageUtil.saveModel(ignite, mdl, "movielens_model");
            System.out.println("Inference...");
            try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " + "rating, " + "predictRecommendation('movielens_model', movie_id, user_id) as prediction " + "from ratings"))) {
                for (List<?> row : cursor) {
                    double rating = (Double) row.get(0);
                    double prediction = (Double) row.get(1);
                    System.out.println("Rating: " + rating + ", prediction: " + prediction);
                }
            }
        } finally {
            cache.query(new SqlFieldsQuery("DROP TABLE ratings"));
            cache.destroy();
        }
    } finally {
        System.out.flush();
    }
}
Also used : Serializable(java.io.Serializable) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) RecommendationTrainer(org.apache.ignite.ml.recommendation.RecommendationTrainer) SqlFieldsQuery(org.apache.ignite.cache.query.SqlFieldsQuery) SqlDatasetBuilder(org.apache.ignite.ml.sql.SqlDatasetBuilder) Ignite(org.apache.ignite.Ignite) List(java.util.List) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration)

Example 3 with SqlDatasetBuilder

use of org.apache.ignite.ml.sql.SqlDatasetBuilder in project ignite by apache.

the class RecommendationTrainerSQLTest method testFit.

/**
 */
@Test
public void testFit() {
    // Dummy cache is required to perform SQL queries.
    CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME).setSqlSchema("PUBLIC").setSqlFunctionClasses(SQLFunctions.class);
    IgniteCache<?, ?> cache = null;
    try {
        cache = ignite.getOrCreateCache(cacheCfg);
        System.out.println(">>> Creating table with training data...");
        cache.query(new SqlFieldsQuery("create table ratings (\n" + "    rating_id int primary key,\n" + "    obj_id int,\n" + "    subj_id int,\n" + "    rating float\n" + ") with \"template=partitioned\";")).getAll();
        int size = 100;
        Random rnd = new Random(0L);
        SqlFieldsQuery qry = new SqlFieldsQuery("insert into ratings (rating_id, obj_id, subj_id, rating) values (?, ?, ?, ?)");
        // Quadrant I contains "0", quadrant II contains "1", quadrant III contains "0", quadrant IV contains "1".
        for (int i = 0; i < size; i++) {
            for (int j = 0; j < size; j++) {
                if (rnd.nextBoolean()) {
                    double rating = ((i > size / 2) ^ (j > size / 2)) ? 1.0 : 0.0;
                    qry.setArgs(i * size + j, i, j, rating);
                    cache.query(qry);
                }
            }
        }
        RecommendationTrainer trainer = new RecommendationTrainer().withMaxIterations(100).withLearningRate(50.0).withBatchSize(10).withK(2).withLearningEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(1)).withTrainerEnvironment(LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(1).buildForTrainer());
        RecommendationModel<Serializable, Serializable> mdl = trainer.fit(new SqlDatasetBuilder(ignite, "SQL_PUBLIC_RATINGS"), "obj_id", "subj_id", "rating");
        int incorrect = 0;
        for (int i = 0; i < size; i++) {
            for (int j = 0; j < size; j++) {
                if (rnd.nextBoolean()) {
                    double rating = ((i > size / 2) ^ (j > size / 2)) ? 1.0 : 0.0;
                    double prediction = mdl.predict(new ObjectSubjectPair<>(i, j));
                    if (Math.abs(prediction - rating) >= 1e-5)
                        incorrect++;
                }
            }
        }
        assertEquals(0, incorrect);
    } finally {
        cache.query(new SqlFieldsQuery("DROP TABLE ratings"));
        cache.destroy();
    }
}
Also used : Serializable(java.io.Serializable) Random(java.util.Random) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) SqlFieldsQuery(org.apache.ignite.cache.query.SqlFieldsQuery) SqlDatasetBuilder(org.apache.ignite.ml.sql.SqlDatasetBuilder) GridCommonAbstractTest(org.apache.ignite.testframework.junits.common.GridCommonAbstractTest) Test(org.junit.Test)

Example 4 with SqlDatasetBuilder

use of org.apache.ignite.ml.sql.SqlDatasetBuilder in project ignite by apache.

the class DecisionTreeClassificationTrainerSQLTableExample method main.

/**
 * Run example.
 */
public static void main(String[] args) throws IgniteCheckedException, IOException {
    System.out.println(">>> Decision tree classification trainer example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println(">>> Ignite grid started.");
        // Dummy cache is required to perform SQL queries.
        CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME).setSqlSchema("PUBLIC");
        IgniteCache<?, ?> cache = null;
        try {
            cache = ignite.getOrCreateCache(cacheCfg);
            System.out.println(">>> Creating table with training data...");
            cache.query(new SqlFieldsQuery("create table titanic_train (\n" + "    passengerid int primary key,\n" + "    pclass int,\n" + "    survived int,\n" + "    name varchar(255),\n" + "    sex varchar(255),\n" + "    age float,\n" + "    sibsp int,\n" + "    parch int,\n" + "    ticket varchar(255),\n" + "    fare float,\n" + "    cabin varchar(255),\n" + "    embarked varchar(255)\n" + ") with \"template=partitioned\";")).getAll();
            System.out.println(">>> Creating table with test data...");
            cache.query(new SqlFieldsQuery("create table titanic_test (\n" + "    passengerid int primary key,\n" + "    pclass int,\n" + "    survived int,\n" + "    name varchar(255),\n" + "    sex varchar(255),\n" + "    age float,\n" + "    sibsp int,\n" + "    parch int,\n" + "    ticket varchar(255),\n" + "    fare float,\n" + "    cabin varchar(255),\n" + "    embarked varchar(255)\n" + ") with \"template=partitioned\";")).getAll();
            loadTitanicDatasets(ignite, cache);
            System.out.println(">>> Prepare trainer...");
            DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
            System.out.println(">>> Perform training...");
            DecisionTreeModel mdl = trainer.fit(new SqlDatasetBuilder(ignite, "SQL_PUBLIC_TITANIC_TRAIN"), new BinaryObjectVectorizer<>("pclass", "age", "sibsp", "parch", "fare").withFeature("sex", BinaryObjectVectorizer.Mapping.create().map("male", 1.0).defaultValue(0.0)).labeled("survived"));
            System.out.println("Tree is here: " + mdl.toString(true));
            System.out.println(">>> Perform inference...");
            try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " + "pclass, " + "sex, " + "age, " + "sibsp, " + "parch, " + "fare from titanic_test"))) {
                for (List<?> passenger : cursor) {
                    Vector input = VectorUtils.of(new Double[] { asDouble(passenger.get(0)), "male".equals(passenger.get(1)) ? 1.0 : 0.0, asDouble(passenger.get(2)), asDouble(passenger.get(3)), asDouble(passenger.get(4)), asDouble(passenger.get(5)) });
                    double prediction = mdl.predict(input);
                    System.out.printf("Passenger %s will %s.\n", passenger, prediction == 0 ? "die" : "survive");
                }
            }
            System.out.println(">>> Example completed.");
        } finally {
            cache.query(new SqlFieldsQuery("DROP TABLE titanic_train"));
            cache.query(new SqlFieldsQuery("DROP TABLE titanic_test"));
            cache.destroy();
        }
    } finally {
        System.out.flush();
    }
}
Also used : DecisionTreeModel(org.apache.ignite.ml.tree.DecisionTreeModel) SqlFieldsQuery(org.apache.ignite.cache.query.SqlFieldsQuery) SqlDatasetBuilder(org.apache.ignite.ml.sql.SqlDatasetBuilder) DecisionTreeClassificationTrainer(org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer) Ignite(org.apache.ignite.Ignite) List(java.util.List) Vector(org.apache.ignite.ml.math.primitives.vector.Vector)

Aggregations

SqlFieldsQuery (org.apache.ignite.cache.query.SqlFieldsQuery)4 SqlDatasetBuilder (org.apache.ignite.ml.sql.SqlDatasetBuilder)4 List (java.util.List)3 Ignite (org.apache.ignite.Ignite)3 CacheConfiguration (org.apache.ignite.configuration.CacheConfiguration)3 Serializable (java.io.Serializable)2 DecisionTreeClassificationTrainer (org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer)2 DecisionTreeModel (org.apache.ignite.ml.tree.DecisionTreeModel)2 Random (java.util.Random)1 LearningEnvironmentBuilder (org.apache.ignite.ml.environment.LearningEnvironmentBuilder)1 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)1 RecommendationTrainer (org.apache.ignite.ml.recommendation.RecommendationTrainer)1 GridCommonAbstractTest (org.apache.ignite.testframework.junits.common.GridCommonAbstractTest)1 Test (org.junit.Test)1