use of org.tribuo.math.la.DenseSparseMatrix in project tribuo by oracle.
the class CRFParameters method merge.
@Override
public Tensor[] merge(Tensor[][] gradients, int size) {
DenseVector biasUpdate = new DenseVector(biases.size());
List<DenseSparseMatrix> updates = new ArrayList<>(size);
DenseMatrix denseUpdates = null;
DenseMatrix labelLabelUpdate = new DenseMatrix(labelLabelWeights.getDimension1Size(), labelLabelWeights.getDimension2Size());
for (int j = 0; j < gradients.length; j++) {
biasUpdate.intersectAndAddInPlace(gradients[j][0]);
Matrix tmpUpdate = (Matrix) gradients[j][1];
if (tmpUpdate instanceof DenseSparseMatrix) {
updates.add((DenseSparseMatrix) tmpUpdate);
} else {
// is dense
if (denseUpdates == null) {
denseUpdates = (DenseMatrix) tmpUpdate;
} else {
denseUpdates.intersectAndAddInPlace(tmpUpdate);
}
}
labelLabelUpdate.intersectAndAddInPlace(gradients[j][2]);
}
// Merge the combination of any dense and sparse updates
Matrix featureLabelUpdate;
if (updates.size() > 0) {
featureLabelUpdate = merger.merge(updates.toArray(new DenseSparseMatrix[0]));
if (denseUpdates != null) {
denseUpdates.intersectAndAddInPlace(featureLabelUpdate);
featureLabelUpdate = denseUpdates;
}
} else {
featureLabelUpdate = denseUpdates;
}
return new Tensor[] { biasUpdate, featureLabelUpdate, labelLabelUpdate };
}
use of org.tribuo.math.la.DenseSparseMatrix in project tribuo by oracle.
the class MergerTest method testMerger.
public void testMerger(Merger merger) {
DenseSparseMatrix[] array = new DenseSparseMatrix[2];
array[0] = generateA();
array[1] = generateB();
DenseSparseMatrix output = generateAB();
DenseSparseMatrix merged = merger.merge(array);
assertEquals(output, merged, "Merge A - B unsuccessful");
array[0] = generateB();
output = generateBB();
merged = merger.merge(array);
assertEquals(output, merged, "Merge B - B unsuccessful");
array[0] = generateZipA();
array[1] = generateZipB();
output = generateZip();
merged = merger.merge(array);
assertEquals(output, merged, "Merge zip unsuccessful");
array = new DenseSparseMatrix[4];
array[0] = generateA();
array[1] = generateB();
array[2] = generateA();
array[3] = generateB();
output = generateAABB();
merged = merger.merge(array);
assertEquals(output, merged, "Merge A - B - A - B unsuccessful");
}
use of org.tribuo.math.la.DenseSparseMatrix in project tribuo by oracle.
the class FMParameters method gradients.
/**
* Generate the gradients for a particular feature vector given
* the loss and the per output gradients.
* <p>
* This method returns a {@link Tensor} array with numLabels + 2 elements.
*
* @param score The Pair returned by the objective.
* @param features The feature vector.
* @return A {@link Tensor} array containing all the gradients.
*/
@Override
public Tensor[] gradients(Pair<Double, SGDVector> score, SGDVector features) {
Tensor[] gradients = new Tensor[weights.length];
SGDVector outputGradient = score.getB();
// Bias gradient
if (outputGradient instanceof SparseVector) {
gradients[0] = ((SparseVector) outputGradient).densify();
} else {
gradients[0] = outputGradient.copy();
}
// Feature gradients
gradients[1] = outputGradient.outer(features);
// per label
for (int i = 2; i < weights.length; i++) {
double curOutputGradient = outputGradient.get(i - 2);
DenseMatrix curFactors = (DenseMatrix) weights[i];
if (curOutputGradient != 0.0) {
// compute /sum_j v_{j,f}x_j
SGDVector factorSum = curFactors.leftMultiply(features);
// grad_f: dy/d0 * (x_i * factorSum_f - v_{i,f} * x_i * x_i)
Matrix factorGradMatrix;
if (features instanceof SparseVector) {
List<SparseVector> vectors = new ArrayList<>(numFactors);
for (int j = 0; j < numFactors; j++) {
vectors.add(((SparseVector) features).copy());
}
factorGradMatrix = new DenseSparseMatrix(vectors);
} else {
factorGradMatrix = new DenseMatrix(numFactors, features.size());
for (int j = 0; j < numFactors; j++) {
for (int k = 0; k < features.size(); k++) {
factorGradMatrix.set(j, k, features.get(k));
}
}
}
for (int j = 0; j < numFactors; j++) {
// This gets a mutable view of the row
SGDVector curFactorGrad = factorGradMatrix.getRow(j);
double curFactorSum = factorSum.get(j);
final int jFinal = j;
// Compute the gradient for this element of the factor vector
curFactorGrad.foreachIndexedInPlace((Integer idx, Double a) -> a * curFactorSum - curFactors.get(jFinal, idx) * a * a);
// Multiply by the output gradient
curFactorGrad.scaleInPlace(curOutputGradient);
}
gradients[i] = factorGradMatrix;
} else {
// If the output gradient is 0.0 then all the factor gradients are zero.
// Technically with regularization we should shrink the weights for the specified features.
gradients[i] = new DenseSparseMatrix(numFactors, features.size());
}
}
return gradients;
}
use of org.tribuo.math.la.DenseSparseMatrix in project tribuo by oracle.
the class FMParameters method merge.
@Override
public Tensor[] merge(Tensor[][] gradients, int size) {
Tensor[] output = new Tensor[weights.length];
for (int i = 0; i < weights.length; i++) {
if (gradients[0][i] instanceof DenseVector) {
for (int j = 1; j < size; j++) {
gradients[0][i].intersectAndAddInPlace(gradients[j][i]);
}
output[i] = gradients[0][i];
} else if (gradients[0][i] instanceof DenseMatrix) {
for (int j = 1; j < size; j++) {
gradients[0][i].intersectAndAddInPlace(gradients[j][i]);
}
output[i] = gradients[0][i];
} else if (gradients[0][i] instanceof DenseSparseMatrix) {
DenseSparseMatrix[] updates = new DenseSparseMatrix[size];
for (int j = 0; j < updates.length; j++) {
updates[j] = (DenseSparseMatrix) gradients[j][0];
}
DenseSparseMatrix update = merger.merge(updates);
output[i] = update;
} else {
throw new IllegalStateException("Unexpected gradient type, expected DenseVector, DenseMatrix or DenseSparseMatrix, received " + gradients[0][i].getClass().getName());
}
}
return output;
}
use of org.tribuo.math.la.DenseSparseMatrix in project tribuo by oracle.
the class HeapMerger method merge.
@Override
public DenseSparseMatrix merge(DenseSparseMatrix[] inputs) {
int denseLength = inputs[0].getDimension1Size();
int sparseLength = inputs[0].getDimension2Size();
int[] totalLengths = new int[inputs[0].getDimension1Size()];
for (int i = 0; i < inputs.length; i++) {
for (int j = 0; j < totalLengths.length; j++) {
totalLengths[j] += inputs[i].numActiveElements(j);
}
}
int maxLength = 0;
for (int i = 0; i < totalLengths.length; i++) {
if (totalLengths[i] > maxLength) {
maxLength = totalLengths[i];
}
}
SparseVector[] output = new SparseVector[denseLength];
int[] indicesBuffer = new int[maxLength];
double[] valuesBuffer = new double[maxLength];
List<SparseVector> vectors = new ArrayList<>();
for (int i = 0; i < denseLength; i++) {
vectors.clear();
for (DenseSparseMatrix m : inputs) {
SparseVector vec = m.getRow(i);
if (vec.numActiveElements() > 0) {
vectors.add(vec);
}
}
output[i] = merge(vectors, sparseLength, indicesBuffer, valuesBuffer);
}
return DenseSparseMatrix.createFromSparseVectors(output);
}
Aggregations