use of org.apache.ignite.ml.environment.LearningEnvironment in project ignite by apache.
the class LocalDataset method compute.
/**
* {@inheritDoc}
*/
@Override
public <R> R compute(IgniteBiFunction<D, LearningEnvironment, R> map, IgniteBinaryOperator<R> reduce, R identity) {
R res = identity;
for (int part = 0; part < data.size(); part++) {
D partData = data.get(part);
LearningEnvironment env = envs.get(part);
if (partData != null)
res = reduce.apply(res, map.apply(partData, env));
}
return res;
}
use of org.apache.ignite.ml.environment.LearningEnvironment in project ignite by apache.
the class LocalDatasetBuilder method build.
/**
* {@inheritDoc}
*/
@Override
public <C extends Serializable, D extends AutoCloseable> LocalDataset<C, D> build(LearningEnvironmentBuilder envBuilder, PartitionContextBuilder<K, V, C> partCtxBuilder, PartitionDataBuilder<K, V, C, D> partDataBuilder, LearningEnvironment learningEnvironment) {
List<C> ctxList = new ArrayList<>();
List<D> dataList = new ArrayList<>();
List<UpstreamEntry<K, V>> entriesList = new ArrayList<>();
upstreamMap.entrySet().stream().filter(en -> filter.apply(en.getKey(), en.getValue())).map(en -> new UpstreamEntry<>(en.getKey(), en.getValue())).forEach(entriesList::add);
int partSize = Math.max(1, entriesList.size() / partitions);
Iterator<UpstreamEntry<K, V>> firstKeysIter = entriesList.iterator();
Iterator<UpstreamEntry<K, V>> secondKeysIter = entriesList.iterator();
Iterator<UpstreamEntry<K, V>> thirdKeysIter = entriesList.iterator();
int ptr = 0;
List<LearningEnvironment> envs = IntStream.range(0, partitions).boxed().map(envBuilder::buildForWorker).collect(Collectors.toList());
for (int part = 0; part < partitions; part++) {
int cntBeforeTransform = part == partitions - 1 ? entriesList.size() - ptr : Math.min(partSize, entriesList.size() - ptr);
LearningEnvironment env = envs.get(part);
UpstreamTransformer transformer1 = upstreamTransformerBuilder.build(env);
UpstreamTransformer transformer2 = Utils.copy(transformer1);
UpstreamTransformer transformer3 = Utils.copy(transformer1);
int cnt = (int) transformer1.transform(Utils.asStream(new IteratorWindow<>(thirdKeysIter, k -> k, cntBeforeTransform))).count();
Iterator<UpstreamEntry> iter = transformer2.transform(Utils.asStream(new IteratorWindow<>(firstKeysIter, k -> k, cntBeforeTransform)).map(x -> (UpstreamEntry) x)).iterator();
Iterator<UpstreamEntry<K, V>> convertedBack = Utils.asStream(iter).map(x -> (UpstreamEntry<K, V>) x).iterator();
C ctx = cntBeforeTransform > 0 ? partCtxBuilder.build(env, convertedBack, cnt) : null;
Iterator<UpstreamEntry> iter1 = transformer3.transform(Utils.asStream(new IteratorWindow<>(secondKeysIter, k -> k, cntBeforeTransform))).iterator();
Iterator<UpstreamEntry<K, V>> convertedBack1 = Utils.asStream(iter1).map(x -> (UpstreamEntry<K, V>) x).iterator();
D data = cntBeforeTransform > 0 ? partDataBuilder.build(env, convertedBack1, cnt, ctx) : null;
ctxList.add(ctx);
dataList.add(data);
ptr += cntBeforeTransform;
}
return new LocalDataset<>(envs, ctxList, dataList);
}
use of org.apache.ignite.ml.environment.LearningEnvironment in project ignite by apache.
the class GDBOnTreesLearningStrategy method update.
/**
* {@inheritDoc}
*/
@Override
public <K, V> List<IgniteModel<Vector, Double>> update(GDBModel mdlToUpdate, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> vectorizer) {
LearningEnvironment environment = envBuilder.buildForTrainer();
environment.initDeployingContext(vectorizer);
DatasetTrainer<? extends IgniteModel<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get();
assert trainer instanceof DecisionTreeTrainer;
DecisionTreeTrainer decisionTreeTrainer = (DecisionTreeTrainer) trainer;
List<IgniteModel<Vector, Double>> models = initLearningState(mdlToUpdate);
ConvergenceChecker<K, V> convCheck = checkConvergenceStgyFactory.create(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, vectorizer);
try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new DecisionTreeDataBuilder<>(vectorizer, useIdx), environment)) {
for (int i = 0; i < cntOfIterations; i++) {
double[] weights = Arrays.copyOf(compositionWeights, models.size());
WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLbVal);
ModelsComposition currComposition = new ModelsComposition(models, aggregator);
if (convCheck.isConverged(dataset, currComposition))
break;
dataset.compute(part -> {
if (part.getCopiedOriginalLabels() == null)
part.setCopiedOriginalLabels(Arrays.copyOf(part.getLabels(), part.getLabels().length));
for (int j = 0; j < part.getLabels().length; j++) {
double mdlAnswer = currComposition.predict(VectorUtils.of(part.getFeatures()[j]));
double originalLbVal = externalLbToInternalMapping.apply(part.getCopiedOriginalLabels()[j]);
part.getLabels()[j] = -loss.gradient(sampleSize, originalLbVal, mdlAnswer);
}
});
long startTs = System.currentTimeMillis();
models.add(decisionTreeTrainer.fit(dataset));
double learningTime = (double) (System.currentTimeMillis() - startTs) / 1000.0;
trainerEnvironment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
compositionWeights = Arrays.copyOf(compositionWeights, models.size());
return models;
}
use of org.apache.ignite.ml.environment.LearningEnvironment in project ignite by apache.
the class DataStreamGeneratorFillCacheTest method testCacheFilling.
/**
*/
@Test
public void testCacheFilling() {
IgniteConfiguration configuration = new IgniteConfiguration().setDiscoverySpi(new TcpDiscoverySpi().setIpFinder(new TcpDiscoveryVmIpFinder().setAddresses(Arrays.asList("127.0.0.1:47500..47509"))));
String cacheName = "TEST_CACHE";
CacheConfiguration<UUID, LabeledVector<Double>> cacheConfiguration = new CacheConfiguration<UUID, LabeledVector<Double>>(cacheName).setAffinity(new RendezvousAffinityFunction(false, 10));
int datasetSize = 5000;
try (Ignite ignite = Ignition.start(configuration)) {
IgniteCache<UUID, LabeledVector<Double>> cache = ignite.getOrCreateCache(cacheConfiguration);
DataStreamGenerator generator = new GaussRandomProducer(0).vectorize(1).asDataStream();
generator.fillCacheWithVecUUIDAsKey(datasetSize, cache);
LabeledDummyVectorizer<UUID, Double> vectorizer = new LabeledDummyVectorizer<>();
CacheBasedDatasetBuilder<UUID, LabeledVector<Double>> datasetBuilder = new CacheBasedDatasetBuilder<>(ignite, cache);
IgniteFunction<SimpleDatasetData, StatPair> map = data -> new StatPair(DoubleStream.of(data.getFeatures()).sum(), data.getRows());
LearningEnvironment env = LearningEnvironmentBuilder.defaultBuilder().buildForTrainer();
env.deployingContext().initByClientObject(map);
try (CacheBasedDataset<UUID, LabeledVector<Double>, EmptyContext, SimpleDatasetData> dataset = datasetBuilder.build(LearningEnvironmentBuilder.defaultBuilder(), new EmptyContextBuilder<>(), new SimpleDatasetDataBuilder<>(vectorizer), env)) {
StatPair res = dataset.compute(map, StatPair::sum);
assertEquals(datasetSize, res.cntOfRows);
assertEquals(0.0, res.elementsSum / res.cntOfRows, 1e-2);
}
ignite.destroyCache(cacheName);
}
}
use of org.apache.ignite.ml.environment.LearningEnvironment in project ignite by apache.
the class SparkModelParser method parse.
/**
* Load model from parquet (presented as a directory).
*
* @param pathToMdl Path to directory with saved model.
* @param parsedSparkMdl Parsed spark model.
*/
public static Model parse(String pathToMdl, SupportedSparkModels parsedSparkMdl) throws IllegalArgumentException {
LearningEnvironmentBuilder envBuilder = LearningEnvironmentBuilder.defaultBuilder();
LearningEnvironment environment = envBuilder.buildForTrainer();
return extractModel(pathToMdl, parsedSparkMdl, environment);
}
Aggregations