Search in sources :

Example 1 with EmptyContext

use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.

the class EncoderTrainer method fit.

/**
 * {@inheritDoc}
 */
@Override
public EncoderPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> basePreprocessor) {
    if (handledIndices.isEmpty() && encoderType != EncoderType.LABEL_ENCODER)
        throw new RuntimeException("Add indices of handled features");
    try (Dataset<EmptyContext, EncoderPartitionData> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), (env, upstream, upstreamSize, ctx) -> {
        EncoderPartitionData partData = new EncoderPartitionData();
        if (encoderType == EncoderType.LABEL_ENCODER) {
            Map<String, Integer> lbFrequencies = null;
            while (upstream.hasNext()) {
                UpstreamEntry<K, V> entity = upstream.next();
                LabeledVector<Double> row = basePreprocessor.apply(entity.getKey(), entity.getValue());
                lbFrequencies = updateLabelFrequenciesForNextRow(row, lbFrequencies);
            }
            partData.withLabelFrequencies(lbFrequencies);
        } else if (encoderType == EncoderType.TARGET_ENCODER) {
            TargetCounter[] targetCounter = null;
            while (upstream.hasNext()) {
                UpstreamEntry<K, V> entity = upstream.next();
                LabeledVector<Double> row = basePreprocessor.apply(entity.getKey(), entity.getValue());
                targetCounter = updateTargetCountersForNextRow(row, targetCounter);
            }
            partData.withTargetCounters(targetCounter);
        } else {
            // This array will contain not null values for handled indices
            Map<String, Integer>[] categoryFrequencies = null;
            while (upstream.hasNext()) {
                UpstreamEntry<K, V> entity = upstream.next();
                LabeledVector<Double> row = basePreprocessor.apply(entity.getKey(), entity.getValue());
                categoryFrequencies = updateFeatureFrequenciesForNextRow(row, categoryFrequencies);
            }
            partData.withCategoryFrequencies(categoryFrequencies);
        }
        return partData;
    }, learningEnvironment(basePreprocessor))) {
        switch(encoderType) {
            case ONE_HOT_ENCODER:
                return new OneHotEncoderPreprocessor<>(calculateEncodingValuesByFrequencies(dataset), basePreprocessor, handledIndices);
            case STRING_ENCODER:
                return new StringEncoderPreprocessor<>(calculateEncodingValuesByFrequencies(dataset), basePreprocessor, handledIndices);
            case LABEL_ENCODER:
                return new LabelEncoderPreprocessor<>(calculateEncodingValuesForLabelsByFrequencies(dataset), basePreprocessor);
            case FREQUENCY_ENCODER:
                return new FrequencyEncoderPreprocessor<>(calculateEncodingFrequencies(dataset), basePreprocessor, handledIndices);
            case TARGET_ENCODER:
                return new TargetEncoderPreprocessor<>(calculateTargetEncodingFrequencies(dataset), basePreprocessor, handledIndices);
            default:
                throw new IllegalStateException("Define the type of the resulting prerocessor.");
        }
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
Also used : EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) StringEncoderPreprocessor(org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderPreprocessor) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) TargetEncoderPreprocessor(org.apache.ignite.ml.preprocessing.encoding.target.TargetEncoderPreprocessor) FrequencyEncoderPreprocessor(org.apache.ignite.ml.preprocessing.encoding.frequency.FrequencyEncoderPreprocessor) UndefinedLabelException(org.apache.ignite.ml.math.exceptions.preprocessing.UndefinedLabelException) OneHotEncoderPreprocessor(org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor) LabelEncoderPreprocessor(org.apache.ignite.ml.preprocessing.encoding.label.LabelEncoderPreprocessor) UpstreamEntry(org.apache.ignite.ml.dataset.UpstreamEntry)

Example 2 with EmptyContext

use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.

the class GaussianNaiveBayesTrainer method updateModel.

/**
 * {@inheritDoc}
 */
@Override
protected <K, V> GaussianNaiveBayesModel updateModel(GaussianNaiveBayesModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
    assert datasetBuilder != null;
    try (Dataset<EmptyContext, GaussianNaiveBayesSumsHolder> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), (env, upstream, upstreamSize, ctx) -> {
        GaussianNaiveBayesSumsHolder res = new GaussianNaiveBayesSumsHolder();
        while (upstream.hasNext()) {
            UpstreamEntry<K, V> entity = upstream.next();
            LabeledVector lv = extractor.apply(entity.getKey(), entity.getValue());
            Vector features = lv.features();
            Double label = (Double) lv.label();
            double[] toMeans;
            double[] sqSum;
            if (!res.featureSumsPerLbl.containsKey(label)) {
                toMeans = new double[features.size()];
                Arrays.fill(toMeans, 0.);
                res.featureSumsPerLbl.put(label, toMeans);
            }
            if (!res.featureSquaredSumsPerLbl.containsKey(label)) {
                sqSum = new double[features.size()];
                res.featureSquaredSumsPerLbl.put(label, sqSum);
            }
            if (!res.featureCountersPerLbl.containsKey(label))
                res.featureCountersPerLbl.put(label, 0);
            res.featureCountersPerLbl.put(label, res.featureCountersPerLbl.get(label) + 1);
            toMeans = res.featureSumsPerLbl.get(label);
            sqSum = res.featureSquaredSumsPerLbl.get(label);
            for (int j = 0; j < features.size(); j++) {
                double x = features.get(j);
                toMeans[j] += x;
                sqSum[j] += x * x;
            }
        }
        return res;
    }, learningEnvironment())) {
        GaussianNaiveBayesSumsHolder sumsHolder = dataset.compute(t -> t, (a, b) -> {
            if (a == null)
                return b;
            if (b == null)
                return a;
            return a.merge(b);
        });
        if (mdl != null && mdl.getSumsHolder() != null)
            sumsHolder = sumsHolder.merge(mdl.getSumsHolder());
        List<Double> sortedLabels = new ArrayList<>(sumsHolder.featureCountersPerLbl.keySet());
        sortedLabels.sort(Double::compareTo);
        assert !sortedLabels.isEmpty() : "The dataset should contain at least one feature";
        int labelCount = sortedLabels.size();
        int featureCount = sumsHolder.featureSumsPerLbl.get(sortedLabels.get(0)).length;
        double[][] means = new double[labelCount][featureCount];
        double[][] variances = new double[labelCount][featureCount];
        double[] classProbabilities = new double[labelCount];
        double[] labels = new double[labelCount];
        long datasetSize = sumsHolder.featureCountersPerLbl.values().stream().mapToInt(i -> i).sum();
        int lbl = 0;
        for (Double label : sortedLabels) {
            int count = sumsHolder.featureCountersPerLbl.get(label);
            double[] sum = sumsHolder.featureSumsPerLbl.get(label);
            double[] sqSum = sumsHolder.featureSquaredSumsPerLbl.get(label);
            for (int i = 0; i < featureCount; i++) {
                means[lbl][i] = sum[i] / count;
                variances[lbl][i] = (sqSum[i] - sum[i] * sum[i] / count) / count;
            }
            if (equiprobableClasses)
                classProbabilities[lbl] = 1. / labelCount;
            else if (priorProbabilities != null) {
                assert classProbabilities.length == priorProbabilities.length;
                classProbabilities[lbl] = priorProbabilities[lbl];
            } else
                classProbabilities[lbl] = (double) count / datasetSize;
            labels[lbl] = label;
            ++lbl;
        }
        return new GaussianNaiveBayesModel(means, variances, classProbabilities, labels, sumsHolder);
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
Also used : Arrays(java.util.Arrays) List(java.util.List) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) Dataset(org.apache.ignite.ml.dataset.Dataset) SingleLabelDatasetTrainer(org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) ArrayList(java.util.ArrayList) UpstreamEntry(org.apache.ignite.ml.dataset.UpstreamEntry) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) ArrayList(java.util.ArrayList) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) Vector(org.apache.ignite.ml.math.primitives.vector.Vector)

Example 3 with EmptyContext

use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.

the class MLPTrainer method updateModel.

/**
 * {@inheritDoc}
 */
@Override
protected <K, V> MultilayerPerceptron updateModel(MultilayerPerceptron lastLearnedMdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
    assert archSupplier != null;
    assert loss != null;
    assert updatesStgy != null;
    try (Dataset<EmptyContext, SimpleLabeledDatasetData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new SimpleLabeledDatasetDataBuilder<>(extractor), learningEnvironment())) {
        MultilayerPerceptron mdl;
        if (lastLearnedMdl != null)
            mdl = lastLearnedMdl;
        else {
            MLPArchitecture arch = archSupplier.apply(dataset);
            mdl = new MultilayerPerceptron(arch, new RandomInitializer(seed));
        }
        ParameterUpdateCalculator<? super MultilayerPerceptron, P> updater = updatesStgy.getUpdatesCalculator();
        for (int i = 0; i < maxIterations; i += locIterations) {
            MultilayerPerceptron finalMdl = mdl;
            int finalI = i;
            List<P> totUp = dataset.compute(data -> {
                P update = updater.init(finalMdl, loss);
                MultilayerPerceptron mlp = Utils.copy(finalMdl);
                if (data.getFeatures() != null) {
                    List<P> updates = new ArrayList<>();
                    for (int locStep = 0; locStep < locIterations; locStep++) {
                        int[] rows = Utils.selectKDistinct(data.getRows(), Math.min(batchSize, data.getRows()), new Random(seed ^ (finalI * locStep)));
                        double[] inputsBatch = batch(data.getFeatures(), rows, data.getRows());
                        double[] groundTruthBatch = batch(data.getLabels(), rows, data.getRows());
                        Matrix inputs = new DenseMatrix(inputsBatch, rows.length, 0);
                        Matrix groundTruth = new DenseMatrix(groundTruthBatch, rows.length, 0);
                        update = updater.calculateNewUpdate(mlp, update, locStep, inputs.transpose(), groundTruth.transpose());
                        mlp = updater.update(mlp, update);
                        updates.add(update);
                    }
                    List<P> res = new ArrayList<>();
                    res.add(updatesStgy.locStepUpdatesReducer().apply(updates));
                    return res;
                }
                return null;
            }, (a, b) -> {
                if (a == null)
                    return b;
                else if (b == null)
                    return a;
                else {
                    a.addAll(b);
                    return a;
                }
            });
            if (totUp == null)
                return getLastTrainedModelOrThrowEmptyDatasetException(lastLearnedMdl);
            P update = updatesStgy.allUpdatesReducer().apply(totUp);
            mdl = updater.update(mdl, update);
        }
        return mdl;
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
Also used : EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) SimpleLabeledDatasetData(org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData) MLPArchitecture(org.apache.ignite.ml.nn.architecture.MLPArchitecture) ArrayList(java.util.ArrayList) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) Matrix(org.apache.ignite.ml.math.primitives.matrix.Matrix) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) Random(java.util.Random) RandomInitializer(org.apache.ignite.ml.nn.initializers.RandomInitializer)

Example 4 with EmptyContext

use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.

the class DiscreteNaiveBayesTrainer method updateModel.

/**
 * {@inheritDoc}
 */
@Override
protected <K, V> DiscreteNaiveBayesModel updateModel(DiscreteNaiveBayesModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
    try (Dataset<EmptyContext, DiscreteNaiveBayesSumsHolder> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), (env, upstream, upstreamSize, ctx) -> {
        DiscreteNaiveBayesSumsHolder res = new DiscreteNaiveBayesSumsHolder();
        while (upstream.hasNext()) {
            UpstreamEntry<K, V> entity = upstream.next();
            LabeledVector lv = extractor.apply(entity.getKey(), entity.getValue());
            Vector features = lv.features();
            Double lb = (Double) lv.label();
            long[][] valuesInBucket;
            int size = features.size();
            if (!res.valuesInBucketPerLbl.containsKey(lb)) {
                valuesInBucket = new long[size][];
                for (int i = 0; i < size; i++) {
                    valuesInBucket[i] = new long[bucketThresholds[i].length + 1];
                    Arrays.fill(valuesInBucket[i], 0L);
                }
                res.valuesInBucketPerLbl.put(lb, valuesInBucket);
            }
            if (!res.featureCountersPerLbl.containsKey(lb))
                res.featureCountersPerLbl.put(lb, 0);
            res.featureCountersPerLbl.put(lb, res.featureCountersPerLbl.get(lb) + 1);
            valuesInBucket = res.valuesInBucketPerLbl.get(lb);
            for (int j = 0; j < size; j++) {
                double x = features.get(j);
                int bucketNum = toBucketNumber(x, bucketThresholds[j]);
                valuesInBucket[j][bucketNum] += 1;
            }
        }
        return res;
    }, learningEnvironment())) {
        DiscreteNaiveBayesSumsHolder sumsHolder = dataset.compute(t -> t, (a, b) -> {
            if (a == null)
                return b;
            if (b == null)
                return a;
            return a.merge(b);
        });
        if (mdl != null && isUpdateable(mdl)) {
            if (checkSumsHolder(sumsHolder, mdl.getSumsHolder()))
                sumsHolder = sumsHolder.merge(mdl.getSumsHolder());
        }
        List<Double> sortedLabels = new ArrayList<>(sumsHolder.featureCountersPerLbl.keySet());
        sortedLabels.sort(Double::compareTo);
        assert !sortedLabels.isEmpty() : "The dataset should contain at least one feature";
        int lbCnt = sortedLabels.size();
        int featureCnt = sumsHolder.valuesInBucketPerLbl.get(sortedLabels.get(0)).length;
        double[][][] probabilities = new double[lbCnt][featureCnt][];
        double[] classProbabilities = new double[lbCnt];
        double[] labels = new double[lbCnt];
        long datasetSize = sumsHolder.featureCountersPerLbl.values().stream().mapToInt(i -> i).sum();
        int lbl = 0;
        for (Double label : sortedLabels) {
            int cnt = sumsHolder.featureCountersPerLbl.get(label);
            long[][] sum = sumsHolder.valuesInBucketPerLbl.get(label);
            for (int i = 0; i < featureCnt; i++) {
                int bucketsCnt = sum[i].length;
                probabilities[lbl][i] = new double[bucketsCnt];
                for (int j = 0; j < bucketsCnt; j++) probabilities[lbl][i][j] = (double) sum[i][j] / cnt;
            }
            if (equiprobableClasses)
                classProbabilities[lbl] = 1. / lbCnt;
            else if (priorProbabilities != null) {
                assert classProbabilities.length == priorProbabilities.length;
                classProbabilities[lbl] = priorProbabilities[lbl];
            } else
                classProbabilities[lbl] = (double) cnt / datasetSize;
            labels[lbl] = label;
            ++lbl;
        }
        return new DiscreteNaiveBayesModel(probabilities, classProbabilities, labels, bucketThresholds, sumsHolder);
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
Also used : Arrays(java.util.Arrays) List(java.util.List) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) Dataset(org.apache.ignite.ml.dataset.Dataset) SingleLabelDatasetTrainer(org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) Optional(java.util.Optional) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) ArrayList(java.util.ArrayList) UpstreamEntry(org.apache.ignite.ml.dataset.UpstreamEntry) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) ArrayList(java.util.ArrayList) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) Vector(org.apache.ignite.ml.math.primitives.vector.Vector)

Example 5 with EmptyContext

use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.

the class MaxAbsScalerTrainer method fit.

/**
 * {@inheritDoc}
 */
@Override
public MaxAbsScalerPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> basePreprocessor) {
    try (Dataset<EmptyContext, MaxAbsScalerPartitionData> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), (env, upstream, upstreamSize, ctx) -> {
        double[] maxAbs = null;
        while (upstream.hasNext()) {
            UpstreamEntry<K, V> entity = upstream.next();
            LabeledVector row = basePreprocessor.apply(entity.getKey(), entity.getValue());
            if (maxAbs == null) {
                maxAbs = new double[row.size()];
                Arrays.fill(maxAbs, .0);
            } else
                assert maxAbs.length == row.size() : "Base preprocessor must return exactly " + maxAbs.length + " features";
            for (int i = 0; i < row.size(); i++) {
                if (Math.abs(row.get(i)) > Math.abs(maxAbs[i]))
                    maxAbs[i] = Math.abs(row.get(i));
            }
        }
        return new MaxAbsScalerPartitionData(maxAbs);
    }, learningEnvironment(basePreprocessor))) {
        double[] maxAbs = dataset.compute(MaxAbsScalerPartitionData::getMaxAbs, (a, b) -> {
            if (a == null)
                return b;
            if (b == null)
                return a;
            double[] res = new double[a.length];
            for (int i = 0; i < res.length; i++) res[i] = Math.max(Math.abs(a[i]), Math.abs(b[i]));
            return res;
        });
        return new MaxAbsScalerPreprocessor<>(maxAbs, basePreprocessor);
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
Also used : EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) LabeledVector(org.apache.ignite.ml.structures.LabeledVector)

Aggregations

EmptyContext (org.apache.ignite.ml.dataset.primitive.context.EmptyContext)23 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)16 Dataset (org.apache.ignite.ml.dataset.Dataset)12 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)12 DatasetBuilder (org.apache.ignite.ml.dataset.DatasetBuilder)11 Preprocessor (org.apache.ignite.ml.preprocessing.Preprocessor)11 Arrays (java.util.Arrays)9 LearningEnvironmentBuilder (org.apache.ignite.ml.environment.LearningEnvironmentBuilder)9 UpstreamEntry (org.apache.ignite.ml.dataset.UpstreamEntry)6 SingleLabelDatasetTrainer (org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer)6 LearningEnvironment (org.apache.ignite.ml.environment.LearningEnvironment)5 ArrayList (java.util.ArrayList)4 Map (java.util.Map)4 PartitionDataBuilder (org.apache.ignite.ml.dataset.PartitionDataBuilder)4 FeatureMatrixWithLabelsOnHeapData (org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData)4 DenseVector (org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)4 NotNull (org.jetbrains.annotations.NotNull)4 Serializable (java.io.Serializable)3 List (java.util.List)3 Optional (java.util.Optional)3