use of edu.neu.ccs.pyramid.eval.InstanceAverage in project pyramid by cheng-li.
the class GeneralF1Predictor method showSupportPrediction.
public static Analysis showSupportPrediction(List<MultiLabel> combinations, double[] probs, MultiLabel truth, MultiLabel prediction, int numClasses) {
int truthIndex = 0;
for (int i = 0; i < combinations.size(); i++) {
if (combinations.get(i).equals(truth)) {
truthIndex = i;
break;
}
}
double[] trueJoint = new double[combinations.size()];
trueJoint[truthIndex] = 1;
double kl = KLDivergence.kl(trueJoint, probs);
List<Pair<MultiLabel, Double>> list = new ArrayList<>();
for (int i = 0; i < combinations.size(); i++) {
list.add(new Pair<>(combinations.get(i), probs[i]));
}
Comparator<Pair<MultiLabel, Double>> comparator = Comparator.comparing(a -> a.getSecond());
List<Pair<MultiLabel, Double>> sorted = list.stream().sorted(comparator.reversed()).filter(pair -> pair.getSecond() > 0.01).collect(Collectors.toList());
double expectedF1Prediction = expectedF1(combinations, probs, prediction, numClasses);
double expectedF1Truth = expectedF1(combinations, probs, truth, numClasses);
double actualF1 = new InstanceAverage(numClasses, truth, prediction).getF1();
StringBuilder jointString = new StringBuilder();
for (int i = 0; i < sorted.size(); i++) {
jointString.append(sorted.get(i).getFirst()).append(":").append(sorted.get(i).getSecond()).append(", ");
}
Analysis analysis = new Analysis();
analysis.expectedF1Prediction = expectedF1Prediction;
analysis.expectedF1Truth = expectedF1Truth;
analysis.actualF1 = actualF1;
analysis.kl = kl;
analysis.prediction = prediction;
analysis.truth = truth;
analysis.joint = jointString.toString();
return analysis;
}
use of edu.neu.ccs.pyramid.eval.InstanceAverage in project pyramid by cheng-li.
the class LossMatrixGenerator method matrix.
public static Matrix matrix(int n, String lossName) {
int size = (int) Math.pow(2, n);
double[][] matrixBuilder = new double[size][size];
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
String ib = toBinary(i, n);
String jb = toBinary(j, n);
MultiLabel multiLabel1 = toML(ib);
MultiLabel multiLabel2 = toML(jb);
MultiLabel[] trueLabels = { multiLabel1 };
MultiLabel[] predicted = { multiLabel2 };
MLConfusionMatrix mlConfusionMatrix = new MLConfusionMatrix(n, trueLabels, predicted);
InstanceAverage instanceAverage = new InstanceAverage(mlConfusionMatrix);
double loss;
switch(lossName.toLowerCase()) {
case "hamming":
loss = instanceAverage.getHammingLoss() * n;
break;
case "overlap":
loss = 1 - instanceAverage.getOverlap();
break;
case "accuracy":
loss = 1 - instanceAverage.getAccuracy();
break;
case "precision":
loss = 1 - instanceAverage.getPrecision();
break;
case "recall":
loss = 1 - instanceAverage.getRecall();
break;
case "f1":
loss = 1 - instanceAverage.getF1();
break;
default:
throw new IllegalArgumentException("unknown loss");
}
matrixBuilder[i][j] = loss;
}
}
Matrix matrix = new DenseMatrix(matrixBuilder);
return matrix;
}
Aggregations