Search in sources :

Example 11 with Dataset

use of org.apache.ignite.ml.dataset.Dataset in project ignite by apache.

the class LogisticRegressionSGDTrainer method updateModel.

/**
 * {@inheritDoc}
 */
@Override
protected <K, V> LogisticRegressionModel updateModel(LogisticRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
    IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> {
        Integer cols = dataset.compute(data -> {
            if (data.getFeatures() == null)
                return null;
            return data.getFeatures().length / data.getRows();
        }, (a, b) -> {
            // If both are null then zero will be propagated, no good.
            if (a == null)
                return b;
            return a;
        });
        if (cols == null)
            throw new IllegalStateException("Cannot train on empty dataset");
        MLPArchitecture architecture = new MLPArchitecture(cols);
        architecture = architecture.withAddedLayer(1, true, Activators.SIGMOID);
        return architecture;
    };
    MLPTrainer<?> trainer = new MLPTrainer<>(archSupplier, LossFunctions.L2, updatesStgy, maxIterations, batchSize, locIterations, seed).withEnvironmentBuilder(envBuilder);
    MultilayerPerceptron mlp;
    IgniteFunction<LabeledVector<Double>, LabeledVector<double[]>> func = lv -> new LabeledVector<>(lv.features(), new double[] { lv.label() });
    PatchedPreprocessor<K, V, Double, double[]> patchedPreprocessor = new PatchedPreprocessor<>(func, extractor);
    if (mdl != null) {
        mlp = restoreMLPState(mdl);
        mlp = trainer.update(mlp, datasetBuilder, patchedPreprocessor);
    } else
        mlp = trainer.fit(datasetBuilder, patchedPreprocessor);
    double[] params = mlp.parameters().getStorage().data();
    return new LogisticRegressionModel(new DenseVector(Arrays.copyOf(params, params.length - 1)), params[params.length - 1]);
}
Also used : SimpleGDUpdateCalculator(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator) Arrays(java.util.Arrays) Activators(org.apache.ignite.ml.nn.Activators) SimpleGDParameterUpdate(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate) UpdatesStrategy(org.apache.ignite.ml.nn.UpdatesStrategy) SimpleLabeledDatasetData(org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) MLPArchitecture(org.apache.ignite.ml.nn.architecture.MLPArchitecture) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) Dataset(org.apache.ignite.ml.dataset.Dataset) SingleLabelDatasetTrainer(org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer) LossFunctions(org.apache.ignite.ml.optimization.LossFunctions) MultilayerPerceptron(org.apache.ignite.ml.nn.MultilayerPerceptron) PatchedPreprocessor(org.apache.ignite.ml.preprocessing.developer.PatchedPreprocessor) NotNull(org.jetbrains.annotations.NotNull) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) MLPTrainer(org.apache.ignite.ml.nn.MLPTrainer) MLPArchitecture(org.apache.ignite.ml.nn.architecture.MLPArchitecture) Dataset(org.apache.ignite.ml.dataset.Dataset) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) MultilayerPerceptron(org.apache.ignite.ml.nn.MultilayerPerceptron) PatchedPreprocessor(org.apache.ignite.ml.preprocessing.developer.PatchedPreprocessor) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)

Example 12 with Dataset

use of org.apache.ignite.ml.dataset.Dataset in project ignite by apache.

the class Deltas method updateModel.

/**
 * {@inheritDoc}
 */
@Override
protected <K, V> SVMLinearClassificationModel updateModel(SVMLinearClassificationModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
    assert datasetBuilder != null;
    IgniteFunction<Double, Double> lbTransformer = lb -> {
        if (lb == 0.0)
            return -1.0;
        else
            return lb;
    };
    IgniteFunction<LabeledVector<Double>, LabeledVector<Double>> func = lv -> new LabeledVector<>(lv.features(), lbTransformer.apply(lv.label()));
    PatchedPreprocessor<K, V, Double, Double> patchedPreprocessor = new PatchedPreprocessor<>(func, preprocessor);
    PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(patchedPreprocessor);
    Vector weights;
    try (Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, learningEnvironment())) {
        if (mdl == null) {
            final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
                if (a == null)
                    return b == null ? 0 : b;
                if (b == null)
                    return a;
                return b;
            });
            final int weightVectorSizeWithIntercept = cols + 1;
            weights = initializeWeightsWithZeros(weightVectorSizeWithIntercept);
        } else
            weights = getStateVector(mdl);
        for (int i = 0; i < this.getAmountOfIterations(); i++) {
            Vector deltaWeights = calculateUpdates(weights, dataset);
            if (deltaWeights == null)
                return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
            // creates new vector
            weights = weights.plus(deltaWeights);
        }
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    return new SVMLinearClassificationModel(weights.copyOfRange(1, weights.size()), weights.get(0));
}
Also used : IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) Random(java.util.Random) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) SparseVector(org.apache.ignite.ml.math.primitives.vector.impl.SparseVector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) Dataset(org.apache.ignite.ml.dataset.Dataset) SingleLabelDatasetTrainer(org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer) LabeledVectorSet(org.apache.ignite.ml.structures.LabeledVectorSet) PatchedPreprocessor(org.apache.ignite.ml.preprocessing.developer.PatchedPreprocessor) PartitionDataBuilder(org.apache.ignite.ml.dataset.PartitionDataBuilder) NotNull(org.jetbrains.annotations.NotNull) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) UpstreamEntry(org.apache.ignite.ml.dataset.UpstreamEntry) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) LabeledDatasetPartitionDataBuilderOnHeap(org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) LabeledDatasetPartitionDataBuilderOnHeap(org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap) Dataset(org.apache.ignite.ml.dataset.Dataset) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LabeledVectorSet(org.apache.ignite.ml.structures.LabeledVectorSet) PatchedPreprocessor(org.apache.ignite.ml.preprocessing.developer.PatchedPreprocessor) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) SparseVector(org.apache.ignite.ml.math.primitives.vector.impl.SparseVector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)

Example 13 with Dataset

use of org.apache.ignite.ml.dataset.Dataset in project ignite by apache.

the class ImputerTrainer method fit.

/**
 * {@inheritDoc}
 */
@Override
public ImputerPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> basePreprocessor) {
    PartitionContextBuilder<K, V, EmptyContext> builder = (env, upstream, upstreamSize) -> new EmptyContext();
    try (Dataset<EmptyContext, ImputerPartitionData> dataset = datasetBuilder.build(envBuilder, builder, (env, upstream, upstreamSize, ctx) -> {
        double[] sums = null;
        int[] counts = null;
        double[] maxs = null;
        double[] mins = null;
        Map<Double, Integer>[] valuesByFreq = null;
        while (upstream.hasNext()) {
            UpstreamEntry<K, V> entity = upstream.next();
            LabeledVector row = basePreprocessor.apply(entity.getKey(), entity.getValue());
            switch(imputingStgy) {
                case MEAN:
                    sums = updateTheSums(row, sums);
                    counts = updateTheCounts(row, counts);
                    break;
                case MOST_FREQUENT:
                    valuesByFreq = updateFrequenciesByGivenRow(row, valuesByFreq);
                    break;
                case LEAST_FREQUENT:
                    valuesByFreq = updateFrequenciesByGivenRow(row, valuesByFreq);
                    break;
                case MAX:
                    maxs = updateTheMaxs(row, maxs);
                    break;
                case MIN:
                    mins = updateTheMins(row, mins);
                    break;
                case COUNT:
                    counts = updateTheCounts(row, counts);
                    break;
                default:
                    throw new UnsupportedOperationException("The chosen strategy is not supported");
            }
        }
        ImputerPartitionData partData;
        switch(imputingStgy) {
            case MEAN:
                partData = new ImputerPartitionData().withSums(sums).withCounts(counts);
                break;
            case MOST_FREQUENT:
                partData = new ImputerPartitionData().withValuesByFrequency(valuesByFreq);
                break;
            case LEAST_FREQUENT:
                partData = new ImputerPartitionData().withValuesByFrequency(valuesByFreq);
                break;
            case MAX:
                partData = new ImputerPartitionData().withMaxs(maxs);
                break;
            case MIN:
                partData = new ImputerPartitionData().withMins(mins);
                break;
            case COUNT:
                partData = new ImputerPartitionData().withCounts(counts);
                break;
            default:
                throw new UnsupportedOperationException("The chosen strategy is not supported");
        }
        return partData;
    }, learningEnvironment(basePreprocessor))) {
        Vector imputingValues;
        switch(imputingStgy) {
            case MEAN:
                imputingValues = VectorUtils.of(calculateImputingValuesBySumsAndCounts(dataset));
                break;
            case MOST_FREQUENT:
                imputingValues = VectorUtils.of(calculateImputingValuesByTheMostFrequentValues(dataset));
                break;
            case LEAST_FREQUENT:
                imputingValues = VectorUtils.of(calculateImputingValuesByTheLeastFrequentValues(dataset));
                break;
            case MAX:
                imputingValues = VectorUtils.of(calculateImputingValuesByMaxValues(dataset));
                break;
            case MIN:
                imputingValues = VectorUtils.of(calculateImputingValuesByMinValues(dataset));
                break;
            case COUNT:
                imputingValues = VectorUtils.of(calculateImputingValuesByCounts(dataset));
                break;
            default:
                throw new UnsupportedOperationException("The chosen strategy is not supported");
        }
        return new ImputerPreprocessor<>(imputingValues, basePreprocessor);
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
Also used : Arrays(java.util.Arrays) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) HashMap(java.util.HashMap) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) PreprocessingTrainer(org.apache.ignite.ml.preprocessing.PreprocessingTrainer) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) PartitionContextBuilder(org.apache.ignite.ml.dataset.PartitionContextBuilder) VectorUtils(org.apache.ignite.ml.math.primitives.vector.VectorUtils) Dataset(org.apache.ignite.ml.dataset.Dataset) Map(java.util.Map) Optional(java.util.Optional) Comparator(java.util.Comparator) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) UpstreamEntry(org.apache.ignite.ml.dataset.UpstreamEntry) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) HashMap(java.util.HashMap) Map(java.util.Map) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector)

Example 14 with Dataset

use of org.apache.ignite.ml.dataset.Dataset in project ignite by apache.

the class ImpurityHistogramsComputer method aggregateImpurityStatisticsOnPartition.

/**
 * Aggregates statistics for impurity computing for each corner nodes for each trees in random forest. This
 * algorithm predict corner node in decision tree for learning vector and stocks it to correspond histogram.
 *
 * @param dataset Dataset.
 * @param roots Trees.
 * @param histMeta Histogram buckets meta.
 * @param part Partition.
 * @return Leaf statistics for impurity computing.
 */
private Map<NodeId, NodeImpurityHistograms<S>> aggregateImpurityStatisticsOnPartition(BootstrappedDatasetPartition dataset, ArrayList<RandomForestTreeModel> roots, Map<Integer, BucketMeta> histMeta, Map<NodeId, TreeNode> part) {
    Map<NodeId, NodeImpurityHistograms<S>> res = part.keySet().stream().collect(Collectors.toMap(n -> n, NodeImpurityHistograms::new));
    dataset.forEach(vector -> {
        for (int sampleId = 0; sampleId < vector.counters().length; sampleId++) {
            if (vector.counters()[sampleId] == 0)
                continue;
            RandomForestTreeModel root = roots.get(sampleId);
            NodeId key = root.getRootNode().predictNextNodeKey(vector.features());
            if (// if we didn't take all nodes from learning queue
            !part.containsKey(key))
                continue;
            NodeImpurityHistograms<S> statistics = res.get(key);
            for (Integer featureId : root.getUsedFeatures()) {
                BucketMeta meta = histMeta.get(featureId);
                if (!statistics.perFeatureStatistics.containsKey(featureId))
                    statistics.perFeatureStatistics.put(featureId, createImpurityComputerForFeature(sampleId, meta));
                S impurityComputer = statistics.perFeatureStatistics.get(featureId);
                impurityComputer.addElement(vector);
            }
        }
    });
    return res;
}
Also used : TreeNode(org.apache.ignite.ml.tree.randomforest.data.TreeNode) BootstrappedVector(org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector) NodeSplit(org.apache.ignite.ml.tree.randomforest.data.NodeSplit) HashMap(java.util.HashMap) BootstrappedDatasetPartition(org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartition) Collectors(java.util.stream.Collectors) NodeId(org.apache.ignite.ml.tree.randomforest.data.NodeId) Serializable(java.io.Serializable) ArrayList(java.util.ArrayList) BucketMeta(org.apache.ignite.ml.dataset.feature.BucketMeta) Stream(java.util.stream.Stream) Dataset(org.apache.ignite.ml.dataset.Dataset) Map(java.util.Map) Optional(java.util.Optional) RandomForestTreeModel(org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel) Comparator(java.util.Comparator) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) RandomForestTreeModel(org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel) NodeId(org.apache.ignite.ml.tree.randomforest.data.NodeId) BucketMeta(org.apache.ignite.ml.dataset.feature.BucketMeta)

Aggregations

Dataset (org.apache.ignite.ml.dataset.Dataset)14 EmptyContext (org.apache.ignite.ml.dataset.primitive.context.EmptyContext)13 Preprocessor (org.apache.ignite.ml.preprocessing.Preprocessor)12 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)12 DatasetBuilder (org.apache.ignite.ml.dataset.DatasetBuilder)11 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)11 Arrays (java.util.Arrays)9 LearningEnvironmentBuilder (org.apache.ignite.ml.environment.LearningEnvironmentBuilder)6 SingleLabelDatasetTrainer (org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer)6 Serializable (java.io.Serializable)5 Map (java.util.Map)5 UpstreamEntry (org.apache.ignite.ml.dataset.UpstreamEntry)5 DenseVector (org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)5 Optional (java.util.Optional)4 PartitionDataBuilder (org.apache.ignite.ml.dataset.PartitionDataBuilder)4 IgniteFunction (org.apache.ignite.ml.math.functions.IgniteFunction)4 PatchedPreprocessor (org.apache.ignite.ml.preprocessing.developer.PatchedPreprocessor)4 NotNull (org.jetbrains.annotations.NotNull)4 ArrayList (java.util.ArrayList)3 HashMap (java.util.HashMap)3