Search in sources :

Example 1 with DenseSparseMatrix

use of org.tribuo.math.la.DenseSparseMatrix in project tribuo by oracle.

the class CRFParameters method merge.

@Override
public Tensor[] merge(Tensor[][] gradients, int size) {
    DenseVector biasUpdate = new DenseVector(biases.size());
    List<DenseSparseMatrix> updates = new ArrayList<>(size);
    DenseMatrix denseUpdates = null;
    DenseMatrix labelLabelUpdate = new DenseMatrix(labelLabelWeights.getDimension1Size(), labelLabelWeights.getDimension2Size());
    for (int j = 0; j < gradients.length; j++) {
        biasUpdate.intersectAndAddInPlace(gradients[j][0]);
        Matrix tmpUpdate = (Matrix) gradients[j][1];
        if (tmpUpdate instanceof DenseSparseMatrix) {
            updates.add((DenseSparseMatrix) tmpUpdate);
        } else {
            // is dense
            if (denseUpdates == null) {
                denseUpdates = (DenseMatrix) tmpUpdate;
            } else {
                denseUpdates.intersectAndAddInPlace(tmpUpdate);
            }
        }
        labelLabelUpdate.intersectAndAddInPlace(gradients[j][2]);
    }
    // Merge the combination of any dense and sparse updates
    Matrix featureLabelUpdate;
    if (updates.size() > 0) {
        featureLabelUpdate = merger.merge(updates.toArray(new DenseSparseMatrix[0]));
        if (denseUpdates != null) {
            denseUpdates.intersectAndAddInPlace(featureLabelUpdate);
            featureLabelUpdate = denseUpdates;
        }
    } else {
        featureLabelUpdate = denseUpdates;
    }
    return new Tensor[] { biasUpdate, featureLabelUpdate, labelLabelUpdate };
}
Also used : DenseMatrix(org.tribuo.math.la.DenseMatrix) DenseSparseMatrix(org.tribuo.math.la.DenseSparseMatrix) Matrix(org.tribuo.math.la.Matrix) Tensor(org.tribuo.math.la.Tensor) ArrayList(java.util.ArrayList) DenseSparseMatrix(org.tribuo.math.la.DenseSparseMatrix) DenseVector(org.tribuo.math.la.DenseVector) DenseMatrix(org.tribuo.math.la.DenseMatrix)

Example 2 with DenseSparseMatrix

use of org.tribuo.math.la.DenseSparseMatrix in project tribuo by oracle.

the class MergerTest method testMerger.

public void testMerger(Merger merger) {
    DenseSparseMatrix[] array = new DenseSparseMatrix[2];
    array[0] = generateA();
    array[1] = generateB();
    DenseSparseMatrix output = generateAB();
    DenseSparseMatrix merged = merger.merge(array);
    assertEquals(output, merged, "Merge A - B unsuccessful");
    array[0] = generateB();
    output = generateBB();
    merged = merger.merge(array);
    assertEquals(output, merged, "Merge B - B unsuccessful");
    array[0] = generateZipA();
    array[1] = generateZipB();
    output = generateZip();
    merged = merger.merge(array);
    assertEquals(output, merged, "Merge zip unsuccessful");
    array = new DenseSparseMatrix[4];
    array[0] = generateA();
    array[1] = generateB();
    array[2] = generateA();
    array[3] = generateB();
    output = generateAABB();
    merged = merger.merge(array);
    assertEquals(output, merged, "Merge A - B - A - B unsuccessful");
}
Also used : DenseSparseMatrix(org.tribuo.math.la.DenseSparseMatrix)

Example 3 with DenseSparseMatrix

use of org.tribuo.math.la.DenseSparseMatrix in project tribuo by oracle.

the class FMParameters method gradients.

/**
 * Generate the gradients for a particular feature vector given
 * the loss and the per output gradients.
 * <p>
 * This method returns a {@link Tensor} array with numLabels + 2 elements.
 *
 * @param score    The Pair returned by the objective.
 * @param features The feature vector.
 * @return A {@link Tensor} array containing all the gradients.
 */
@Override
public Tensor[] gradients(Pair<Double, SGDVector> score, SGDVector features) {
    Tensor[] gradients = new Tensor[weights.length];
    SGDVector outputGradient = score.getB();
    // Bias gradient
    if (outputGradient instanceof SparseVector) {
        gradients[0] = ((SparseVector) outputGradient).densify();
    } else {
        gradients[0] = outputGradient.copy();
    }
    // Feature gradients
    gradients[1] = outputGradient.outer(features);
    // per label
    for (int i = 2; i < weights.length; i++) {
        double curOutputGradient = outputGradient.get(i - 2);
        DenseMatrix curFactors = (DenseMatrix) weights[i];
        if (curOutputGradient != 0.0) {
            // compute /sum_j v_{j,f}x_j
            SGDVector factorSum = curFactors.leftMultiply(features);
            // grad_f: dy/d0 * (x_i * factorSum_f - v_{i,f} * x_i * x_i)
            Matrix factorGradMatrix;
            if (features instanceof SparseVector) {
                List<SparseVector> vectors = new ArrayList<>(numFactors);
                for (int j = 0; j < numFactors; j++) {
                    vectors.add(((SparseVector) features).copy());
                }
                factorGradMatrix = new DenseSparseMatrix(vectors);
            } else {
                factorGradMatrix = new DenseMatrix(numFactors, features.size());
                for (int j = 0; j < numFactors; j++) {
                    for (int k = 0; k < features.size(); k++) {
                        factorGradMatrix.set(j, k, features.get(k));
                    }
                }
            }
            for (int j = 0; j < numFactors; j++) {
                // This gets a mutable view of the row
                SGDVector curFactorGrad = factorGradMatrix.getRow(j);
                double curFactorSum = factorSum.get(j);
                final int jFinal = j;
                // Compute the gradient for this element of the factor vector
                curFactorGrad.foreachIndexedInPlace((Integer idx, Double a) -> a * curFactorSum - curFactors.get(jFinal, idx) * a * a);
                // Multiply by the output gradient
                curFactorGrad.scaleInPlace(curOutputGradient);
            }
            gradients[i] = factorGradMatrix;
        } else {
            // If the output gradient is 0.0 then all the factor gradients are zero.
            // Technically with regularization we should shrink the weights for the specified features.
            gradients[i] = new DenseSparseMatrix(numFactors, features.size());
        }
    }
    return gradients;
}
Also used : Tensor(org.tribuo.math.la.Tensor) ArrayList(java.util.ArrayList) DenseSparseMatrix(org.tribuo.math.la.DenseSparseMatrix) SparseVector(org.tribuo.math.la.SparseVector) DenseMatrix(org.tribuo.math.la.DenseMatrix) DenseMatrix(org.tribuo.math.la.DenseMatrix) DenseSparseMatrix(org.tribuo.math.la.DenseSparseMatrix) Matrix(org.tribuo.math.la.Matrix) SGDVector(org.tribuo.math.la.SGDVector)

Example 4 with DenseSparseMatrix

use of org.tribuo.math.la.DenseSparseMatrix in project tribuo by oracle.

the class FMParameters method merge.

@Override
public Tensor[] merge(Tensor[][] gradients, int size) {
    Tensor[] output = new Tensor[weights.length];
    for (int i = 0; i < weights.length; i++) {
        if (gradients[0][i] instanceof DenseVector) {
            for (int j = 1; j < size; j++) {
                gradients[0][i].intersectAndAddInPlace(gradients[j][i]);
            }
            output[i] = gradients[0][i];
        } else if (gradients[0][i] instanceof DenseMatrix) {
            for (int j = 1; j < size; j++) {
                gradients[0][i].intersectAndAddInPlace(gradients[j][i]);
            }
            output[i] = gradients[0][i];
        } else if (gradients[0][i] instanceof DenseSparseMatrix) {
            DenseSparseMatrix[] updates = new DenseSparseMatrix[size];
            for (int j = 0; j < updates.length; j++) {
                updates[j] = (DenseSparseMatrix) gradients[j][0];
            }
            DenseSparseMatrix update = merger.merge(updates);
            output[i] = update;
        } else {
            throw new IllegalStateException("Unexpected gradient type, expected DenseVector, DenseMatrix or DenseSparseMatrix, received " + gradients[0][i].getClass().getName());
        }
    }
    return output;
}
Also used : Tensor(org.tribuo.math.la.Tensor) DenseSparseMatrix(org.tribuo.math.la.DenseSparseMatrix) DenseVector(org.tribuo.math.la.DenseVector) DenseMatrix(org.tribuo.math.la.DenseMatrix)

Example 5 with DenseSparseMatrix

use of org.tribuo.math.la.DenseSparseMatrix in project tribuo by oracle.

the class HeapMerger method merge.

@Override
public DenseSparseMatrix merge(DenseSparseMatrix[] inputs) {
    int denseLength = inputs[0].getDimension1Size();
    int sparseLength = inputs[0].getDimension2Size();
    int[] totalLengths = new int[inputs[0].getDimension1Size()];
    for (int i = 0; i < inputs.length; i++) {
        for (int j = 0; j < totalLengths.length; j++) {
            totalLengths[j] += inputs[i].numActiveElements(j);
        }
    }
    int maxLength = 0;
    for (int i = 0; i < totalLengths.length; i++) {
        if (totalLengths[i] > maxLength) {
            maxLength = totalLengths[i];
        }
    }
    SparseVector[] output = new SparseVector[denseLength];
    int[] indicesBuffer = new int[maxLength];
    double[] valuesBuffer = new double[maxLength];
    List<SparseVector> vectors = new ArrayList<>();
    for (int i = 0; i < denseLength; i++) {
        vectors.clear();
        for (DenseSparseMatrix m : inputs) {
            SparseVector vec = m.getRow(i);
            if (vec.numActiveElements() > 0) {
                vectors.add(vec);
            }
        }
        output[i] = merge(vectors, sparseLength, indicesBuffer, valuesBuffer);
    }
    return DenseSparseMatrix.createFromSparseVectors(output);
}
Also used : ArrayList(java.util.ArrayList) DenseSparseMatrix(org.tribuo.math.la.DenseSparseMatrix) SparseVector(org.tribuo.math.la.SparseVector)

Aggregations

DenseSparseMatrix (org.tribuo.math.la.DenseSparseMatrix)7 DenseMatrix (org.tribuo.math.la.DenseMatrix)4 Tensor (org.tribuo.math.la.Tensor)4 ArrayList (java.util.ArrayList)3 DenseVector (org.tribuo.math.la.DenseVector)3 Matrix (org.tribuo.math.la.Matrix)3 SparseVector (org.tribuo.math.la.SparseVector)3 Pair (com.oracle.labs.mlrg.olcut.util.Pair)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 Feature (org.tribuo.Feature)1 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)1 Label (org.tribuo.classification.Label)1 SGDVector (org.tribuo.math.la.SGDVector)1 ModelProvenance (org.tribuo.provenance.ModelProvenance)1 TrainerProvenance (org.tribuo.provenance.TrainerProvenance)1