Search in sources :

Example 1 with DenseVector

use of org.tribuo.math.la.DenseVector in project tribuo by oracle.

the class CRFParameters method getEmptyCopy.

/**
 * Returns a 3 element {@link Tensor} array.
 *
 * The first element is a {@link DenseVector} of label biases.
 * The second element is a {@link DenseMatrix} of feature-label weights.
 * The third element is a {@link DenseMatrix} of label-label transition weights.
 * @return A {@link Tensor} array.
 */
@Override
public Tensor[] getEmptyCopy() {
    Tensor[] output = new Tensor[3];
    output[0] = new DenseVector(biases.size());
    output[1] = new DenseMatrix(featureLabelWeights.getDimension1Size(), featureLabelWeights.getDimension2Size());
    output[2] = new DenseMatrix(labelLabelWeights.getDimension1Size(), labelLabelWeights.getDimension2Size());
    return output;
}
Also used : Tensor(org.tribuo.math.la.Tensor) DenseVector(org.tribuo.math.la.DenseVector) DenseMatrix(org.tribuo.math.la.DenseMatrix)

Example 2 with DenseVector

use of org.tribuo.math.la.DenseVector in project tribuo by oracle.

the class CRFParameters method merge.

@Override
public Tensor[] merge(Tensor[][] gradients, int size) {
    DenseVector biasUpdate = new DenseVector(biases.size());
    List<DenseSparseMatrix> updates = new ArrayList<>(size);
    DenseMatrix denseUpdates = null;
    DenseMatrix labelLabelUpdate = new DenseMatrix(labelLabelWeights.getDimension1Size(), labelLabelWeights.getDimension2Size());
    for (int j = 0; j < gradients.length; j++) {
        biasUpdate.intersectAndAddInPlace(gradients[j][0]);
        Matrix tmpUpdate = (Matrix) gradients[j][1];
        if (tmpUpdate instanceof DenseSparseMatrix) {
            updates.add((DenseSparseMatrix) tmpUpdate);
        } else {
            // is dense
            if (denseUpdates == null) {
                denseUpdates = (DenseMatrix) tmpUpdate;
            } else {
                denseUpdates.intersectAndAddInPlace(tmpUpdate);
            }
        }
        labelLabelUpdate.intersectAndAddInPlace(gradients[j][2]);
    }
    // Merge the combination of any dense and sparse updates
    Matrix featureLabelUpdate;
    if (updates.size() > 0) {
        featureLabelUpdate = merger.merge(updates.toArray(new DenseSparseMatrix[0]));
        if (denseUpdates != null) {
            denseUpdates.intersectAndAddInPlace(featureLabelUpdate);
            featureLabelUpdate = denseUpdates;
        }
    } else {
        featureLabelUpdate = denseUpdates;
    }
    return new Tensor[] { biasUpdate, featureLabelUpdate, labelLabelUpdate };
}
Also used : DenseMatrix(org.tribuo.math.la.DenseMatrix) DenseSparseMatrix(org.tribuo.math.la.DenseSparseMatrix) Matrix(org.tribuo.math.la.Matrix) Tensor(org.tribuo.math.la.Tensor) ArrayList(java.util.ArrayList) DenseSparseMatrix(org.tribuo.math.la.DenseSparseMatrix) DenseVector(org.tribuo.math.la.DenseVector) DenseMatrix(org.tribuo.math.la.DenseMatrix)

Example 3 with DenseVector

use of org.tribuo.math.la.DenseVector in project tribuo by oracle.

the class ChainHelper method viterbi.

/**
 * Runs Viterbi on a linear chain CRF. Uses the
 * linear predictions for each token and the label transition probabilities.
 * @param scores Tuple containing the label-label transition matrix, and the per token label scores.
 * @return Tuple containing the score of the maximum path and the maximum predicted label per token.
 */
public static ChainViterbiResults viterbi(ChainCliqueValues scores) {
    DenseMatrix markovScores = scores.transitionValues;
    DenseVector[] localScores = scores.localValues;
    int numLabels = markovScores.getDimension1Size();
    DenseVector[] costs = new DenseVector[scores.localValues.length];
    int[][] backPointers = new int[scores.localValues.length][];
    for (int i = 0; i < scores.localValues.length; i++) {
        costs[i] = new DenseVector(numLabels, Double.NEGATIVE_INFINITY);
        backPointers[i] = new int[numLabels];
        Arrays.fill(backPointers[i], -1);
    }
    costs[0].setElements(localScores[0]);
    for (int i = 1; i < scores.localValues.length; i++) {
        DenseVector curLocalScores = localScores[i];
        DenseVector curCost = costs[i];
        int[] curBackPointers = backPointers[i];
        DenseVector prevCost = costs[i - 1];
        for (int vi = 0; vi < numLabels; vi++) {
            double maxScore = Double.NEGATIVE_INFINITY;
            int maxIndex = -1;
            double curLocalScore = curLocalScores.get(vi);
            for (int vj = 0; vj < numLabels; vj++) {
                double curScore = markovScores.get(vj, vi) + prevCost.get(vj) + curLocalScore;
                if (curScore > maxScore) {
                    maxScore = curScore;
                    maxIndex = vj;
                }
            }
            curCost.set(vi, maxScore);
            if (maxIndex < 0) {
                maxIndex = 0;
            }
            curBackPointers[vi] = maxIndex;
        }
    }
    int[] mapValues = new int[scores.localValues.length];
    mapValues[mapValues.length - 1] = costs[costs.length - 1].indexOfMax();
    for (int j = mapValues.length - 2; j >= 0; j--) {
        mapValues[j] = backPointers[j + 1][mapValues[j + 1]];
    }
    return new ChainViterbiResults(costs[costs.length - 1].maxValue(), mapValues, scores);
}
Also used : DenseVector(org.tribuo.math.la.DenseVector) DenseMatrix(org.tribuo.math.la.DenseMatrix)

Example 4 with DenseVector

use of org.tribuo.math.la.DenseVector in project tribuo by oracle.

the class HdbscanModel method copy.

@Override
protected HdbscanModel copy(String newName, ModelProvenance newProvenance) {
    DenseVector copyOutlierScoresVector = outlierScoresVector.copy();
    List<Integer> copyClusterLabels = Collections.unmodifiableList(clusterLabels);
    List<HdbscanTrainer.ClusterExemplar> copyExemplars = new ArrayList<>(clusterExemplars);
    return new HdbscanModel(newName, newProvenance, featureIDMap, outputIDInfo, copyClusterLabels, copyOutlierScoresVector, copyExemplars, distType, noisePointsOutlierScore);
}
Also used : ArrayList(java.util.ArrayList) DenseVector(org.tribuo.math.la.DenseVector)

Example 5 with DenseVector

use of org.tribuo.math.la.DenseVector in project tribuo by oracle.

the class HdbscanTrainer method train.

@Override
public HdbscanModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance) {
    // increment the invocation count.
    TrainerProvenance trainerProvenance;
    synchronized (this) {
        trainerProvenance = getProvenance();
        trainInvocationCounter++;
    }
    ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
    SGDVector[] data = new SGDVector[examples.size()];
    int n = 0;
    for (Example<ClusterID> example : examples) {
        if (example.size() == featureMap.size()) {
            data[n] = DenseVector.createDenseVector(example, featureMap, false);
        } else {
            data[n] = SparseVector.createSparseVector(example, featureMap, false);
        }
        n++;
    }
    DenseVector coreDistances = calculateCoreDistances(data, k, neighboursQueryFactory);
    ExtendedMinimumSpanningTree emst = constructEMST(data, coreDistances, distType);
    // The levels at which each point becomes noise
    double[] pointNoiseLevels = new double[data.length];
    // The last label of each point before becoming noise
    int[] pointLastClusters = new int[data.length];
    // The HDBSCAN* hierarchy
    Map<Integer, int[]> hierarchy = new HashMap<>();
    List<HdbscanCluster> clusters = computeHierarchyAndClusterTree(emst, minClusterSize, pointNoiseLevels, pointLastClusters, hierarchy);
    propagateTree(clusters);
    List<Integer> clusterLabels = findProminentClusters(hierarchy, clusters, data.length);
    DenseVector outlierScoresVector = calculateOutlierScores(pointNoiseLevels, pointLastClusters, clusters);
    Map<Integer, List<Pair<Double, Integer>>> clusterAssignments = generateClusterAssignments(clusterLabels, outlierScoresVector);
    // Use the cluster assignments to establish the clustering info
    Map<Integer, MutableLong> counts = new HashMap<>();
    for (Entry<Integer, List<Pair<Double, Integer>>> e : clusterAssignments.entrySet()) {
        counts.put(e.getKey(), new MutableLong(e.getValue().size()));
    }
    ImmutableOutputInfo<ClusterID> outputMap = new ImmutableClusteringInfo(counts);
    // Compute the cluster exemplars.
    List<ClusterExemplar> clusterExemplars = computeExemplars(data, clusterAssignments, distType);
    // Get the outlier score value for points that are predicted as noise points.
    double noisePointsOutlierScore = getNoisePointsOutlierScore(clusterAssignments);
    logger.log(Level.INFO, "Hdbscan is done.");
    ModelProvenance provenance = new ModelProvenance(HdbscanModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
    return new HdbscanModel("hdbscan-model", provenance, featureMap, outputMap, clusterLabels, outlierScoresVector, clusterExemplars, distType, noisePointsOutlierScore);
}
Also used : ClusterID(org.tribuo.clustering.ClusterID) HashMap(java.util.HashMap) ImmutableFeatureMap(org.tribuo.ImmutableFeatureMap) SGDVector(org.tribuo.math.la.SGDVector) ArrayList(java.util.ArrayList) List(java.util.List) TrainerProvenance(org.tribuo.provenance.TrainerProvenance) ModelProvenance(org.tribuo.provenance.ModelProvenance) MutableLong(com.oracle.labs.mlrg.olcut.util.MutableLong) DenseVector(org.tribuo.math.la.DenseVector) ImmutableClusteringInfo(org.tribuo.clustering.ImmutableClusteringInfo)

Aggregations

DenseVector (org.tribuo.math.la.DenseVector)38 DenseMatrix (org.tribuo.math.la.DenseMatrix)12 Prediction (org.tribuo.Prediction)9 Label (org.tribuo.classification.Label)9 Pair (com.oracle.labs.mlrg.olcut.util.Pair)7 ArrayList (java.util.ArrayList)7 LinkedHashMap (java.util.LinkedHashMap)7 SparseVector (org.tribuo.math.la.SparseVector)7 Tensor (org.tribuo.math.la.Tensor)6 HashMap (java.util.HashMap)5 List (java.util.List)5 VectorTuple (org.tribuo.math.la.VectorTuple)5 HashSet (java.util.HashSet)4 MultiLabel (org.tribuo.multilabel.MultiLabel)4 ImmutableFeatureMap (org.tribuo.ImmutableFeatureMap)3 DenseSparseMatrix (org.tribuo.math.la.DenseSparseMatrix)3 SGDVector (org.tribuo.math.la.SGDVector)3 ModelProvenance (org.tribuo.provenance.ModelProvenance)3 TrainerProvenance (org.tribuo.provenance.TrainerProvenance)3 MutableLong (com.oracle.labs.mlrg.olcut.util.MutableLong)2