use of org.apache.ignite.ml.math.functions.IgniteFunction in project ignite by apache.
the class ColumnDecisionTreeTrainerBenchmark method tstF1.
/**
* Test decision tree regression.
* To run this test rename this method so it starts from 'test'.
*/
public void tstF1() {
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
int ptsCnt = 10000;
Map<Integer, double[]> ranges = new HashMap<>();
ranges.put(0, new double[] { -100.0, 100.0 });
ranges.put(1, new double[] { -100.0, 100.0 });
ranges.put(2, new double[] { -100.0, 100.0 });
int featCnt = 100;
double[] defRng = { -1.0, 1.0 };
Vector[] trainVectors = vecsFromRanges(ranges, featCnt, defRng, new Random(123L), ptsCnt, f1);
SparseDistributedMatrix m = new SparseDistributedMatrix(ptsCnt, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
SparseDistributedMatrixStorage sto = (SparseDistributedMatrixStorage) m.getStorage();
loadVectorsIntoSparseDistributedMatrixCache(sto.cache().getName(), sto.getUUID(), Arrays.stream(trainVectors).iterator(), featCnt + 1);
IgniteFunction<DoubleStream, Double> regCalc = s -> s.average().orElse(0.0);
ColumnDecisionTreeTrainer<VarianceSplitCalculator.VarianceData> trainer = new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.VARIANCE, RegionCalculators.VARIANCE, regCalc, ignite);
X.println("Training started.");
long before = System.currentTimeMillis();
DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, new HashMap<>()));
X.println("Training finished in: " + (System.currentTimeMillis() - before) + " ms.");
Vector[] testVectors = vecsFromRanges(ranges, featCnt, defRng, new Random(123L), 20, f1);
IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.MSE();
Double accuracy = mse.apply(mdl, Arrays.stream(testVectors).map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity());
X.println("MSE: " + accuracy);
}
use of org.apache.ignite.ml.math.functions.IgniteFunction in project ignite by apache.
the class GDBOnTreesRegressionExportImportExample method main.
/**
* Run example.
*
* @param args Command line arguments, none required.
*/
public static void main(String[] args) throws IOException {
System.out.println();
System.out.println(">>> GDB regression trainer example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
// Create cache with training data.
CacheConfiguration<Integer, double[]> trainingSetCfg = createCacheConfiguration();
IgniteCache<Integer, double[]> trainingSet = null;
Path jsonMdlPath = null;
try {
trainingSet = fillTrainingData(ignite, trainingSetCfg);
// Create regression trainer.
GDBTrainer trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.).withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.001));
// Train decision tree model.
GDBModel mdl = trainer.fit(ignite, trainingSet, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST));
System.out.println("\n>>> Exported GDB regression model: " + mdl.toString(true));
predictOnGeneratedData(mdl);
jsonMdlPath = Files.createTempFile(null, null);
mdl.toJSON(jsonMdlPath);
IgniteFunction<Double, Double> lbMapper = lb -> lb;
GDBModel modelImportedFromJSON = GDBModel.fromJSON(jsonMdlPath).withLblMapping(lbMapper);
System.out.println("\n>>> Imported GDB regression model: " + modelImportedFromJSON.toString(true));
predictOnGeneratedData(modelImportedFromJSON);
System.out.println(">>> GDB regression trainer example completed.");
} finally {
if (trainingSet != null)
trainingSet.destroy();
if (jsonMdlPath != null)
Files.deleteIfExists(jsonMdlPath);
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.math.functions.IgniteFunction in project ignite by apache.
the class LogisticRegressionSGDTrainer method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> LogisticRegressionModel updateModel(LogisticRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> {
Integer cols = dataset.compute(data -> {
if (data.getFeatures() == null)
return null;
return data.getFeatures().length / data.getRows();
}, (a, b) -> {
// If both are null then zero will be propagated, no good.
if (a == null)
return b;
return a;
});
if (cols == null)
throw new IllegalStateException("Cannot train on empty dataset");
MLPArchitecture architecture = new MLPArchitecture(cols);
architecture = architecture.withAddedLayer(1, true, Activators.SIGMOID);
return architecture;
};
MLPTrainer<?> trainer = new MLPTrainer<>(archSupplier, LossFunctions.L2, updatesStgy, maxIterations, batchSize, locIterations, seed).withEnvironmentBuilder(envBuilder);
MultilayerPerceptron mlp;
IgniteFunction<LabeledVector<Double>, LabeledVector<double[]>> func = lv -> new LabeledVector<>(lv.features(), new double[] { lv.label() });
PatchedPreprocessor<K, V, Double, double[]> patchedPreprocessor = new PatchedPreprocessor<>(func, extractor);
if (mdl != null) {
mlp = restoreMLPState(mdl);
mlp = trainer.update(mlp, datasetBuilder, patchedPreprocessor);
} else
mlp = trainer.fit(datasetBuilder, patchedPreprocessor);
double[] params = mlp.parameters().getStorage().data();
return new LogisticRegressionModel(new DenseVector(Arrays.copyOf(params, params.length - 1)), params[params.length - 1]);
}
use of org.apache.ignite.ml.math.functions.IgniteFunction in project ignite by apache.
the class Deltas method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> SVMLinearClassificationModel updateModel(SVMLinearClassificationModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
assert datasetBuilder != null;
IgniteFunction<Double, Double> lbTransformer = lb -> {
if (lb == 0.0)
return -1.0;
else
return lb;
};
IgniteFunction<LabeledVector<Double>, LabeledVector<Double>> func = lv -> new LabeledVector<>(lv.features(), lbTransformer.apply(lv.label()));
PatchedPreprocessor<K, V, Double, Double> patchedPreprocessor = new PatchedPreprocessor<>(func, preprocessor);
PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(patchedPreprocessor);
Vector weights;
try (Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, learningEnvironment())) {
if (mdl == null) {
final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
if (a == null)
return b == null ? 0 : b;
if (b == null)
return a;
return b;
});
final int weightVectorSizeWithIntercept = cols + 1;
weights = initializeWeightsWithZeros(weightVectorSizeWithIntercept);
} else
weights = getStateVector(mdl);
for (int i = 0; i < this.getAmountOfIterations(); i++) {
Vector deltaWeights = calculateUpdates(weights, dataset);
if (deltaWeights == null)
return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
// creates new vector
weights = weights.plus(deltaWeights);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
return new SVMLinearClassificationModel(weights.copyOfRange(1, weights.size()), weights.get(0));
}
use of org.apache.ignite.ml.math.functions.IgniteFunction in project ignite by apache.
the class DataStreamGeneratorFillCacheTest method testCacheFilling.
/**
*/
@Test
public void testCacheFilling() {
IgniteConfiguration configuration = new IgniteConfiguration().setDiscoverySpi(new TcpDiscoverySpi().setIpFinder(new TcpDiscoveryVmIpFinder().setAddresses(Arrays.asList("127.0.0.1:47500..47509"))));
String cacheName = "TEST_CACHE";
CacheConfiguration<UUID, LabeledVector<Double>> cacheConfiguration = new CacheConfiguration<UUID, LabeledVector<Double>>(cacheName).setAffinity(new RendezvousAffinityFunction(false, 10));
int datasetSize = 5000;
try (Ignite ignite = Ignition.start(configuration)) {
IgniteCache<UUID, LabeledVector<Double>> cache = ignite.getOrCreateCache(cacheConfiguration);
DataStreamGenerator generator = new GaussRandomProducer(0).vectorize(1).asDataStream();
generator.fillCacheWithVecUUIDAsKey(datasetSize, cache);
LabeledDummyVectorizer<UUID, Double> vectorizer = new LabeledDummyVectorizer<>();
CacheBasedDatasetBuilder<UUID, LabeledVector<Double>> datasetBuilder = new CacheBasedDatasetBuilder<>(ignite, cache);
IgniteFunction<SimpleDatasetData, StatPair> map = data -> new StatPair(DoubleStream.of(data.getFeatures()).sum(), data.getRows());
LearningEnvironment env = LearningEnvironmentBuilder.defaultBuilder().buildForTrainer();
env.deployingContext().initByClientObject(map);
try (CacheBasedDataset<UUID, LabeledVector<Double>, EmptyContext, SimpleDatasetData> dataset = datasetBuilder.build(LearningEnvironmentBuilder.defaultBuilder(), new EmptyContextBuilder<>(), new SimpleDatasetDataBuilder<>(vectorizer), env)) {
StatPair res = dataset.compute(map, StatPair::sum);
assertEquals(datasetSize, res.cntOfRows);
assertEquals(0.0, res.elementsSum / res.cntOfRows, 1e-2);
}
ignite.destroyCache(cacheName);
}
}
Aggregations