Search in sources :

Example 71 with Vector

use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.

the class RecommendationDatasetData method calculateGradient.

/**
 * Calculates gradient of the loss function of recommendation system SGD training. The details about gradient
 * calculation could be found here: https://tinyurl.com/y6cku9hr.
 *
 * @param objMatrix Object matrix obtained as a result of factorization of rating matrix.
 * @param subjMatrix Subject matrix obtained as a result of factorization of rating matrix.
 * @param batchSize Batch size of stochastic gradient descent. The size of a dataset used on each step of SGD.
 * @param seed Seed (required to make randomized part behaviour repeatable).
 * @param regParam Regularization parameter.
 * @param learningRate Learning rate.
 * @return Gradient of matrix factorization loss function.
 */
public MatrixFactorizationGradient<O, S> calculateGradient(Map<O, Vector> objMatrix, Map<S, Vector> subjMatrix, int batchSize, int seed, double regParam, double learningRate) {
    Map<O, Vector> objGrads = new HashMap<>();
    Map<S, Vector> subjGrads = new HashMap<>();
    int[] rows = getRows(batchSize, seed);
    for (int row : rows) {
        ObjectSubjectRatingTriplet<O, S> triplet = ratings.get(row);
        Vector objVector = objMatrix.get(triplet.getObj());
        Vector subjVector = subjMatrix.get(triplet.getSubj());
        double error = calculateError(objVector, subjVector, triplet.getRating());
        Vector objGrad = (subjVector.times(error).plus(objVector.times(regParam))).times(learningRate);
        Vector subjGrad = (objVector.times(error).plus(subjVector.times(regParam))).times(learningRate);
        objGrads.put(triplet.getObj(), objGrad);
        subjGrads.put(triplet.getSubj(), subjGrad);
    }
    return new MatrixFactorizationGradient<>(objGrads, subjGrads, rows.length);
}
Also used : HashMap(java.util.HashMap) Vector(org.apache.ignite.ml.math.primitives.vector.Vector)

Example 72 with Vector

use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.

the class Deltas method getStateVector.

/**
 * @param mdl Model.
 * @return vector of model weights with intercept.
 */
private Vector getStateVector(SVMLinearClassificationModel mdl) {
    double intercept = mdl.intercept();
    Vector weights = mdl.weights();
    int stateVectorSize = weights.size() + 1;
    Vector res = weights.isDense() ? new DenseVector(stateVectorSize) : new SparseVector(stateVectorSize);
    res.set(0, intercept);
    weights.nonZeroes().forEach(ith -> res.set(ith.index(), ith.get()));
    return res;
}
Also used : SparseVector(org.apache.ignite.ml.math.primitives.vector.impl.SparseVector) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) SparseVector(org.apache.ignite.ml.math.primitives.vector.impl.SparseVector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)

Example 73 with Vector

use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.

the class Deltas method makeVectorWithInterceptElement.

/**
 */
private Vector makeVectorWithInterceptElement(LabeledVector row) {
    Vector vec = row.features().like(row.features().size() + 1);
    // set intercept element
    vec.set(0, 1);
    for (int j = 0; j < row.features().size(); j++) vec.set(j + 1, row.features().get(j));
    return vec;
}
Also used : Vector(org.apache.ignite.ml.math.primitives.vector.Vector) SparseVector(org.apache.ignite.ml.math.primitives.vector.impl.SparseVector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)

Example 74 with Vector

use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.

the class Deltas method calcDeltas.

/**
 */
private Deltas calcDeltas(double lb, Vector v, double alpha, double gradient, int vectorSize, int amountOfObservation) {
    if (gradient != 0.0) {
        double qii = v.dot(v);
        double newAlpha = calcNewAlpha(alpha, gradient, qii);
        Vector deltaWeights = v.times(lb * (newAlpha - alpha) / (this.getLambda() * amountOfObservation));
        return new Deltas(newAlpha - alpha, deltaWeights);
    } else
        return new Deltas(0.0, initializeWeightsWithZeros(vectorSize));
}
Also used : Vector(org.apache.ignite.ml.math.primitives.vector.Vector) SparseVector(org.apache.ignite.ml.math.primitives.vector.impl.SparseVector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)

Example 75 with Vector

use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.

the class Deltas method getDeltas.

/**
 */
private Deltas getDeltas(LabeledVectorSet data, Vector copiedWeights, int amountOfObservation, Vector tmpAlphas, int randomIdx) {
    LabeledVector row = (LabeledVector) data.getRow(randomIdx);
    Double lb = (Double) row.label();
    Vector v = makeVectorWithInterceptElement(row);
    double alpha = tmpAlphas.get(randomIdx);
    return maximize(lb, v, alpha, copiedWeights, amountOfObservation);
}
Also used : LabeledVector(org.apache.ignite.ml.structures.LabeledVector) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) SparseVector(org.apache.ignite.ml.math.primitives.vector.impl.SparseVector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)

Aggregations

Vector (org.apache.ignite.ml.math.primitives.vector.Vector)265 DenseVector (org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)95 Test (org.junit.Test)94 Ignite (org.apache.ignite.Ignite)78 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)49 HashMap (java.util.HashMap)39 SandboxMLCache (org.apache.ignite.examples.ml.util.SandboxMLCache)38 DummyVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer)26 FileNotFoundException (java.io.FileNotFoundException)22 TrainerTest (org.apache.ignite.ml.common.TrainerTest)22 DecisionTreeClassificationTrainer (org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer)21 DecisionTreeModel (org.apache.ignite.ml.tree.DecisionTreeModel)21 Serializable (java.io.Serializable)19 IgniteCache (org.apache.ignite.IgniteCache)18 EncoderTrainer (org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer)16 Cache (javax.cache.Cache)15 DoubleArrayVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer)15 EuclideanDistance (org.apache.ignite.ml.math.distances.EuclideanDistance)14 ArrayList (java.util.ArrayList)12 ModelsComposition (org.apache.ignite.ml.composition.ModelsComposition)12