Search in sources :

Example 41 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class GeneralF1Predictor method predictWithPMatrix.

public MultiLabel predictWithPMatrix(double[][] pMatrix, double zeroProbability) {
    int numLabels = pMatrix.length;
    int min = Math.min(maxSize, numLabels);
    MultiLabel best = new MultiLabel();
    double bestScore = zeroProbability;
    for (int k = 1; k <= min; k++) {
        double[] deltaVector = getDeltaVector(pMatrix, k);
        Pair<MultiLabel, Double> innerBest = bestWithLengthK(deltaVector, k);
        if (innerBest.getSecond() > bestScore) {
            bestScore = innerBest.getSecond();
            best = innerBest.getFirst();
        }
    }
    return best;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 42 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class GeneralF1Predictor method showSupportPrediction.

public static Analysis showSupportPrediction(List<MultiLabel> combinations, double[] probs, MultiLabel truth, MultiLabel prediction, int numClasses) {
    int truthIndex = 0;
    for (int i = 0; i < combinations.size(); i++) {
        if (combinations.get(i).equals(truth)) {
            truthIndex = i;
            break;
        }
    }
    double[] trueJoint = new double[combinations.size()];
    trueJoint[truthIndex] = 1;
    double kl = KLDivergence.kl(trueJoint, probs);
    List<Pair<MultiLabel, Double>> list = new ArrayList<>();
    for (int i = 0; i < combinations.size(); i++) {
        list.add(new Pair<>(combinations.get(i), probs[i]));
    }
    Comparator<Pair<MultiLabel, Double>> comparator = Comparator.comparing(a -> a.getSecond());
    List<Pair<MultiLabel, Double>> sorted = list.stream().sorted(comparator.reversed()).filter(pair -> pair.getSecond() > 0.01).collect(Collectors.toList());
    double expectedF1Prediction = expectedF1(combinations, probs, prediction, numClasses);
    double expectedF1Truth = expectedF1(combinations, probs, truth, numClasses);
    double actualF1 = new InstanceAverage(numClasses, truth, prediction).getF1();
    StringBuilder jointString = new StringBuilder();
    for (int i = 0; i < sorted.size(); i++) {
        jointString.append(sorted.get(i).getFirst()).append(":").append(sorted.get(i).getSecond()).append(", ");
    }
    Analysis analysis = new Analysis();
    analysis.expectedF1Prediction = expectedF1Prediction;
    analysis.expectedF1Truth = expectedF1Truth;
    analysis.actualF1 = actualF1;
    analysis.kl = kl;
    analysis.prediction = prediction;
    analysis.truth = truth;
    analysis.joint = jointString.toString();
    return analysis;
}
Also used : Arrays(java.util.Arrays) ArgSort(edu.neu.ccs.pyramid.util.ArgSort) Multiset(com.google.common.collect.Multiset) DenseVector(org.apache.mahout.math.DenseVector) DenseMatrix(org.apache.mahout.math.DenseMatrix) Matrix(org.apache.mahout.math.Matrix) Collectors(java.util.stream.Collectors) InstanceAverage(edu.neu.ccs.pyramid.eval.InstanceAverage) ArrayList(java.util.ArrayList) KLDivergence(edu.neu.ccs.pyramid.eval.KLDivergence) List(java.util.List) ConcurrentHashMultiset(com.google.common.collect.ConcurrentHashMultiset) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) Vector(org.apache.mahout.math.Vector) Enumerator(edu.neu.ccs.pyramid.multilabel_classification.Enumerator) Comparator(java.util.Comparator) Pair(edu.neu.ccs.pyramid.util.Pair) ArrayList(java.util.ArrayList) InstanceAverage(edu.neu.ccs.pyramid.eval.InstanceAverage) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 43 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class GeneralF1Predictor method getPMatrix.

private double[][] getPMatrix(int numClasses, List<MultiLabel> multiLabels, List<Double> probabilities) {
    int min = Math.min(maxSize, numClasses);
    double[][] pMatrix = new double[numClasses][min];
    for (int j = 0; j < multiLabels.size(); j++) {
        MultiLabel multiLabel = multiLabels.get(j);
        double prob = probabilities.get(j);
        int s = multiLabel.getMatchedLabels().size();
        if (s <= maxSize) {
            for (int i : multiLabel.getMatchedLabels()) {
                double old = pMatrix[i][s - 1];
                pMatrix[i][s - 1] = old + prob;
            }
        }
    }
    return pMatrix;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 44 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class GeneralF1Predictor method predict.

/**
     *
     * @param numClasses
     * @param samples sampled multi-labels; can have duplicates; their empirical probabilities will be estimated
     * @return
     */
public MultiLabel predict(int numClasses, List<MultiLabel> samples) {
    Multiset<MultiLabel> multiset = ConcurrentHashMultiset.create();
    for (MultiLabel multiLabel : samples) {
        multiset.add(multiLabel);
    }
    int sampleSize = samples.size();
    List<MultiLabel> uniqueOnes = new ArrayList<>();
    List<Double> probs = new ArrayList<>();
    for (Multiset.Entry<MultiLabel> entry : multiset.entrySet()) {
        uniqueOnes.add(entry.getElement());
        probs.add((double) entry.getCount() / sampleSize);
    }
    return predict(numClasses, uniqueOnes, probs);
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) ArrayList(java.util.ArrayList) Multiset(com.google.common.collect.Multiset) ConcurrentHashMultiset(com.google.common.collect.ConcurrentHashMultiset)

Example 45 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class SubsetAccPredictor method predict.

@Override
public MultiLabel predict(Vector vector) {
    if (imlGradientBoosting.getAssignments() == null) {
        throw new RuntimeException("CRF is used but legal assignments is not specified!");
    }
    double maxScore = Double.NEGATIVE_INFINITY;
    MultiLabel prediction = null;
    double[] classScores = imlGradientBoosting.predictClassScores(vector);
    for (MultiLabel assignment : imlGradientBoosting.getAssignments()) {
        double score = imlGradientBoosting.calAssignmentScore(assignment, classScores);
        if (score > maxScore) {
            maxScore = score;
            prediction = assignment;
        }
    }
    return prediction;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Aggregations

MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)101 Vector (org.apache.mahout.math.Vector)22 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)21 File (java.io.File)14 DenseVector (org.apache.mahout.math.DenseVector)13 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)12 Pair (edu.neu.ccs.pyramid.util.Pair)8 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)7 ArrayList (java.util.ArrayList)7 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)6 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)6 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)5 GeneralF1Predictor (edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)5 Collectors (java.util.stream.Collectors)5 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)4 java.util (java.util)4 StopWatch (org.apache.commons.lang3.time.StopWatch)4 Config (edu.neu.ccs.pyramid.configuration.Config)3 DataSetUtil (edu.neu.ccs.pyramid.dataset.DataSetUtil)3 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)3