Search in sources :

Example 1 with InstanceAverage

use of edu.neu.ccs.pyramid.eval.InstanceAverage in project pyramid by cheng-li.

the class GeneralF1Predictor method showSupportPrediction.

public static Analysis showSupportPrediction(List<MultiLabel> combinations, double[] probs, MultiLabel truth, MultiLabel prediction, int numClasses) {
    int truthIndex = 0;
    for (int i = 0; i < combinations.size(); i++) {
        if (combinations.get(i).equals(truth)) {
            truthIndex = i;
            break;
        }
    }
    double[] trueJoint = new double[combinations.size()];
    trueJoint[truthIndex] = 1;
    double kl = KLDivergence.kl(trueJoint, probs);
    List<Pair<MultiLabel, Double>> list = new ArrayList<>();
    for (int i = 0; i < combinations.size(); i++) {
        list.add(new Pair<>(combinations.get(i), probs[i]));
    }
    Comparator<Pair<MultiLabel, Double>> comparator = Comparator.comparing(a -> a.getSecond());
    List<Pair<MultiLabel, Double>> sorted = list.stream().sorted(comparator.reversed()).filter(pair -> pair.getSecond() > 0.01).collect(Collectors.toList());
    double expectedF1Prediction = expectedF1(combinations, probs, prediction, numClasses);
    double expectedF1Truth = expectedF1(combinations, probs, truth, numClasses);
    double actualF1 = new InstanceAverage(numClasses, truth, prediction).getF1();
    StringBuilder jointString = new StringBuilder();
    for (int i = 0; i < sorted.size(); i++) {
        jointString.append(sorted.get(i).getFirst()).append(":").append(sorted.get(i).getSecond()).append(", ");
    }
    Analysis analysis = new Analysis();
    analysis.expectedF1Prediction = expectedF1Prediction;
    analysis.expectedF1Truth = expectedF1Truth;
    analysis.actualF1 = actualF1;
    analysis.kl = kl;
    analysis.prediction = prediction;
    analysis.truth = truth;
    analysis.joint = jointString.toString();
    return analysis;
}
Also used : Arrays(java.util.Arrays) ArgSort(edu.neu.ccs.pyramid.util.ArgSort) Multiset(com.google.common.collect.Multiset) DenseVector(org.apache.mahout.math.DenseVector) DenseMatrix(org.apache.mahout.math.DenseMatrix) Matrix(org.apache.mahout.math.Matrix) Collectors(java.util.stream.Collectors) InstanceAverage(edu.neu.ccs.pyramid.eval.InstanceAverage) ArrayList(java.util.ArrayList) KLDivergence(edu.neu.ccs.pyramid.eval.KLDivergence) List(java.util.List) ConcurrentHashMultiset(com.google.common.collect.ConcurrentHashMultiset) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) Vector(org.apache.mahout.math.Vector) Enumerator(edu.neu.ccs.pyramid.multilabel_classification.Enumerator) Comparator(java.util.Comparator) Pair(edu.neu.ccs.pyramid.util.Pair) ArrayList(java.util.ArrayList) InstanceAverage(edu.neu.ccs.pyramid.eval.InstanceAverage) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 2 with InstanceAverage

use of edu.neu.ccs.pyramid.eval.InstanceAverage in project pyramid by cheng-li.

the class LossMatrixGenerator method matrix.

public static Matrix matrix(int n, String lossName) {
    int size = (int) Math.pow(2, n);
    double[][] matrixBuilder = new double[size][size];
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < size; j++) {
            String ib = toBinary(i, n);
            String jb = toBinary(j, n);
            MultiLabel multiLabel1 = toML(ib);
            MultiLabel multiLabel2 = toML(jb);
            MultiLabel[] trueLabels = { multiLabel1 };
            MultiLabel[] predicted = { multiLabel2 };
            MLConfusionMatrix mlConfusionMatrix = new MLConfusionMatrix(n, trueLabels, predicted);
            InstanceAverage instanceAverage = new InstanceAverage(mlConfusionMatrix);
            double loss;
            switch(lossName.toLowerCase()) {
                case "hamming":
                    loss = instanceAverage.getHammingLoss() * n;
                    break;
                case "overlap":
                    loss = 1 - instanceAverage.getOverlap();
                    break;
                case "accuracy":
                    loss = 1 - instanceAverage.getAccuracy();
                    break;
                case "precision":
                    loss = 1 - instanceAverage.getPrecision();
                    break;
                case "recall":
                    loss = 1 - instanceAverage.getRecall();
                    break;
                case "f1":
                    loss = 1 - instanceAverage.getF1();
                    break;
                default:
                    throw new IllegalArgumentException("unknown loss");
            }
            matrixBuilder[i][j] = loss;
        }
    }
    Matrix matrix = new DenseMatrix(matrixBuilder);
    return matrix;
}
Also used : MLConfusionMatrix(edu.neu.ccs.pyramid.eval.MLConfusionMatrix) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DenseMatrix(org.apache.mahout.math.DenseMatrix) MLConfusionMatrix(edu.neu.ccs.pyramid.eval.MLConfusionMatrix) Matrix(org.apache.mahout.math.Matrix) InstanceAverage(edu.neu.ccs.pyramid.eval.InstanceAverage) DenseMatrix(org.apache.mahout.math.DenseMatrix)

Aggregations

MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)2 InstanceAverage (edu.neu.ccs.pyramid.eval.InstanceAverage)2 DenseMatrix (org.apache.mahout.math.DenseMatrix)2 Matrix (org.apache.mahout.math.Matrix)2 ConcurrentHashMultiset (com.google.common.collect.ConcurrentHashMultiset)1 Multiset (com.google.common.collect.Multiset)1 KLDivergence (edu.neu.ccs.pyramid.eval.KLDivergence)1 MLConfusionMatrix (edu.neu.ccs.pyramid.eval.MLConfusionMatrix)1 Enumerator (edu.neu.ccs.pyramid.multilabel_classification.Enumerator)1 ArgSort (edu.neu.ccs.pyramid.util.ArgSort)1 Pair (edu.neu.ccs.pyramid.util.Pair)1 ArrayList (java.util.ArrayList)1 Arrays (java.util.Arrays)1 Comparator (java.util.Comparator)1 List (java.util.List)1 Collectors (java.util.stream.Collectors)1 DenseVector (org.apache.mahout.math.DenseVector)1 Vector (org.apache.mahout.math.Vector)1