use of org.apache.ignite.ml.environment.LearningEnvironmentBuilder in project ignite by apache.
the class MovieLensSQLExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws IOException {
System.out.println();
System.out.println(">>> Recommendation system over cache based dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite-ml.xml")) {
System.out.println(">>> Ignite grid started.");
// Dummy cache is required to perform SQL queries.
CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME).setSqlSchema("PUBLIC").setSqlFunctionClasses(SQLFunctions.class);
IgniteCache<?, ?> cache = null;
try {
cache = ignite.getOrCreateCache(cacheCfg);
System.out.println(">>> Creating table with training data...");
cache.query(new SqlFieldsQuery("create table ratings (\n" + " rating_id int primary key,\n" + " movie_id int,\n" + " user_id int,\n" + " rating float\n" + ") with \"template=partitioned\";")).getAll();
System.out.println(">>> Loading data...");
loadMovieLensDataset(ignite, cache, 10_000);
LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(1);
RecommendationTrainer trainer = new RecommendationTrainer().withMaxIterations(100).withBatchSize(10).withLearningRate(10).withLearningEnvironmentBuilder(envBuilder).withTrainerEnvironment(envBuilder.buildForTrainer());
System.out.println(">>> Training model...");
RecommendationModel<Serializable, Serializable> mdl = trainer.fit(new SqlDatasetBuilder(ignite, "SQL_PUBLIC_RATINGS"), "movie_id", "user_id", "rating");
System.out.println("Saving model into model storage...");
IgniteModelStorageUtil.saveModel(ignite, mdl, "movielens_model");
System.out.println("Inference...");
try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " + "rating, " + "predictRecommendation('movielens_model', movie_id, user_id) as prediction " + "from ratings"))) {
for (List<?> row : cursor) {
double rating = (Double) row.get(0);
double prediction = (Double) row.get(1);
System.out.println("Rating: " + rating + ", prediction: " + prediction);
}
}
} finally {
cache.query(new SqlFieldsQuery("DROP TABLE ratings"));
cache.destroy();
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.environment.LearningEnvironmentBuilder in project ignite by apache.
the class KNNUtils method buildDataset.
/**
* Builds dataset.
*
* @param envBuilder Learning environment builder.
* @param datasetBuilder Dataset builder.
* @param vectorizer Upstream vectorizer.
* @return Dataset.
*/
@Nullable
public static <K, V, C extends Serializable> Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> buildDataset(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> vectorizer) {
LearningEnvironment environment = envBuilder.buildForTrainer();
environment.initDeployingContext(vectorizer);
PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(vectorizer);
Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset = null;
if (datasetBuilder != null) {
dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, environment);
}
return dataset;
}
use of org.apache.ignite.ml.environment.LearningEnvironmentBuilder in project ignite by apache.
the class MeanAbsValueConvergenceCheckerTest method testConvergenceChecking.
/**
*/
@Test
public void testConvergenceChecking() {
LocalDatasetBuilder<Integer, LabeledVector<Double>> datasetBuilder = new LocalDatasetBuilder<>(data, 1);
ConvergenceChecker<Integer, LabeledVector<Double>> checker = createChecker(new MeanAbsValueConvergenceCheckerFactory(0.1), datasetBuilder);
double error = checker.computeError(VectorUtils.of(1, 2), 4.0, notConvergedMdl);
LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder();
Assert.assertEquals(1.9, error, 0.01);
Assert.assertFalse(checker.isConverged(envBuilder, datasetBuilder, notConvergedMdl));
Assert.assertTrue(checker.isConverged(envBuilder, datasetBuilder, convergedMdl));
try (LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(vectorizer), envBuilder.buildForTrainer())) {
double onDSError = checker.computeMeanErrorOnDataset(dataset, notConvergedMdl);
Assert.assertEquals(1.55, onDSError, 0.01);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
use of org.apache.ignite.ml.environment.LearningEnvironmentBuilder in project ignite by apache.
the class MinMaxScalerTrainer method fit.
/**
* {@inheritDoc}
*/
@Override
public MinMaxScalerPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> basePreprocessor) {
PartitionContextBuilder<K, V, EmptyContext> ctxBuilder = (env, upstream, upstreamSize) -> new EmptyContext();
try (Dataset<EmptyContext, MinMaxScalerPartitionData> dataset = datasetBuilder.build(envBuilder, ctxBuilder, (env, upstream, upstreamSize, ctx) -> {
double[] min = null;
double[] max = null;
while (upstream.hasNext()) {
UpstreamEntry<K, V> entity = upstream.next();
LabeledVector row = basePreprocessor.apply(entity.getKey(), entity.getValue());
if (min == null) {
min = new double[row.size()];
Arrays.fill(min, Double.MAX_VALUE);
} else
assert min.length == row.size() : "Base preprocessor must return exactly " + min.length + " features";
if (max == null) {
max = new double[row.size()];
Arrays.fill(max, -Double.MAX_VALUE);
} else
assert max.length == row.size() : "Base preprocessor must return exactly " + min.length + " features";
for (int i = 0; i < row.size(); i++) {
if (row.get(i) < min[i])
min[i] = row.get(i);
if (row.get(i) > max[i])
max[i] = row.get(i);
}
}
return new MinMaxScalerPartitionData(min, max);
}, learningEnvironment(basePreprocessor))) {
double[][] minMax = dataset.compute(data -> data.getMin() != null ? new double[][] { data.getMin(), data.getMax() } : null, (a, b) -> {
if (a == null)
return b;
if (b == null)
return a;
double[][] res = new double[2][];
res[0] = new double[a[0].length];
for (int i = 0; i < res[0].length; i++) res[0][i] = Math.min(a[0][i], b[0][i]);
res[1] = new double[a[1].length];
for (int i = 0; i < res[1].length; i++) res[1][i] = Math.max(a[1][i], b[1][i]);
return res;
});
return new MinMaxScalerPreprocessor<>(minMax[0], minMax[1], basePreprocessor);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
use of org.apache.ignite.ml.environment.LearningEnvironmentBuilder in project ignite by apache.
the class ComputeUtils method initContext.
/**
* Initializes partition {@code context} by loading it from a partition {@code upstream}.
* @param ignite Ignite instance.
* @param upstreamCacheName Name of an {@code upstream} cache.
* @param filter Filter for {@code upstream} data.
* @param transformerBuilder Upstream transformer builder.
* @param ctxBuilder Partition {@code context} builder.
* @param envBuilder Environment builder.
* @param isKeepBinary Support of binary objects.
* @param deployingCtx Deploy context.
* @param <K> Type of a key in {@code upstream} data.
* @param <V> Type of a value in {@code upstream} data.
* @param <C> Type of a partition {@code context}.
*/
public static <K, V, C extends Serializable> void initContext(Ignite ignite, String upstreamCacheName, UpstreamTransformerBuilder transformerBuilder, IgniteBiPredicate<K, V> filter, String datasetCacheName, PartitionContextBuilder<K, V, C> ctxBuilder, LearningEnvironmentBuilder envBuilder, int retries, int interval, boolean isKeepBinary, DeployingContext deployingCtx) {
affinityCallWithRetries(ignite, Arrays.asList(datasetCacheName, upstreamCacheName), part -> {
Ignite locIgnite = Ignition.localIgnite();
LearningEnvironment env = envBuilder.buildForWorker(part);
IgniteCache<K, V> locUpstreamCache = locIgnite.cache(upstreamCacheName);
if (isKeepBinary)
locUpstreamCache = locUpstreamCache.withKeepBinary();
ScanQuery<K, V> qry = new ScanQuery<>();
qry.setLocal(true);
qry.setPartition(part);
qry.setFilter(filter);
C ctx;
UpstreamTransformer transformer = transformerBuilder.build(env);
UpstreamTransformer transformerCp = Utils.copy(transformer);
long cnt = computeCount(locUpstreamCache, qry, transformer);
try (QueryCursor<UpstreamEntry<K, V>> cursor = locUpstreamCache.query(qry, e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) {
Iterator<UpstreamEntry<K, V>> it = cursor.iterator();
Stream<UpstreamEntry> transformedStream = transformerCp.transform(Utils.asStream(it, cnt).map(x -> (UpstreamEntry) x));
it = Utils.asStream(transformedStream.iterator()).map(x -> (UpstreamEntry<K, V>) x).iterator();
Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>(it, cnt, "Cache expected to be not modified during dataset data building [partition=" + part + ']');
ctx = ctxBuilder.build(env, iter, cnt);
}
IgniteCache<Integer, C> datasetCache = locIgnite.cache(datasetCacheName);
datasetCache.put(part, ctx);
return part;
}, retries, interval, deployingCtx);
}
Aggregations