use of org.tribuo.math.la.Tensor in project tribuo by oracle.
the class CRFParameters method getEmptyCopy.
/**
* Returns a 3 element {@link Tensor} array.
*
* The first element is a {@link DenseVector} of label biases.
* The second element is a {@link DenseMatrix} of feature-label weights.
* The third element is a {@link DenseMatrix} of label-label transition weights.
* @return A {@link Tensor} array.
*/
@Override
public Tensor[] getEmptyCopy() {
Tensor[] output = new Tensor[3];
output[0] = new DenseVector(biases.size());
output[1] = new DenseMatrix(featureLabelWeights.getDimension1Size(), featureLabelWeights.getDimension2Size());
output[2] = new DenseMatrix(labelLabelWeights.getDimension1Size(), labelLabelWeights.getDimension2Size());
return output;
}
use of org.tribuo.math.la.Tensor in project tribuo by oracle.
the class CRFParameters method merge.
@Override
public Tensor[] merge(Tensor[][] gradients, int size) {
DenseVector biasUpdate = new DenseVector(biases.size());
List<DenseSparseMatrix> updates = new ArrayList<>(size);
DenseMatrix denseUpdates = null;
DenseMatrix labelLabelUpdate = new DenseMatrix(labelLabelWeights.getDimension1Size(), labelLabelWeights.getDimension2Size());
for (int j = 0; j < gradients.length; j++) {
biasUpdate.intersectAndAddInPlace(gradients[j][0]);
Matrix tmpUpdate = (Matrix) gradients[j][1];
if (tmpUpdate instanceof DenseSparseMatrix) {
updates.add((DenseSparseMatrix) tmpUpdate);
} else {
// is dense
if (denseUpdates == null) {
denseUpdates = (DenseMatrix) tmpUpdate;
} else {
denseUpdates.intersectAndAddInPlace(tmpUpdate);
}
}
labelLabelUpdate.intersectAndAddInPlace(gradients[j][2]);
}
// Merge the combination of any dense and sparse updates
Matrix featureLabelUpdate;
if (updates.size() > 0) {
featureLabelUpdate = merger.merge(updates.toArray(new DenseSparseMatrix[0]));
if (denseUpdates != null) {
denseUpdates.intersectAndAddInPlace(featureLabelUpdate);
featureLabelUpdate = denseUpdates;
}
} else {
featureLabelUpdate = denseUpdates;
}
return new Tensor[] { biasUpdate, featureLabelUpdate, labelLabelUpdate };
}
use of org.tribuo.math.la.Tensor in project tribuo by oracle.
the class CRFTrainer method train.
@Override
public CRFModel train(SequenceDataset<Label> sequenceExamples, Map<String, Provenance> runProvenance) {
if (sequenceExamples.getOutputInfo().getUnknownCount() > 0) {
throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
}
// Creates a new RNG, adds one to the invocation count, generates a local optimiser.
SplittableRandom localRNG;
TrainerProvenance trainerProvenance;
StochasticGradientOptimiser localOptimiser;
synchronized (this) {
localRNG = rng.split();
localOptimiser = optimiser.copy();
trainerProvenance = getProvenance();
trainInvocationCounter++;
}
ImmutableOutputInfo<Label> labelIDMap = sequenceExamples.getOutputIDInfo();
ImmutableFeatureMap featureIDMap = sequenceExamples.getFeatureIDMap();
SGDVector[][] sgdFeatures = new SGDVector[sequenceExamples.size()][];
int[][] sgdLabels = new int[sequenceExamples.size()][];
double[] weights = new double[sequenceExamples.size()];
int n = 0;
for (SequenceExample<Label> example : sequenceExamples) {
weights[n] = example.getWeight();
Pair<int[], SGDVector[]> pair = CRFModel.convertToVector(example, featureIDMap, labelIDMap);
sgdFeatures[n] = pair.getB();
sgdLabels[n] = pair.getA();
n++;
}
logger.info(String.format("Training SGD CRF with %d examples", n));
CRFParameters crfParameters = new CRFParameters(featureIDMap.size(), labelIDMap.size());
localOptimiser.initialise(crfParameters);
double loss = 0.0;
int iteration = 0;
for (int i = 0; i < epochs; i++) {
if (shuffle) {
Util.shuffleInPlace(sgdFeatures, sgdLabels, weights, localRNG);
}
if (minibatchSize == 1) {
/*
* Special case a minibatch of size 1. Directly updates the parameters after each
* example rather than aggregating.
*/
for (int j = 0; j < sgdFeatures.length; j++) {
Pair<Double, Tensor[]> output = crfParameters.valueAndGradient(sgdFeatures[j], sgdLabels[j]);
loss += output.getA() * weights[j];
// Update the gradient with the current learning rates
Tensor[] updates = localOptimiser.step(output.getB(), weights[j]);
// Apply the update to the current parameters.
crfParameters.update(updates);
iteration++;
if ((iteration % loggingInterval == 0) && (loggingInterval != -1)) {
logger.info("At iteration " + iteration + ", average loss = " + loss / loggingInterval);
loss = 0.0;
}
}
} else {
Tensor[][] gradients = new Tensor[minibatchSize][];
for (int j = 0; j < sgdFeatures.length; j += minibatchSize) {
double tempWeight = 0.0;
int curSize = 0;
// Aggregate the gradient updates for each example in the minibatch
for (int k = j; k < j + minibatchSize && k < sgdFeatures.length; k++) {
Pair<Double, Tensor[]> output = crfParameters.valueAndGradient(sgdFeatures[j], sgdLabels[j]);
loss += output.getA() * weights[k];
tempWeight += weights[k];
gradients[k - j] = output.getB();
curSize++;
}
// Merge the values into a single gradient update
Tensor[] updates = crfParameters.merge(gradients, curSize);
for (Tensor update : updates) {
update.scaleInPlace(minibatchSize);
}
tempWeight /= minibatchSize;
// Update the gradient with the current learning rates
updates = localOptimiser.step(updates, tempWeight);
// Apply the gradient.
crfParameters.update(updates);
iteration++;
if ((loggingInterval != -1) && (iteration % loggingInterval == 0)) {
logger.info("At iteration " + iteration + ", average loss = " + loss / loggingInterval);
loss = 0.0;
}
}
}
}
localOptimiser.finalise();
// public CRFModel(String name, String description, ImmutableInfoMap featureIDMap, ImmutableInfoMap outputIDInfo, CRFParameters parameters) {
ModelProvenance provenance = new ModelProvenance(CRFModel.class.getName(), OffsetDateTime.now(), sequenceExamples.getProvenance(), trainerProvenance, runProvenance);
CRFModel model = new CRFModel("crf-sgd-model", provenance, featureIDMap, labelIDMap, crfParameters);
localOptimiser.reset();
return model;
}
use of org.tribuo.math.la.Tensor in project tribuo by oracle.
the class LinearParameters method getEmptyCopy.
/**
* This returns a {@link DenseMatrix} the same size as the Parameters.
* @return A {@link Tensor} array containing a single {@link DenseMatrix}.
*/
@Override
public Tensor[] getEmptyCopy() {
DenseMatrix matrix = new DenseMatrix(weightMatrix.getDimension1Size(), weightMatrix.getDimension2Size());
Tensor[] output = new Tensor[1];
output[0] = matrix;
return output;
}
use of org.tribuo.math.la.Tensor in project tribuo by oracle.
the class Pegasos method step.
@Override
public Tensor[] step(Tensor[] updates, double weight) {
double eta_t = baseRate / (lambda * iteration);
for (Tensor t : updates) {
t.scaleInPlace(eta_t * weight);
}
iteration++;
return updates;
}
Aggregations