Search in sources :

Example 1 with RecommendationTrainer

use of org.apache.ignite.ml.recommendation.RecommendationTrainer in project ignite by apache.

the class MovieLensSQLExample method main.

/**
 * Run example.
 */
public static void main(String[] args) throws IOException {
    System.out.println();
    System.out.println(">>> Recommendation system over cache based dataset usage example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite-ml.xml")) {
        System.out.println(">>> Ignite grid started.");
        // Dummy cache is required to perform SQL queries.
        CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME).setSqlSchema("PUBLIC").setSqlFunctionClasses(SQLFunctions.class);
        IgniteCache<?, ?> cache = null;
        try {
            cache = ignite.getOrCreateCache(cacheCfg);
            System.out.println(">>> Creating table with training data...");
            cache.query(new SqlFieldsQuery("create table ratings (\n" + "    rating_id int primary key,\n" + "    movie_id int,\n" + "    user_id int,\n" + "    rating float\n" + ") with \"template=partitioned\";")).getAll();
            System.out.println(">>> Loading data...");
            loadMovieLensDataset(ignite, cache, 10_000);
            LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(1);
            RecommendationTrainer trainer = new RecommendationTrainer().withMaxIterations(100).withBatchSize(10).withLearningRate(10).withLearningEnvironmentBuilder(envBuilder).withTrainerEnvironment(envBuilder.buildForTrainer());
            System.out.println(">>> Training model...");
            RecommendationModel<Serializable, Serializable> mdl = trainer.fit(new SqlDatasetBuilder(ignite, "SQL_PUBLIC_RATINGS"), "movie_id", "user_id", "rating");
            System.out.println("Saving model into model storage...");
            IgniteModelStorageUtil.saveModel(ignite, mdl, "movielens_model");
            System.out.println("Inference...");
            try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " + "rating, " + "predictRecommendation('movielens_model', movie_id, user_id) as prediction " + "from ratings"))) {
                for (List<?> row : cursor) {
                    double rating = (Double) row.get(0);
                    double prediction = (Double) row.get(1);
                    System.out.println("Rating: " + rating + ", prediction: " + prediction);
                }
            }
        } finally {
            cache.query(new SqlFieldsQuery("DROP TABLE ratings"));
            cache.destroy();
        }
    } finally {
        System.out.flush();
    }
}
Also used : Serializable(java.io.Serializable) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) RecommendationTrainer(org.apache.ignite.ml.recommendation.RecommendationTrainer) SqlFieldsQuery(org.apache.ignite.cache.query.SqlFieldsQuery) SqlDatasetBuilder(org.apache.ignite.ml.sql.SqlDatasetBuilder) Ignite(org.apache.ignite.Ignite) List(java.util.List) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration)

Example 2 with RecommendationTrainer

use of org.apache.ignite.ml.recommendation.RecommendationTrainer in project ignite by apache.

the class MovieLensExample method main.

/**
 * Run example.
 */
public static void main(String[] args) throws IOException {
    System.out.println();
    System.out.println(">>> Recommendation system over cache based dataset usage example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println(">>> Ignite grid started.");
        IgniteCache<Integer, RatingPoint> movielensCache = loadMovieLensDataset(ignite, 10_000);
        try {
            LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(1);
            RecommendationTrainer trainer = new RecommendationTrainer().withMaxIterations(-1).withMinMdlImprovement(10).withBatchSize(10).withLearningRate(10).withLearningEnvironmentBuilder(envBuilder).withTrainerEnvironment(envBuilder.buildForTrainer());
            RecommendationModel<Integer, Integer> mdl = trainer.fit(new CacheBasedDatasetBuilder<>(ignite, movielensCache));
            double mean = 0;
            try (QueryCursor<Cache.Entry<Integer, RatingPoint>> cursor = movielensCache.query(new ScanQuery<>())) {
                for (Cache.Entry<Integer, RatingPoint> e : cursor) {
                    ObjectSubjectRatingTriplet<Integer, Integer> triplet = e.getValue();
                    mean += triplet.getRating();
                }
                mean /= movielensCache.size();
            }
            double tss = 0, rss = 0;
            try (QueryCursor<Cache.Entry<Integer, RatingPoint>> cursor = movielensCache.query(new ScanQuery<>())) {
                for (Cache.Entry<Integer, RatingPoint> e : cursor) {
                    ObjectSubjectRatingTriplet<Integer, Integer> triplet = e.getValue();
                    tss += Math.pow(triplet.getRating() - mean, 2);
                    rss += Math.pow(triplet.getRating() - mdl.predict(triplet), 2);
                }
            }
            double r2 = 1.0 - rss / tss;
            System.out.println("R2 score: " + r2);
        } finally {
            movielensCache.destroy();
        }
    } finally {
        System.out.flush();
    }
}
Also used : LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) RecommendationTrainer(org.apache.ignite.ml.recommendation.RecommendationTrainer) Ignite(org.apache.ignite.Ignite) IgniteCache(org.apache.ignite.IgniteCache) SandboxMLCache(org.apache.ignite.examples.ml.util.SandboxMLCache) Cache(javax.cache.Cache)

Aggregations

Ignite (org.apache.ignite.Ignite)2 LearningEnvironmentBuilder (org.apache.ignite.ml.environment.LearningEnvironmentBuilder)2 RecommendationTrainer (org.apache.ignite.ml.recommendation.RecommendationTrainer)2 Serializable (java.io.Serializable)1 List (java.util.List)1 Cache (javax.cache.Cache)1 IgniteCache (org.apache.ignite.IgniteCache)1 SqlFieldsQuery (org.apache.ignite.cache.query.SqlFieldsQuery)1 CacheConfiguration (org.apache.ignite.configuration.CacheConfiguration)1 SandboxMLCache (org.apache.ignite.examples.ml.util.SandboxMLCache)1 SqlDatasetBuilder (org.apache.ignite.ml.sql.SqlDatasetBuilder)1