use of org.tribuo.math.la.DenseMatrix in project tribuo by oracle.
the class CRFParameters method getEmptyCopy.
/**
* Returns a 3 element {@link Tensor} array.
*
* The first element is a {@link DenseVector} of label biases.
* The second element is a {@link DenseMatrix} of feature-label weights.
* The third element is a {@link DenseMatrix} of label-label transition weights.
* @return A {@link Tensor} array.
*/
@Override
public Tensor[] getEmptyCopy() {
Tensor[] output = new Tensor[3];
output[0] = new DenseVector(biases.size());
output[1] = new DenseMatrix(featureLabelWeights.getDimension1Size(), featureLabelWeights.getDimension2Size());
output[2] = new DenseMatrix(labelLabelWeights.getDimension1Size(), labelLabelWeights.getDimension2Size());
return output;
}
use of org.tribuo.math.la.DenseMatrix in project tribuo by oracle.
the class CRFParameters method merge.
@Override
public Tensor[] merge(Tensor[][] gradients, int size) {
DenseVector biasUpdate = new DenseVector(biases.size());
List<DenseSparseMatrix> updates = new ArrayList<>(size);
DenseMatrix denseUpdates = null;
DenseMatrix labelLabelUpdate = new DenseMatrix(labelLabelWeights.getDimension1Size(), labelLabelWeights.getDimension2Size());
for (int j = 0; j < gradients.length; j++) {
biasUpdate.intersectAndAddInPlace(gradients[j][0]);
Matrix tmpUpdate = (Matrix) gradients[j][1];
if (tmpUpdate instanceof DenseSparseMatrix) {
updates.add((DenseSparseMatrix) tmpUpdate);
} else {
// is dense
if (denseUpdates == null) {
denseUpdates = (DenseMatrix) tmpUpdate;
} else {
denseUpdates.intersectAndAddInPlace(tmpUpdate);
}
}
labelLabelUpdate.intersectAndAddInPlace(gradients[j][2]);
}
// Merge the combination of any dense and sparse updates
Matrix featureLabelUpdate;
if (updates.size() > 0) {
featureLabelUpdate = merger.merge(updates.toArray(new DenseSparseMatrix[0]));
if (denseUpdates != null) {
denseUpdates.intersectAndAddInPlace(featureLabelUpdate);
featureLabelUpdate = denseUpdates;
}
} else {
featureLabelUpdate = denseUpdates;
}
return new Tensor[] { biasUpdate, featureLabelUpdate, labelLabelUpdate };
}
use of org.tribuo.math.la.DenseMatrix in project tribuo by oracle.
the class ChainHelper method viterbi.
/**
* Runs Viterbi on a linear chain CRF. Uses the
* linear predictions for each token and the label transition probabilities.
* @param scores Tuple containing the label-label transition matrix, and the per token label scores.
* @return Tuple containing the score of the maximum path and the maximum predicted label per token.
*/
public static ChainViterbiResults viterbi(ChainCliqueValues scores) {
DenseMatrix markovScores = scores.transitionValues;
DenseVector[] localScores = scores.localValues;
int numLabels = markovScores.getDimension1Size();
DenseVector[] costs = new DenseVector[scores.localValues.length];
int[][] backPointers = new int[scores.localValues.length][];
for (int i = 0; i < scores.localValues.length; i++) {
costs[i] = new DenseVector(numLabels, Double.NEGATIVE_INFINITY);
backPointers[i] = new int[numLabels];
Arrays.fill(backPointers[i], -1);
}
costs[0].setElements(localScores[0]);
for (int i = 1; i < scores.localValues.length; i++) {
DenseVector curLocalScores = localScores[i];
DenseVector curCost = costs[i];
int[] curBackPointers = backPointers[i];
DenseVector prevCost = costs[i - 1];
for (int vi = 0; vi < numLabels; vi++) {
double maxScore = Double.NEGATIVE_INFINITY;
int maxIndex = -1;
double curLocalScore = curLocalScores.get(vi);
for (int vj = 0; vj < numLabels; vj++) {
double curScore = markovScores.get(vj, vi) + prevCost.get(vj) + curLocalScore;
if (curScore > maxScore) {
maxScore = curScore;
maxIndex = vj;
}
}
curCost.set(vi, maxScore);
if (maxIndex < 0) {
maxIndex = 0;
}
curBackPointers[vi] = maxIndex;
}
}
int[] mapValues = new int[scores.localValues.length];
mapValues[mapValues.length - 1] = costs[costs.length - 1].indexOfMax();
for (int j = mapValues.length - 2; j >= 0; j--) {
mapValues[j] = backPointers[j + 1][mapValues[j + 1]];
}
return new ChainViterbiResults(costs[costs.length - 1].maxValue(), mapValues, scores);
}
use of org.tribuo.math.la.DenseMatrix in project tribuo by oracle.
the class LinearParameters method getEmptyCopy.
/**
* This returns a {@link DenseMatrix} the same size as the Parameters.
* @return A {@link Tensor} array containing a single {@link DenseMatrix}.
*/
@Override
public Tensor[] getEmptyCopy() {
DenseMatrix matrix = new DenseMatrix(weightMatrix.getDimension1Size(), weightMatrix.getDimension2Size());
Tensor[] output = new Tensor[1];
output[0] = matrix;
return output;
}
use of org.tribuo.math.la.DenseMatrix in project tribuo by oracle.
the class MultiLabelConfusionMatrixTest method testTabulateSingleLabel.
@Test
public void testTabulateSingleLabel() {
MultiLabel a = label("a");
MultiLabel b = label("b");
MultiLabel c = label("c");
List<Prediction<MultiLabel>> predictions = Arrays.asList(mkPrediction(a, a), mkPrediction(c, b), mkPrediction(b, b), mkPrediction(b, c));
ImmutableOutputInfo<MultiLabel> domain = mkDomain(predictions);
DenseMatrix[] mcm = MultiLabelConfusionMatrix.tabulate(domain, predictions).getMCM();
int aIndex = domain.getID(a);
int bIndex = domain.getID(b);
int cIndex = domain.getID(c);
assertEquals(domain.size(), mcm.length);
assertEquals(3d, mcm[aIndex].get(0, 0));
assertEquals(1d, mcm[aIndex].get(1, 1));
assertEquals(1d, mcm[bIndex].get(0, 0));
assertEquals(1d, mcm[bIndex].get(0, 1));
assertEquals(1d, mcm[bIndex].get(1, 0));
assertEquals(1d, mcm[bIndex].get(1, 1));
assertEquals(2d, mcm[cIndex].get(0, 0));
assertEquals(1d, mcm[cIndex].get(0, 1));
assertEquals(1d, mcm[cIndex].get(1, 0));
}
Aggregations