use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.
the class LogisticRegressionSGDTrainer method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> LogisticRegressionModel updateModel(LogisticRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> {
Integer cols = dataset.compute(data -> {
if (data.getFeatures() == null)
return null;
return data.getFeatures().length / data.getRows();
}, (a, b) -> {
// If both are null then zero will be propagated, no good.
if (a == null)
return b;
return a;
});
if (cols == null)
throw new IllegalStateException("Cannot train on empty dataset");
MLPArchitecture architecture = new MLPArchitecture(cols);
architecture = architecture.withAddedLayer(1, true, Activators.SIGMOID);
return architecture;
};
MLPTrainer<?> trainer = new MLPTrainer<>(archSupplier, LossFunctions.L2, updatesStgy, maxIterations, batchSize, locIterations, seed).withEnvironmentBuilder(envBuilder);
MultilayerPerceptron mlp;
IgniteFunction<LabeledVector<Double>, LabeledVector<double[]>> func = lv -> new LabeledVector<>(lv.features(), new double[] { lv.label() });
PatchedPreprocessor<K, V, Double, double[]> patchedPreprocessor = new PatchedPreprocessor<>(func, extractor);
if (mdl != null) {
mlp = restoreMLPState(mdl);
mlp = trainer.update(mlp, datasetBuilder, patchedPreprocessor);
} else
mlp = trainer.fit(datasetBuilder, patchedPreprocessor);
double[] params = mlp.parameters().getStorage().data();
return new LogisticRegressionModel(new DenseVector(Arrays.copyOf(params, params.length - 1)), params[params.length - 1]);
}
use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.
the class Deltas method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> SVMLinearClassificationModel updateModel(SVMLinearClassificationModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
assert datasetBuilder != null;
IgniteFunction<Double, Double> lbTransformer = lb -> {
if (lb == 0.0)
return -1.0;
else
return lb;
};
IgniteFunction<LabeledVector<Double>, LabeledVector<Double>> func = lv -> new LabeledVector<>(lv.features(), lbTransformer.apply(lv.label()));
PatchedPreprocessor<K, V, Double, Double> patchedPreprocessor = new PatchedPreprocessor<>(func, preprocessor);
PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(patchedPreprocessor);
Vector weights;
try (Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, learningEnvironment())) {
if (mdl == null) {
final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
if (a == null)
return b == null ? 0 : b;
if (b == null)
return a;
return b;
});
final int weightVectorSizeWithIntercept = cols + 1;
weights = initializeWeightsWithZeros(weightVectorSizeWithIntercept);
} else
weights = getStateVector(mdl);
for (int i = 0; i < this.getAmountOfIterations(); i++) {
Vector deltaWeights = calculateUpdates(weights, dataset);
if (deltaWeights == null)
return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
// creates new vector
weights = weights.plus(deltaWeights);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
return new SVMLinearClassificationModel(weights.copyOfRange(1, weights.size()), weights.get(0));
}
use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.
the class GDBOnTreesLearningStrategy method update.
/**
* {@inheritDoc}
*/
@Override
public <K, V> List<IgniteModel<Vector, Double>> update(GDBModel mdlToUpdate, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> vectorizer) {
LearningEnvironment environment = envBuilder.buildForTrainer();
environment.initDeployingContext(vectorizer);
DatasetTrainer<? extends IgniteModel<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get();
assert trainer instanceof DecisionTreeTrainer;
DecisionTreeTrainer decisionTreeTrainer = (DecisionTreeTrainer) trainer;
List<IgniteModel<Vector, Double>> models = initLearningState(mdlToUpdate);
ConvergenceChecker<K, V> convCheck = checkConvergenceStgyFactory.create(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, vectorizer);
try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new DecisionTreeDataBuilder<>(vectorizer, useIdx), environment)) {
for (int i = 0; i < cntOfIterations; i++) {
double[] weights = Arrays.copyOf(compositionWeights, models.size());
WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLbVal);
ModelsComposition currComposition = new ModelsComposition(models, aggregator);
if (convCheck.isConverged(dataset, currComposition))
break;
dataset.compute(part -> {
if (part.getCopiedOriginalLabels() == null)
part.setCopiedOriginalLabels(Arrays.copyOf(part.getLabels(), part.getLabels().length));
for (int j = 0; j < part.getLabels().length; j++) {
double mdlAnswer = currComposition.predict(VectorUtils.of(part.getFeatures()[j]));
double originalLbVal = externalLbToInternalMapping.apply(part.getCopiedOriginalLabels()[j]);
part.getLabels()[j] = -loss.gradient(sampleSize, originalLbVal, mdlAnswer);
}
});
long startTs = System.currentTimeMillis();
models.add(decisionTreeTrainer.fit(dataset));
double learningTime = (double) (System.currentTimeMillis() - startTs) / 1000.0;
trainerEnvironment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
compositionWeights = Arrays.copyOf(compositionWeights, models.size());
return models;
}
use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.
the class DataStreamGeneratorFillCacheTest method testCacheFilling.
/**
*/
@Test
public void testCacheFilling() {
IgniteConfiguration configuration = new IgniteConfiguration().setDiscoverySpi(new TcpDiscoverySpi().setIpFinder(new TcpDiscoveryVmIpFinder().setAddresses(Arrays.asList("127.0.0.1:47500..47509"))));
String cacheName = "TEST_CACHE";
CacheConfiguration<UUID, LabeledVector<Double>> cacheConfiguration = new CacheConfiguration<UUID, LabeledVector<Double>>(cacheName).setAffinity(new RendezvousAffinityFunction(false, 10));
int datasetSize = 5000;
try (Ignite ignite = Ignition.start(configuration)) {
IgniteCache<UUID, LabeledVector<Double>> cache = ignite.getOrCreateCache(cacheConfiguration);
DataStreamGenerator generator = new GaussRandomProducer(0).vectorize(1).asDataStream();
generator.fillCacheWithVecUUIDAsKey(datasetSize, cache);
LabeledDummyVectorizer<UUID, Double> vectorizer = new LabeledDummyVectorizer<>();
CacheBasedDatasetBuilder<UUID, LabeledVector<Double>> datasetBuilder = new CacheBasedDatasetBuilder<>(ignite, cache);
IgniteFunction<SimpleDatasetData, StatPair> map = data -> new StatPair(DoubleStream.of(data.getFeatures()).sum(), data.getRows());
LearningEnvironment env = LearningEnvironmentBuilder.defaultBuilder().buildForTrainer();
env.deployingContext().initByClientObject(map);
try (CacheBasedDataset<UUID, LabeledVector<Double>, EmptyContext, SimpleDatasetData> dataset = datasetBuilder.build(LearningEnvironmentBuilder.defaultBuilder(), new EmptyContextBuilder<>(), new SimpleDatasetDataBuilder<>(vectorizer), env)) {
StatPair res = dataset.compute(map, StatPair::sum);
assertEquals(datasetSize, res.cntOfRows);
assertEquals(0.0, res.elementsSum / res.cntOfRows, 1e-2);
}
ignite.destroyCache(cacheName);
}
}
use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.
the class ConvergenceChecker method isConverged.
/**
* Checks convergency on dataset.
*
* @param envBuilder Learning environment builder.
* @param currMdl Current model.
* @return True if GDB is converged.
*/
public boolean isConverged(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, ModelsComposition currMdl) {
LearningEnvironment environment = envBuilder.buildForTrainer();
environment.initDeployingContext(preprocessor);
try (Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(preprocessor), environment)) {
return isConverged(dataset, currMdl);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
Aggregations