use of org.apache.ignite.ml.recommendation.RecommendationTrainer in project ignite by apache.
the class MovieLensSQLExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws IOException {
System.out.println();
System.out.println(">>> Recommendation system over cache based dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite-ml.xml")) {
System.out.println(">>> Ignite grid started.");
// Dummy cache is required to perform SQL queries.
CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME).setSqlSchema("PUBLIC").setSqlFunctionClasses(SQLFunctions.class);
IgniteCache<?, ?> cache = null;
try {
cache = ignite.getOrCreateCache(cacheCfg);
System.out.println(">>> Creating table with training data...");
cache.query(new SqlFieldsQuery("create table ratings (\n" + " rating_id int primary key,\n" + " movie_id int,\n" + " user_id int,\n" + " rating float\n" + ") with \"template=partitioned\";")).getAll();
System.out.println(">>> Loading data...");
loadMovieLensDataset(ignite, cache, 10_000);
LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(1);
RecommendationTrainer trainer = new RecommendationTrainer().withMaxIterations(100).withBatchSize(10).withLearningRate(10).withLearningEnvironmentBuilder(envBuilder).withTrainerEnvironment(envBuilder.buildForTrainer());
System.out.println(">>> Training model...");
RecommendationModel<Serializable, Serializable> mdl = trainer.fit(new SqlDatasetBuilder(ignite, "SQL_PUBLIC_RATINGS"), "movie_id", "user_id", "rating");
System.out.println("Saving model into model storage...");
IgniteModelStorageUtil.saveModel(ignite, mdl, "movielens_model");
System.out.println("Inference...");
try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " + "rating, " + "predictRecommendation('movielens_model', movie_id, user_id) as prediction " + "from ratings"))) {
for (List<?> row : cursor) {
double rating = (Double) row.get(0);
double prediction = (Double) row.get(1);
System.out.println("Rating: " + rating + ", prediction: " + prediction);
}
}
} finally {
cache.query(new SqlFieldsQuery("DROP TABLE ratings"));
cache.destroy();
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.recommendation.RecommendationTrainer in project ignite by apache.
the class MovieLensExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws IOException {
System.out.println();
System.out.println(">>> Recommendation system over cache based dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
IgniteCache<Integer, RatingPoint> movielensCache = loadMovieLensDataset(ignite, 10_000);
try {
LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(1);
RecommendationTrainer trainer = new RecommendationTrainer().withMaxIterations(-1).withMinMdlImprovement(10).withBatchSize(10).withLearningRate(10).withLearningEnvironmentBuilder(envBuilder).withTrainerEnvironment(envBuilder.buildForTrainer());
RecommendationModel<Integer, Integer> mdl = trainer.fit(new CacheBasedDatasetBuilder<>(ignite, movielensCache));
double mean = 0;
try (QueryCursor<Cache.Entry<Integer, RatingPoint>> cursor = movielensCache.query(new ScanQuery<>())) {
for (Cache.Entry<Integer, RatingPoint> e : cursor) {
ObjectSubjectRatingTriplet<Integer, Integer> triplet = e.getValue();
mean += triplet.getRating();
}
mean /= movielensCache.size();
}
double tss = 0, rss = 0;
try (QueryCursor<Cache.Entry<Integer, RatingPoint>> cursor = movielensCache.query(new ScanQuery<>())) {
for (Cache.Entry<Integer, RatingPoint> e : cursor) {
ObjectSubjectRatingTriplet<Integer, Integer> triplet = e.getValue();
tss += Math.pow(triplet.getRating() - mean, 2);
rss += Math.pow(triplet.getRating() - mdl.predict(triplet), 2);
}
}
double r2 = 1.0 - rss / tss;
System.out.println("R2 score: " + r2);
} finally {
movielensCache.destroy();
}
} finally {
System.out.flush();
}
}
Aggregations