use of org.apache.ignite.ml.environment.LearningEnvironmentBuilder in project ignite by apache.
the class LocalDatasetBuilder method build.
/**
* {@inheritDoc}
*/
@Override
public <C extends Serializable, D extends AutoCloseable> LocalDataset<C, D> build(LearningEnvironmentBuilder envBuilder, PartitionContextBuilder<K, V, C> partCtxBuilder, PartitionDataBuilder<K, V, C, D> partDataBuilder, LearningEnvironment learningEnvironment) {
List<C> ctxList = new ArrayList<>();
List<D> dataList = new ArrayList<>();
List<UpstreamEntry<K, V>> entriesList = new ArrayList<>();
upstreamMap.entrySet().stream().filter(en -> filter.apply(en.getKey(), en.getValue())).map(en -> new UpstreamEntry<>(en.getKey(), en.getValue())).forEach(entriesList::add);
int partSize = Math.max(1, entriesList.size() / partitions);
Iterator<UpstreamEntry<K, V>> firstKeysIter = entriesList.iterator();
Iterator<UpstreamEntry<K, V>> secondKeysIter = entriesList.iterator();
Iterator<UpstreamEntry<K, V>> thirdKeysIter = entriesList.iterator();
int ptr = 0;
List<LearningEnvironment> envs = IntStream.range(0, partitions).boxed().map(envBuilder::buildForWorker).collect(Collectors.toList());
for (int part = 0; part < partitions; part++) {
int cntBeforeTransform = part == partitions - 1 ? entriesList.size() - ptr : Math.min(partSize, entriesList.size() - ptr);
LearningEnvironment env = envs.get(part);
UpstreamTransformer transformer1 = upstreamTransformerBuilder.build(env);
UpstreamTransformer transformer2 = Utils.copy(transformer1);
UpstreamTransformer transformer3 = Utils.copy(transformer1);
int cnt = (int) transformer1.transform(Utils.asStream(new IteratorWindow<>(thirdKeysIter, k -> k, cntBeforeTransform))).count();
Iterator<UpstreamEntry> iter = transformer2.transform(Utils.asStream(new IteratorWindow<>(firstKeysIter, k -> k, cntBeforeTransform)).map(x -> (UpstreamEntry) x)).iterator();
Iterator<UpstreamEntry<K, V>> convertedBack = Utils.asStream(iter).map(x -> (UpstreamEntry<K, V>) x).iterator();
C ctx = cntBeforeTransform > 0 ? partCtxBuilder.build(env, convertedBack, cnt) : null;
Iterator<UpstreamEntry> iter1 = transformer3.transform(Utils.asStream(new IteratorWindow<>(secondKeysIter, k -> k, cntBeforeTransform))).iterator();
Iterator<UpstreamEntry<K, V>> convertedBack1 = Utils.asStream(iter1).map(x -> (UpstreamEntry<K, V>) x).iterator();
D data = cntBeforeTransform > 0 ? partDataBuilder.build(env, convertedBack1, cnt, ctx) : null;
ctxList.add(ctx);
dataList.add(data);
ptr += cntBeforeTransform;
}
return new LocalDataset<>(envs, ctxList, dataList);
}
use of org.apache.ignite.ml.environment.LearningEnvironmentBuilder in project ignite by apache.
the class MovieLensExample method main.
/**
* Run example.
*/
public static void main(String[] args) throws IOException {
System.out.println();
System.out.println(">>> Recommendation system over cache based dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
IgniteCache<Integer, RatingPoint> movielensCache = loadMovieLensDataset(ignite, 10_000);
try {
LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(1);
RecommendationTrainer trainer = new RecommendationTrainer().withMaxIterations(-1).withMinMdlImprovement(10).withBatchSize(10).withLearningRate(10).withLearningEnvironmentBuilder(envBuilder).withTrainerEnvironment(envBuilder.buildForTrainer());
RecommendationModel<Integer, Integer> mdl = trainer.fit(new CacheBasedDatasetBuilder<>(ignite, movielensCache));
double mean = 0;
try (QueryCursor<Cache.Entry<Integer, RatingPoint>> cursor = movielensCache.query(new ScanQuery<>())) {
for (Cache.Entry<Integer, RatingPoint> e : cursor) {
ObjectSubjectRatingTriplet<Integer, Integer> triplet = e.getValue();
mean += triplet.getRating();
}
mean /= movielensCache.size();
}
double tss = 0, rss = 0;
try (QueryCursor<Cache.Entry<Integer, RatingPoint>> cursor = movielensCache.query(new ScanQuery<>())) {
for (Cache.Entry<Integer, RatingPoint> e : cursor) {
ObjectSubjectRatingTriplet<Integer, Integer> triplet = e.getValue();
tss += Math.pow(triplet.getRating() - mean, 2);
rss += Math.pow(triplet.getRating() - mdl.predict(triplet), 2);
}
}
double r2 = 1.0 - rss / tss;
System.out.println("R2 score: " + r2);
} finally {
movielensCache.destroy();
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.environment.LearningEnvironmentBuilder in project ignite by apache.
the class SparkModelParser method parse.
/**
* Load model from parquet (presented as a directory).
*
* @param pathToMdl Path to directory with saved model.
* @param parsedSparkMdl Parsed spark model.
*/
public static Model parse(String pathToMdl, SupportedSparkModels parsedSparkMdl) throws IllegalArgumentException {
LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder();
LearningEnvironment environment = envBuilder.buildForTrainer();
return extractModel(pathToMdl, parsedSparkMdl, environment);
}
use of org.apache.ignite.ml.environment.LearningEnvironmentBuilder in project ignite by apache.
the class MedianOfMedianConvergenceCheckerTest method testConvergenceChecking.
/**
*/
@Test
public void testConvergenceChecking() {
data.put(666, VectorUtils.of(10, 11).labeled(100000.0));
LocalDatasetBuilder<Integer, LabeledVector<Double>> datasetBuilder = new LocalDatasetBuilder<>(data, 1);
ConvergenceChecker<Integer, LabeledVector<Double>> checker = createChecker(new MedianOfMedianConvergenceCheckerFactory(0.1), datasetBuilder);
double error = checker.computeError(VectorUtils.of(1, 2), 4.0, notConvergedMdl);
Assert.assertEquals(1.9, error, 0.01);
LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder();
Assert.assertFalse(checker.isConverged(envBuilder, datasetBuilder, notConvergedMdl));
Assert.assertTrue(checker.isConverged(envBuilder, datasetBuilder, convergedMdl));
try (LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(vectorizer), TestUtils.testEnvBuilder().buildForTrainer())) {
double onDSError = checker.computeMeanErrorOnDataset(dataset, notConvergedMdl);
Assert.assertEquals(1.6, onDSError, 0.01);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
use of org.apache.ignite.ml.environment.LearningEnvironmentBuilder in project ignite by apache.
the class ImputerTrainer method fit.
/**
* {@inheritDoc}
*/
@Override
public ImputerPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> basePreprocessor) {
PartitionContextBuilder<K, V, EmptyContext> builder = (env, upstream, upstreamSize) -> new EmptyContext();
try (Dataset<EmptyContext, ImputerPartitionData> dataset = datasetBuilder.build(envBuilder, builder, (env, upstream, upstreamSize, ctx) -> {
double[] sums = null;
int[] counts = null;
double[] maxs = null;
double[] mins = null;
Map<Double, Integer>[] valuesByFreq = null;
while (upstream.hasNext()) {
UpstreamEntry<K, V> entity = upstream.next();
LabeledVector row = basePreprocessor.apply(entity.getKey(), entity.getValue());
switch(imputingStgy) {
case MEAN:
sums = updateTheSums(row, sums);
counts = updateTheCounts(row, counts);
break;
case MOST_FREQUENT:
valuesByFreq = updateFrequenciesByGivenRow(row, valuesByFreq);
break;
case LEAST_FREQUENT:
valuesByFreq = updateFrequenciesByGivenRow(row, valuesByFreq);
break;
case MAX:
maxs = updateTheMaxs(row, maxs);
break;
case MIN:
mins = updateTheMins(row, mins);
break;
case COUNT:
counts = updateTheCounts(row, counts);
break;
default:
throw new UnsupportedOperationException("The chosen strategy is not supported");
}
}
ImputerPartitionData partData;
switch(imputingStgy) {
case MEAN:
partData = new ImputerPartitionData().withSums(sums).withCounts(counts);
break;
case MOST_FREQUENT:
partData = new ImputerPartitionData().withValuesByFrequency(valuesByFreq);
break;
case LEAST_FREQUENT:
partData = new ImputerPartitionData().withValuesByFrequency(valuesByFreq);
break;
case MAX:
partData = new ImputerPartitionData().withMaxs(maxs);
break;
case MIN:
partData = new ImputerPartitionData().withMins(mins);
break;
case COUNT:
partData = new ImputerPartitionData().withCounts(counts);
break;
default:
throw new UnsupportedOperationException("The chosen strategy is not supported");
}
return partData;
}, learningEnvironment(basePreprocessor))) {
Vector imputingValues;
switch(imputingStgy) {
case MEAN:
imputingValues = VectorUtils.of(calculateImputingValuesBySumsAndCounts(dataset));
break;
case MOST_FREQUENT:
imputingValues = VectorUtils.of(calculateImputingValuesByTheMostFrequentValues(dataset));
break;
case LEAST_FREQUENT:
imputingValues = VectorUtils.of(calculateImputingValuesByTheLeastFrequentValues(dataset));
break;
case MAX:
imputingValues = VectorUtils.of(calculateImputingValuesByMaxValues(dataset));
break;
case MIN:
imputingValues = VectorUtils.of(calculateImputingValuesByMinValues(dataset));
break;
case COUNT:
imputingValues = VectorUtils.of(calculateImputingValuesByCounts(dataset));
break;
default:
throw new UnsupportedOperationException("The chosen strategy is not supported");
}
return new ImputerPreprocessor<>(imputingValues, basePreprocessor);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
Aggregations