Search in sources :

Example 1 with DenseMatrix

use of org.tribuo.math.la.DenseMatrix in project tribuo by oracle.

the class CRFParameters method getEmptyCopy.

/**
 * Returns a 3 element {@link Tensor} array.
 *
 * The first element is a {@link DenseVector} of label biases.
 * The second element is a {@link DenseMatrix} of feature-label weights.
 * The third element is a {@link DenseMatrix} of label-label transition weights.
 * @return A {@link Tensor} array.
 */
@Override
public Tensor[] getEmptyCopy() {
    Tensor[] output = new Tensor[3];
    output[0] = new DenseVector(biases.size());
    output[1] = new DenseMatrix(featureLabelWeights.getDimension1Size(), featureLabelWeights.getDimension2Size());
    output[2] = new DenseMatrix(labelLabelWeights.getDimension1Size(), labelLabelWeights.getDimension2Size());
    return output;
}
Also used : Tensor(org.tribuo.math.la.Tensor) DenseVector(org.tribuo.math.la.DenseVector) DenseMatrix(org.tribuo.math.la.DenseMatrix)

Example 2 with DenseMatrix

use of org.tribuo.math.la.DenseMatrix in project tribuo by oracle.

the class CRFParameters method merge.

@Override
public Tensor[] merge(Tensor[][] gradients, int size) {
    DenseVector biasUpdate = new DenseVector(biases.size());
    List<DenseSparseMatrix> updates = new ArrayList<>(size);
    DenseMatrix denseUpdates = null;
    DenseMatrix labelLabelUpdate = new DenseMatrix(labelLabelWeights.getDimension1Size(), labelLabelWeights.getDimension2Size());
    for (int j = 0; j < gradients.length; j++) {
        biasUpdate.intersectAndAddInPlace(gradients[j][0]);
        Matrix tmpUpdate = (Matrix) gradients[j][1];
        if (tmpUpdate instanceof DenseSparseMatrix) {
            updates.add((DenseSparseMatrix) tmpUpdate);
        } else {
            // is dense
            if (denseUpdates == null) {
                denseUpdates = (DenseMatrix) tmpUpdate;
            } else {
                denseUpdates.intersectAndAddInPlace(tmpUpdate);
            }
        }
        labelLabelUpdate.intersectAndAddInPlace(gradients[j][2]);
    }
    // Merge the combination of any dense and sparse updates
    Matrix featureLabelUpdate;
    if (updates.size() > 0) {
        featureLabelUpdate = merger.merge(updates.toArray(new DenseSparseMatrix[0]));
        if (denseUpdates != null) {
            denseUpdates.intersectAndAddInPlace(featureLabelUpdate);
            featureLabelUpdate = denseUpdates;
        }
    } else {
        featureLabelUpdate = denseUpdates;
    }
    return new Tensor[] { biasUpdate, featureLabelUpdate, labelLabelUpdate };
}
Also used : DenseMatrix(org.tribuo.math.la.DenseMatrix) DenseSparseMatrix(org.tribuo.math.la.DenseSparseMatrix) Matrix(org.tribuo.math.la.Matrix) Tensor(org.tribuo.math.la.Tensor) ArrayList(java.util.ArrayList) DenseSparseMatrix(org.tribuo.math.la.DenseSparseMatrix) DenseVector(org.tribuo.math.la.DenseVector) DenseMatrix(org.tribuo.math.la.DenseMatrix)

Example 3 with DenseMatrix

use of org.tribuo.math.la.DenseMatrix in project tribuo by oracle.

the class ChainHelper method viterbi.

/**
 * Runs Viterbi on a linear chain CRF. Uses the
 * linear predictions for each token and the label transition probabilities.
 * @param scores Tuple containing the label-label transition matrix, and the per token label scores.
 * @return Tuple containing the score of the maximum path and the maximum predicted label per token.
 */
public static ChainViterbiResults viterbi(ChainCliqueValues scores) {
    DenseMatrix markovScores = scores.transitionValues;
    DenseVector[] localScores = scores.localValues;
    int numLabels = markovScores.getDimension1Size();
    DenseVector[] costs = new DenseVector[scores.localValues.length];
    int[][] backPointers = new int[scores.localValues.length][];
    for (int i = 0; i < scores.localValues.length; i++) {
        costs[i] = new DenseVector(numLabels, Double.NEGATIVE_INFINITY);
        backPointers[i] = new int[numLabels];
        Arrays.fill(backPointers[i], -1);
    }
    costs[0].setElements(localScores[0]);
    for (int i = 1; i < scores.localValues.length; i++) {
        DenseVector curLocalScores = localScores[i];
        DenseVector curCost = costs[i];
        int[] curBackPointers = backPointers[i];
        DenseVector prevCost = costs[i - 1];
        for (int vi = 0; vi < numLabels; vi++) {
            double maxScore = Double.NEGATIVE_INFINITY;
            int maxIndex = -1;
            double curLocalScore = curLocalScores.get(vi);
            for (int vj = 0; vj < numLabels; vj++) {
                double curScore = markovScores.get(vj, vi) + prevCost.get(vj) + curLocalScore;
                if (curScore > maxScore) {
                    maxScore = curScore;
                    maxIndex = vj;
                }
            }
            curCost.set(vi, maxScore);
            if (maxIndex < 0) {
                maxIndex = 0;
            }
            curBackPointers[vi] = maxIndex;
        }
    }
    int[] mapValues = new int[scores.localValues.length];
    mapValues[mapValues.length - 1] = costs[costs.length - 1].indexOfMax();
    for (int j = mapValues.length - 2; j >= 0; j--) {
        mapValues[j] = backPointers[j + 1][mapValues[j + 1]];
    }
    return new ChainViterbiResults(costs[costs.length - 1].maxValue(), mapValues, scores);
}
Also used : DenseVector(org.tribuo.math.la.DenseVector) DenseMatrix(org.tribuo.math.la.DenseMatrix)

Example 4 with DenseMatrix

use of org.tribuo.math.la.DenseMatrix in project tribuo by oracle.

the class LinearParameters method getEmptyCopy.

/**
 * This returns a {@link DenseMatrix} the same size as the Parameters.
 * @return A {@link Tensor} array containing a single {@link DenseMatrix}.
 */
@Override
public Tensor[] getEmptyCopy() {
    DenseMatrix matrix = new DenseMatrix(weightMatrix.getDimension1Size(), weightMatrix.getDimension2Size());
    Tensor[] output = new Tensor[1];
    output[0] = matrix;
    return output;
}
Also used : Tensor(org.tribuo.math.la.Tensor) DenseMatrix(org.tribuo.math.la.DenseMatrix)

Example 5 with DenseMatrix

use of org.tribuo.math.la.DenseMatrix in project tribuo by oracle.

the class MultiLabelConfusionMatrixTest method testTabulateSingleLabel.

@Test
public void testTabulateSingleLabel() {
    MultiLabel a = label("a");
    MultiLabel b = label("b");
    MultiLabel c = label("c");
    List<Prediction<MultiLabel>> predictions = Arrays.asList(mkPrediction(a, a), mkPrediction(c, b), mkPrediction(b, b), mkPrediction(b, c));
    ImmutableOutputInfo<MultiLabel> domain = mkDomain(predictions);
    DenseMatrix[] mcm = MultiLabelConfusionMatrix.tabulate(domain, predictions).getMCM();
    int aIndex = domain.getID(a);
    int bIndex = domain.getID(b);
    int cIndex = domain.getID(c);
    assertEquals(domain.size(), mcm.length);
    assertEquals(3d, mcm[aIndex].get(0, 0));
    assertEquals(1d, mcm[aIndex].get(1, 1));
    assertEquals(1d, mcm[bIndex].get(0, 0));
    assertEquals(1d, mcm[bIndex].get(0, 1));
    assertEquals(1d, mcm[bIndex].get(1, 0));
    assertEquals(1d, mcm[bIndex].get(1, 1));
    assertEquals(2d, mcm[cIndex].get(0, 0));
    assertEquals(1d, mcm[cIndex].get(0, 1));
    assertEquals(1d, mcm[cIndex].get(1, 0));
}
Also used : MultiLabel(org.tribuo.multilabel.MultiLabel) Prediction(org.tribuo.Prediction) Utils.mkPrediction(org.tribuo.multilabel.Utils.mkPrediction) DenseMatrix(org.tribuo.math.la.DenseMatrix) Test(org.junit.jupiter.api.Test)

Aggregations

DenseMatrix (org.tribuo.math.la.DenseMatrix)21 DenseVector (org.tribuo.math.la.DenseVector)12 Tensor (org.tribuo.math.la.Tensor)8 ArrayList (java.util.ArrayList)5 Pair (com.oracle.labs.mlrg.olcut.util.Pair)4 HashMap (java.util.HashMap)4 List (java.util.List)4 DenseSparseMatrix (org.tribuo.math.la.DenseSparseMatrix)4 MultiLabel (org.tribuo.multilabel.MultiLabel)4 Label (org.tribuo.classification.Label)3 Matrix (org.tribuo.math.la.Matrix)3 PriorityQueue (java.util.PriorityQueue)2 Test (org.junit.jupiter.api.Test)2 Prediction (org.tribuo.Prediction)2 SGDVector (org.tribuo.math.la.SGDVector)2 SparseVector (org.tribuo.math.la.SparseVector)2 Utils.mkPrediction (org.tribuo.multilabel.Utils.mkPrediction)2 HashSet (java.util.HashSet)1 SplittableRandom (java.util.SplittableRandom)1 Feature (org.tribuo.Feature)1