Search in sources :

Example 6 with EmptyContext

use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.

the class OneVsRestTrainer method extractClassLabels.

/**
 * Iterates among dataset and collects class labels.
 */
private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
    assert datasetBuilder != null;
    PartitionDataBuilder<K, V, EmptyContext, LabelPartitionDataOnHeap> partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(preprocessor);
    List<Double> res = new ArrayList<>();
    try (Dataset<EmptyContext, LabelPartitionDataOnHeap> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, learningEnvironment())) {
        final Set<Double> clsLabels = dataset.compute(data -> {
            final Set<Double> locClsLabels = new HashSet<>();
            final double[] lbs = data.getY();
            for (double lb : lbs) locClsLabels.add(lb);
            return locClsLabels;
        }, (a, b) -> {
            if (a == null)
                return b == null ? new HashSet<>() : b;
            if (b == null)
                return a;
            return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet());
        });
        if (clsLabels != null)
            res.addAll(clsLabels);
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    return res;
}
Also used : EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) ArrayList(java.util.ArrayList) LabelPartitionDataBuilderOnHeap(org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap) LabelPartitionDataOnHeap(org.apache.ignite.ml.structures.partition.LabelPartitionDataOnHeap) HashSet(java.util.HashSet)

Example 7 with EmptyContext

use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.

the class ANNClassificationTrainer method getCentroidStat.

/**
 */
private <K, V> CentroidStat getCentroidStat(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> vectorizer, List<Vector> centers) {
    PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(vectorizer);
    try (Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, learningEnvironment())) {
        return dataset.compute(data -> {
            CentroidStat res = new CentroidStat();
            for (int i = 0; i < data.rowSize(); i++) {
                final IgniteBiTuple<Integer, Double> closestCentroid = findClosestCentroid(centers, data.getRow(i));
                int centroidIdx = closestCentroid.get1();
                double lb = data.label(i);
                // add new label to label set
                res.labels().add(lb);
                ConcurrentHashMap<Double, Integer> centroidStat = res.centroidStat.get(centroidIdx);
                if (centroidStat == null) {
                    centroidStat = new ConcurrentHashMap<>();
                    centroidStat.put(lb, 1);
                    res.centroidStat.put(centroidIdx, centroidStat);
                } else {
                    int cnt = centroidStat.getOrDefault(lb, 0);
                    centroidStat.put(lb, cnt + 1);
                }
                res.counts.merge(centroidIdx, 1, (IgniteBiFunction<Integer, Integer, Integer>) (i1, i2) -> i1 + i2);
            }
            return res;
        }, (a, b) -> {
            if (a == null)
                return b == null ? new CentroidStat() : b;
            if (b == null)
                return a;
            return a.merge(b);
        });
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
Also used : Arrays(java.util.Arrays) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) KMeansTrainer(org.apache.ignite.ml.clustering.kmeans.KMeansTrainer) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) SingleLabelDatasetTrainer(org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer) EuclideanDistance(org.apache.ignite.ml.math.distances.EuclideanDistance) JsonIgnore(com.fasterxml.jackson.annotation.JsonIgnore) LabeledVectorSet(org.apache.ignite.ml.structures.LabeledVectorSet) PartitionDataBuilder(org.apache.ignite.ml.dataset.PartitionDataBuilder) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) LabeledDatasetPartitionDataBuilderOnHeap(org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) DistanceMeasure(org.apache.ignite.ml.math.distances.DistanceMeasure) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) Collectors(java.util.stream.Collectors) MapUtil(org.apache.ignite.ml.math.util.MapUtil) Serializable(java.io.Serializable) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) List(java.util.List) TreeMap(java.util.TreeMap) ConcurrentSkipListSet(java.util.concurrent.ConcurrentSkipListSet) IgniteBiFunction(org.apache.ignite.ml.math.functions.IgniteBiFunction) KMeansModel(org.apache.ignite.ml.clustering.kmeans.KMeansModel) Dataset(org.apache.ignite.ml.dataset.Dataset) NotNull(org.jetbrains.annotations.NotNull) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) LabeledDatasetPartitionDataBuilderOnHeap(org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap) LabeledVectorSet(org.apache.ignite.ml.structures.LabeledVectorSet)

Example 8 with EmptyContext

use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.

the class KNNUtils method buildDataset.

/**
 * Builds dataset.
 *
 * @param envBuilder Learning environment builder.
 * @param datasetBuilder Dataset builder.
 * @param vectorizer Upstream vectorizer.
 * @return Dataset.
 */
@Nullable
public static <K, V, C extends Serializable> Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> buildDataset(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> vectorizer) {
    LearningEnvironment environment = envBuilder.buildForTrainer();
    environment.initDeployingContext(vectorizer);
    PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(vectorizer);
    Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset = null;
    if (datasetBuilder != null) {
        dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, environment);
    }
    return dataset;
}
Also used : LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) Nullable(org.jetbrains.annotations.Nullable) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) Dataset(org.apache.ignite.ml.dataset.Dataset) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) LabeledVectorSet(org.apache.ignite.ml.structures.LabeledVectorSet) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) PartitionDataBuilder(org.apache.ignite.ml.dataset.PartitionDataBuilder) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) Serializable(java.io.Serializable) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) LabeledDatasetPartitionDataBuilderOnHeap(org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) LabeledDatasetPartitionDataBuilderOnHeap(org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap) LabeledVectorSet(org.apache.ignite.ml.structures.LabeledVectorSet) Nullable(org.jetbrains.annotations.Nullable)

Example 9 with EmptyContext

use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.

the class MeanAbsValueConvergenceCheckerTest method testConvergenceChecking.

/**
 */
@Test
public void testConvergenceChecking() {
    LocalDatasetBuilder<Integer, LabeledVector<Double>> datasetBuilder = new LocalDatasetBuilder<>(data, 1);
    ConvergenceChecker<Integer, LabeledVector<Double>> checker = createChecker(new MeanAbsValueConvergenceCheckerFactory(0.1), datasetBuilder);
    double error = checker.computeError(VectorUtils.of(1, 2), 4.0, notConvergedMdl);
    LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder();
    Assert.assertEquals(1.9, error, 0.01);
    Assert.assertFalse(checker.isConverged(envBuilder, datasetBuilder, notConvergedMdl));
    Assert.assertTrue(checker.isConverged(envBuilder, datasetBuilder, convergedMdl));
    try (LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new FeatureMatrixWithLabelsOnHeapDataBuilder<>(vectorizer), envBuilder.buildForTrainer())) {
        double onDSError = checker.computeMeanErrorOnDataset(dataset, notConvergedMdl);
        Assert.assertEquals(1.55, onDSError, 0.01);
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
Also used : EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) FeatureMatrixWithLabelsOnHeapData(org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LocalDatasetBuilder(org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder) ConvergenceCheckerTest(org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerTest) Test(org.junit.Test)

Example 10 with EmptyContext

use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.

the class LinearRegressionSGDTrainer method updateModel.

/**
 * {@inheritDoc}
 */
@Override
protected <K, V> LinearRegressionModel updateModel(LinearRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
    assert updatesStgy != null;
    IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> {
        int cols = dataset.compute(data -> {
            if (data.getFeatures() == null)
                return null;
            return data.getFeatures().length / data.getRows();
        }, (a, b) -> {
            if (a == null)
                return b == null ? 0 : b;
            if (b == null)
                return a;
            return b;
        });
        MLPArchitecture architecture = new MLPArchitecture(cols);
        architecture = architecture.withAddedLayer(1, true, Activators.LINEAR);
        return architecture;
    };
    MLPTrainer<?> trainer = new MLPTrainer<>(archSupplier, LossFunctions.MSE, updatesStgy, maxIterations, batchSize, locIterations, seed);
    IgniteFunction<LabeledVector<Double>, LabeledVector<double[]>> func = lv -> new LabeledVector<>(lv.features(), new double[] { lv.label() });
    PatchedPreprocessor<K, V, Double, double[]> patchedPreprocessor = new PatchedPreprocessor<>(func, extractor);
    MultilayerPerceptron mlp = Optional.ofNullable(mdl).map(this::restoreMLPState).map(m -> trainer.update(m, datasetBuilder, patchedPreprocessor)).orElseGet(() -> trainer.fit(datasetBuilder, patchedPreprocessor));
    double[] p = mlp.parameters().getStorage().data();
    return new LinearRegressionModel(new DenseVector(Arrays.copyOf(p, p.length - 1)), p[p.length - 1]);
}
Also used : Arrays(java.util.Arrays) Activators(org.apache.ignite.ml.nn.Activators) UpdatesStrategy(org.apache.ignite.ml.nn.UpdatesStrategy) SimpleLabeledDatasetData(org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) MLPArchitecture(org.apache.ignite.ml.nn.architecture.MLPArchitecture) Serializable(java.io.Serializable) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) Dataset(org.apache.ignite.ml.dataset.Dataset) SingleLabelDatasetTrainer(org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer) Optional(java.util.Optional) LossFunctions(org.apache.ignite.ml.optimization.LossFunctions) MultilayerPerceptron(org.apache.ignite.ml.nn.MultilayerPerceptron) PatchedPreprocessor(org.apache.ignite.ml.preprocessing.developer.PatchedPreprocessor) NotNull(org.jetbrains.annotations.NotNull) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) MLPTrainer(org.apache.ignite.ml.nn.MLPTrainer) MLPArchitecture(org.apache.ignite.ml.nn.architecture.MLPArchitecture) Dataset(org.apache.ignite.ml.dataset.Dataset) MLPTrainer(org.apache.ignite.ml.nn.MLPTrainer) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) MultilayerPerceptron(org.apache.ignite.ml.nn.MultilayerPerceptron) PatchedPreprocessor(org.apache.ignite.ml.preprocessing.developer.PatchedPreprocessor) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)

Aggregations

EmptyContext (org.apache.ignite.ml.dataset.primitive.context.EmptyContext)23 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)16 Dataset (org.apache.ignite.ml.dataset.Dataset)12 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)12 DatasetBuilder (org.apache.ignite.ml.dataset.DatasetBuilder)11 Preprocessor (org.apache.ignite.ml.preprocessing.Preprocessor)11 Arrays (java.util.Arrays)9 LearningEnvironmentBuilder (org.apache.ignite.ml.environment.LearningEnvironmentBuilder)9 UpstreamEntry (org.apache.ignite.ml.dataset.UpstreamEntry)6 SingleLabelDatasetTrainer (org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer)6 LearningEnvironment (org.apache.ignite.ml.environment.LearningEnvironment)5 ArrayList (java.util.ArrayList)4 Map (java.util.Map)4 PartitionDataBuilder (org.apache.ignite.ml.dataset.PartitionDataBuilder)4 FeatureMatrixWithLabelsOnHeapData (org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData)4 DenseVector (org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)4 NotNull (org.jetbrains.annotations.NotNull)4 Serializable (java.io.Serializable)3 List (java.util.List)3 Optional (java.util.Optional)3