Search in sources :

Example 46 with Vector

use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.

the class ReplicatedVectorMatrix method plus.

/**
 * Specialized optimized version of plus for ReplicatedVectorMatrix.
 *
 * @param mtx Matrix to be added.
 * @return New ReplicatedVectorMatrix resulting from addition.
 */
public Matrix plus(ReplicatedVectorMatrix mtx) {
    if (isColumnReplicated() == mtx.isColumnReplicated()) {
        checkCardinality(mtx.rowSize(), mtx.columnSize());
        Vector plus = vector.plus(mtx.replicant());
        return new ReplicatedVectorMatrix(plus, replicationCnt, asCol);
    }
    throw new UnsupportedOperationException();
}
Also used : Vector(org.apache.ignite.ml.math.primitives.vector.Vector)

Example 47 with Vector

use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.

the class DiscreteNaiveBayesTrainer method updateModel.

/**
 * {@inheritDoc}
 */
@Override
protected <K, V> DiscreteNaiveBayesModel updateModel(DiscreteNaiveBayesModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
    try (Dataset<EmptyContext, DiscreteNaiveBayesSumsHolder> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), (env, upstream, upstreamSize, ctx) -> {
        DiscreteNaiveBayesSumsHolder res = new DiscreteNaiveBayesSumsHolder();
        while (upstream.hasNext()) {
            UpstreamEntry<K, V> entity = upstream.next();
            LabeledVector lv = extractor.apply(entity.getKey(), entity.getValue());
            Vector features = lv.features();
            Double lb = (Double) lv.label();
            long[][] valuesInBucket;
            int size = features.size();
            if (!res.valuesInBucketPerLbl.containsKey(lb)) {
                valuesInBucket = new long[size][];
                for (int i = 0; i < size; i++) {
                    valuesInBucket[i] = new long[bucketThresholds[i].length + 1];
                    Arrays.fill(valuesInBucket[i], 0L);
                }
                res.valuesInBucketPerLbl.put(lb, valuesInBucket);
            }
            if (!res.featureCountersPerLbl.containsKey(lb))
                res.featureCountersPerLbl.put(lb, 0);
            res.featureCountersPerLbl.put(lb, res.featureCountersPerLbl.get(lb) + 1);
            valuesInBucket = res.valuesInBucketPerLbl.get(lb);
            for (int j = 0; j < size; j++) {
                double x = features.get(j);
                int bucketNum = toBucketNumber(x, bucketThresholds[j]);
                valuesInBucket[j][bucketNum] += 1;
            }
        }
        return res;
    }, learningEnvironment())) {
        DiscreteNaiveBayesSumsHolder sumsHolder = dataset.compute(t -> t, (a, b) -> {
            if (a == null)
                return b;
            if (b == null)
                return a;
            return a.merge(b);
        });
        if (mdl != null && isUpdateable(mdl)) {
            if (checkSumsHolder(sumsHolder, mdl.getSumsHolder()))
                sumsHolder = sumsHolder.merge(mdl.getSumsHolder());
        }
        List<Double> sortedLabels = new ArrayList<>(sumsHolder.featureCountersPerLbl.keySet());
        sortedLabels.sort(Double::compareTo);
        assert !sortedLabels.isEmpty() : "The dataset should contain at least one feature";
        int lbCnt = sortedLabels.size();
        int featureCnt = sumsHolder.valuesInBucketPerLbl.get(sortedLabels.get(0)).length;
        double[][][] probabilities = new double[lbCnt][featureCnt][];
        double[] classProbabilities = new double[lbCnt];
        double[] labels = new double[lbCnt];
        long datasetSize = sumsHolder.featureCountersPerLbl.values().stream().mapToInt(i -> i).sum();
        int lbl = 0;
        for (Double label : sortedLabels) {
            int cnt = sumsHolder.featureCountersPerLbl.get(label);
            long[][] sum = sumsHolder.valuesInBucketPerLbl.get(label);
            for (int i = 0; i < featureCnt; i++) {
                int bucketsCnt = sum[i].length;
                probabilities[lbl][i] = new double[bucketsCnt];
                for (int j = 0; j < bucketsCnt; j++) probabilities[lbl][i][j] = (double) sum[i][j] / cnt;
            }
            if (equiprobableClasses)
                classProbabilities[lbl] = 1. / lbCnt;
            else if (priorProbabilities != null) {
                assert classProbabilities.length == priorProbabilities.length;
                classProbabilities[lbl] = priorProbabilities[lbl];
            } else
                classProbabilities[lbl] = (double) cnt / datasetSize;
            labels[lbl] = label;
            ++lbl;
        }
        return new DiscreteNaiveBayesModel(probabilities, classProbabilities, labels, bucketThresholds, sumsHolder);
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
}
Also used : Arrays(java.util.Arrays) List(java.util.List) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) Dataset(org.apache.ignite.ml.dataset.Dataset) SingleLabelDatasetTrainer(org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) Optional(java.util.Optional) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) ArrayList(java.util.ArrayList) UpstreamEntry(org.apache.ignite.ml.dataset.UpstreamEntry) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) ArrayList(java.util.ArrayList) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) Vector(org.apache.ignite.ml.math.primitives.vector.Vector)

Example 48 with Vector

use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.

the class MultilayerPerceptron method paramsAsVector.

/**
 * Flatten this MLP parameters as vector.
 *
 * @param layersParams List of layers parameters.
 * @return This MLP parameters as vector.
 */
private Vector paramsAsVector(List<MLPLayer> layersParams) {
    int off = 0;
    Vector res = new DenseVector(architecture().parametersCount());
    for (MLPLayer layerParams : layersParams) {
        off = writeToVector(res, layerParams.weights, off);
        if (layerParams.biases != null)
            off = writeToVector(res, layerParams.biases, off);
    }
    return res;
}
Also used : Vector(org.apache.ignite.ml.math.primitives.vector.Vector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)

Example 49 with Vector

use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.

the class MultivariateGaussianDistribution method prob.

/**
 * {@inheritDoc}
 */
@Override
public double prob(Vector x) {
    Vector delta = x.minus(mean);
    Matrix ePower = delta.toMatrix(true).times(invCovariance).times(delta.toMatrix(false)).times(-0.5);
    assert ePower.columnSize() == 1 && ePower.rowSize() == 1;
    return Math.pow(Math.E, ePower.get(0, 0)) / normalizer;
}
Also used : Matrix(org.apache.ignite.ml.math.primitives.matrix.Matrix) Vector(org.apache.ignite.ml.math.primitives.vector.Vector)

Example 50 with Vector

use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.

the class RPropUpdateCalculator method calculateNewUpdate.

/**
 * {@inheritDoc}
 */
@Override
public RPropParameterUpdate calculateNewUpdate(SmoothParametrized mdl, RPropParameterUpdate updaterParams, int iteration, Matrix inputs, Matrix groundTruth) {
    Vector gradient = mdl.differentiateByParameters(loss, inputs, groundTruth);
    Vector prevGradient = updaterParams.prevIterationGradient();
    Vector derSigns;
    if (prevGradient != null)
        derSigns = VectorUtils.zipWith(prevGradient, gradient, (x, y) -> Math.signum(x * y));
    else
        derSigns = gradient.like(gradient.size()).assign(1.0);
    Vector newDeltas = updaterParams.deltas().copy().map(derSigns, (prevDelta, sign) -> {
        if (sign > 0)
            return Math.min(prevDelta * accelerationRate, UPDATE_MAX);
        else if (sign < 0)
            return Math.max(prevDelta * deaccelerationRate, UPDATE_MIN);
        else
            return prevDelta;
    });
    Vector newPrevIterationUpdates = MatrixUtil.zipWith(gradient, updaterParams.deltas(), (der, delta, i) -> {
        if (derSigns.getX(i) >= 0)
            return -Math.signum(der) * delta;
        return updaterParams.prevIterationUpdates().getX(i);
    });
    Vector updatesMask = MatrixUtil.zipWith(derSigns, updaterParams.prevIterationUpdates(), (sign, upd, i) -> {
        if (sign < 0)
            gradient.setX(i, 0.0);
        if (sign >= 0)
            return 1.0;
        else
            return -1.0;
    });
    return new RPropParameterUpdate(newPrevIterationUpdates, gradient.copy(), newDeltas, updatesMask);
}
Also used : Vector(org.apache.ignite.ml.math.primitives.vector.Vector)

Aggregations

Vector (org.apache.ignite.ml.math.primitives.vector.Vector)265 DenseVector (org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)95 Test (org.junit.Test)94 Ignite (org.apache.ignite.Ignite)78 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)49 HashMap (java.util.HashMap)39 SandboxMLCache (org.apache.ignite.examples.ml.util.SandboxMLCache)38 DummyVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer)26 FileNotFoundException (java.io.FileNotFoundException)22 TrainerTest (org.apache.ignite.ml.common.TrainerTest)22 DecisionTreeClassificationTrainer (org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer)21 DecisionTreeModel (org.apache.ignite.ml.tree.DecisionTreeModel)21 Serializable (java.io.Serializable)19 IgniteCache (org.apache.ignite.IgniteCache)18 EncoderTrainer (org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer)16 Cache (javax.cache.Cache)15 DoubleArrayVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer)15 EuclideanDistance (org.apache.ignite.ml.math.distances.EuclideanDistance)14 ArrayList (java.util.ArrayList)12 ModelsComposition (org.apache.ignite.ml.composition.ModelsComposition)12