Search in sources :

Example 1 with Tensor

use of org.tribuo.math.la.Tensor in project tribuo by oracle.

the class CRFParameters method getEmptyCopy.

/**
 * Returns a 3 element {@link Tensor} array.
 *
 * The first element is a {@link DenseVector} of label biases.
 * The second element is a {@link DenseMatrix} of feature-label weights.
 * The third element is a {@link DenseMatrix} of label-label transition weights.
 * @return A {@link Tensor} array.
 */
@Override
public Tensor[] getEmptyCopy() {
    Tensor[] output = new Tensor[3];
    output[0] = new DenseVector(biases.size());
    output[1] = new DenseMatrix(featureLabelWeights.getDimension1Size(), featureLabelWeights.getDimension2Size());
    output[2] = new DenseMatrix(labelLabelWeights.getDimension1Size(), labelLabelWeights.getDimension2Size());
    return output;
}
Also used : Tensor(org.tribuo.math.la.Tensor) DenseVector(org.tribuo.math.la.DenseVector) DenseMatrix(org.tribuo.math.la.DenseMatrix)

Example 2 with Tensor

use of org.tribuo.math.la.Tensor in project tribuo by oracle.

the class CRFParameters method merge.

@Override
public Tensor[] merge(Tensor[][] gradients, int size) {
    DenseVector biasUpdate = new DenseVector(biases.size());
    List<DenseSparseMatrix> updates = new ArrayList<>(size);
    DenseMatrix denseUpdates = null;
    DenseMatrix labelLabelUpdate = new DenseMatrix(labelLabelWeights.getDimension1Size(), labelLabelWeights.getDimension2Size());
    for (int j = 0; j < gradients.length; j++) {
        biasUpdate.intersectAndAddInPlace(gradients[j][0]);
        Matrix tmpUpdate = (Matrix) gradients[j][1];
        if (tmpUpdate instanceof DenseSparseMatrix) {
            updates.add((DenseSparseMatrix) tmpUpdate);
        } else {
            // is dense
            if (denseUpdates == null) {
                denseUpdates = (DenseMatrix) tmpUpdate;
            } else {
                denseUpdates.intersectAndAddInPlace(tmpUpdate);
            }
        }
        labelLabelUpdate.intersectAndAddInPlace(gradients[j][2]);
    }
    // Merge the combination of any dense and sparse updates
    Matrix featureLabelUpdate;
    if (updates.size() > 0) {
        featureLabelUpdate = merger.merge(updates.toArray(new DenseSparseMatrix[0]));
        if (denseUpdates != null) {
            denseUpdates.intersectAndAddInPlace(featureLabelUpdate);
            featureLabelUpdate = denseUpdates;
        }
    } else {
        featureLabelUpdate = denseUpdates;
    }
    return new Tensor[] { biasUpdate, featureLabelUpdate, labelLabelUpdate };
}
Also used : DenseMatrix(org.tribuo.math.la.DenseMatrix) DenseSparseMatrix(org.tribuo.math.la.DenseSparseMatrix) Matrix(org.tribuo.math.la.Matrix) Tensor(org.tribuo.math.la.Tensor) ArrayList(java.util.ArrayList) DenseSparseMatrix(org.tribuo.math.la.DenseSparseMatrix) DenseVector(org.tribuo.math.la.DenseVector) DenseMatrix(org.tribuo.math.la.DenseMatrix)

Example 3 with Tensor

use of org.tribuo.math.la.Tensor in project tribuo by oracle.

the class CRFTrainer method train.

@Override
public CRFModel train(SequenceDataset<Label> sequenceExamples, Map<String, Provenance> runProvenance) {
    if (sequenceExamples.getOutputInfo().getUnknownCount() > 0) {
        throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
    }
    // Creates a new RNG, adds one to the invocation count, generates a local optimiser.
    SplittableRandom localRNG;
    TrainerProvenance trainerProvenance;
    StochasticGradientOptimiser localOptimiser;
    synchronized (this) {
        localRNG = rng.split();
        localOptimiser = optimiser.copy();
        trainerProvenance = getProvenance();
        trainInvocationCounter++;
    }
    ImmutableOutputInfo<Label> labelIDMap = sequenceExamples.getOutputIDInfo();
    ImmutableFeatureMap featureIDMap = sequenceExamples.getFeatureIDMap();
    SGDVector[][] sgdFeatures = new SGDVector[sequenceExamples.size()][];
    int[][] sgdLabels = new int[sequenceExamples.size()][];
    double[] weights = new double[sequenceExamples.size()];
    int n = 0;
    for (SequenceExample<Label> example : sequenceExamples) {
        weights[n] = example.getWeight();
        Pair<int[], SGDVector[]> pair = CRFModel.convertToVector(example, featureIDMap, labelIDMap);
        sgdFeatures[n] = pair.getB();
        sgdLabels[n] = pair.getA();
        n++;
    }
    logger.info(String.format("Training SGD CRF with %d examples", n));
    CRFParameters crfParameters = new CRFParameters(featureIDMap.size(), labelIDMap.size());
    localOptimiser.initialise(crfParameters);
    double loss = 0.0;
    int iteration = 0;
    for (int i = 0; i < epochs; i++) {
        if (shuffle) {
            Util.shuffleInPlace(sgdFeatures, sgdLabels, weights, localRNG);
        }
        if (minibatchSize == 1) {
            /*
                 * Special case a minibatch of size 1. Directly updates the parameters after each
                 * example rather than aggregating.
                 */
            for (int j = 0; j < sgdFeatures.length; j++) {
                Pair<Double, Tensor[]> output = crfParameters.valueAndGradient(sgdFeatures[j], sgdLabels[j]);
                loss += output.getA() * weights[j];
                // Update the gradient with the current learning rates
                Tensor[] updates = localOptimiser.step(output.getB(), weights[j]);
                // Apply the update to the current parameters.
                crfParameters.update(updates);
                iteration++;
                if ((iteration % loggingInterval == 0) && (loggingInterval != -1)) {
                    logger.info("At iteration " + iteration + ", average loss = " + loss / loggingInterval);
                    loss = 0.0;
                }
            }
        } else {
            Tensor[][] gradients = new Tensor[minibatchSize][];
            for (int j = 0; j < sgdFeatures.length; j += minibatchSize) {
                double tempWeight = 0.0;
                int curSize = 0;
                // Aggregate the gradient updates for each example in the minibatch
                for (int k = j; k < j + minibatchSize && k < sgdFeatures.length; k++) {
                    Pair<Double, Tensor[]> output = crfParameters.valueAndGradient(sgdFeatures[j], sgdLabels[j]);
                    loss += output.getA() * weights[k];
                    tempWeight += weights[k];
                    gradients[k - j] = output.getB();
                    curSize++;
                }
                // Merge the values into a single gradient update
                Tensor[] updates = crfParameters.merge(gradients, curSize);
                for (Tensor update : updates) {
                    update.scaleInPlace(minibatchSize);
                }
                tempWeight /= minibatchSize;
                // Update the gradient with the current learning rates
                updates = localOptimiser.step(updates, tempWeight);
                // Apply the gradient.
                crfParameters.update(updates);
                iteration++;
                if ((loggingInterval != -1) && (iteration % loggingInterval == 0)) {
                    logger.info("At iteration " + iteration + ", average loss = " + loss / loggingInterval);
                    loss = 0.0;
                }
            }
        }
    }
    localOptimiser.finalise();
    // public CRFModel(String name, String description, ImmutableInfoMap featureIDMap, ImmutableInfoMap outputIDInfo, CRFParameters parameters) {
    ModelProvenance provenance = new ModelProvenance(CRFModel.class.getName(), OffsetDateTime.now(), sequenceExamples.getProvenance(), trainerProvenance, runProvenance);
    CRFModel model = new CRFModel("crf-sgd-model", provenance, featureIDMap, labelIDMap, crfParameters);
    localOptimiser.reset();
    return model;
}
Also used : Tensor(org.tribuo.math.la.Tensor) ModelProvenance(org.tribuo.provenance.ModelProvenance) Label(org.tribuo.classification.Label) StochasticGradientOptimiser(org.tribuo.math.StochasticGradientOptimiser) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) SGDVector(org.tribuo.math.la.SGDVector) SplittableRandom(java.util.SplittableRandom) TrainerProvenance(org.tribuo.provenance.TrainerProvenance)

Example 4 with Tensor

use of org.tribuo.math.la.Tensor in project tribuo by oracle.

the class LinearParameters method getEmptyCopy.

/**
 * This returns a {@link DenseMatrix} the same size as the Parameters.
 * @return A {@link Tensor} array containing a single {@link DenseMatrix}.
 */
@Override
public Tensor[] getEmptyCopy() {
    DenseMatrix matrix = new DenseMatrix(weightMatrix.getDimension1Size(), weightMatrix.getDimension2Size());
    Tensor[] output = new Tensor[1];
    output[0] = matrix;
    return output;
}
Also used : Tensor(org.tribuo.math.la.Tensor) DenseMatrix(org.tribuo.math.la.DenseMatrix)

Example 5 with Tensor

use of org.tribuo.math.la.Tensor in project tribuo by oracle.

the class Pegasos method step.

@Override
public Tensor[] step(Tensor[] updates, double weight) {
    double eta_t = baseRate / (lambda * iteration);
    for (Tensor t : updates) {
        t.scaleInPlace(eta_t * weight);
    }
    iteration++;
    return updates;
}
Also used : ShrinkingTensor(org.tribuo.math.optimisers.util.ShrinkingTensor) Tensor(org.tribuo.math.la.Tensor)

Aggregations

Tensor (org.tribuo.math.la.Tensor)15 DenseMatrix (org.tribuo.math.la.DenseMatrix)8 DenseVector (org.tribuo.math.la.DenseVector)6 DenseSparseMatrix (org.tribuo.math.la.DenseSparseMatrix)4 ArrayList (java.util.ArrayList)3 Matrix (org.tribuo.math.la.Matrix)3 SGDVector (org.tribuo.math.la.SGDVector)3 SplittableRandom (java.util.SplittableRandom)2 DoubleUnaryOperator (java.util.function.DoubleUnaryOperator)2 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)2 StochasticGradientOptimiser (org.tribuo.math.StochasticGradientOptimiser)2 ShrinkingTensor (org.tribuo.math.optimisers.util.ShrinkingTensor)2 ModelProvenance (org.tribuo.provenance.ModelProvenance)2 TrainerProvenance (org.tribuo.provenance.TrainerProvenance)2 Pair (com.oracle.labs.mlrg.olcut.util.Pair)1 Test (org.junit.jupiter.api.Test)1 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)1 Dataset (org.tribuo.Dataset)1 Label (org.tribuo.classification.Label)1 AbstractLinearSGDTrainer (org.tribuo.common.sgd.AbstractLinearSGDTrainer)1