use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.
the class CalibrationDataGenerator method createInstance.
public CalibrationInstance createInstance(Vector x, double[] uncalibratedLabelScores, MultiLabel prediction, MultiLabel groundtruth, String calibrateTarget) {
PredictionCandidate predictionCandidate = new PredictionCandidate();
predictionCandidate.x = x;
predictionCandidate.multiLabel = prediction;
predictionCandidate.labelProbs = labelCalibrator.calibratedClassProbs(uncalibratedLabelScores);
DynamicProgramming dynamicProgramming = new DynamicProgramming(predictionCandidate.labelProbs);
List<Pair<MultiLabel, Double>> sparseJoint = dynamicProgramming.topK(50);
predictionCandidate.sparseJoint = sparseJoint;
return createInstance(groundtruth, predictionCandidate, calibrateTarget);
}
use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.
the class CalibrationDataGenerator method expand.
private List<CalibrationInstance> expand(Vector x, MultiLabel groundTruth, double[] uncalibratedLabelScores, int queryId, int numCandidates, String calibrateTarget, List<MultiLabel> support, int numSupportCandidates) {
double[] marginals = labelCalibrator.calibratedClassProbs(uncalibratedLabelScores);
List<CalibrationInstance> calibrationInstances = new ArrayList<>();
DynamicProgramming dynamicProgramming = new DynamicProgramming(marginals);
List<Pair<MultiLabel, Double>> topK = dynamicProgramming.topK(numCandidates);
// add a few random candidates from support
List<MultiLabel> supportList = new ArrayList<>(support);
Collections.shuffle(supportList, new Random(queryId));
List<MultiLabel> candidates = new ArrayList<>();
for (int i = 0; i < numSupportCandidates; i++) {
candidates.add(supportList.get(i));
}
// add top K from DP
for (Pair<MultiLabel, Double> pair : topK) {
candidates.add(pair.getFirst());
}
for (MultiLabel multiLabel : candidates) {
PredictionCandidate predictionCandidate = new PredictionCandidate();
predictionCandidate.multiLabel = multiLabel;
predictionCandidate.labelProbs = marginals;
predictionCandidate.x = x;
predictionCandidate.sparseJoint = topK;
CalibrationInstance calibrationInstance = createInstance(groundTruth, predictionCandidate, calibrateTarget);
calibrationInstance.weight = 1;
calibrationInstance.queryIndex = queryId;
calibrationInstances.add(calibrationInstance);
}
return calibrationInstances;
}
use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.
the class IMLGBInspector method analyzePrediction.
//todo speed up
public static MultiLabelPredictionAnalysis analyzePrediction(IMLGradientBoosting boosting, PluginPredictor<IMLGradientBoosting> pluginPredictor, MultiLabelClfDataSet dataSet, int dataPointIndex, int ruleLimit, int labelSetLimit, double classProbThreshold) {
MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
IdTranslator idTranslator = dataSet.getIdTranslator();
predictionAnalysis.setInternalId(dataPointIndex);
predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setLabels(labels);
if (pluginPredictor instanceof SubsetAccPredictor || pluginPredictor instanceof InstanceF1Predictor) {
predictionAnalysis.setProbForTrueLabels(boosting.predictAssignmentProbWithConstraint(dataSet.getRow(dataPointIndex), dataSet.getMultiLabels()[dataPointIndex]));
}
if (pluginPredictor instanceof HammingPredictor || pluginPredictor instanceof MacroF1Predictor) {
predictionAnalysis.setProbForTrueLabels(boosting.predictAssignmentProbWithoutConstraint(dataSet.getRow(dataPointIndex), dataSet.getMultiLabels()[dataPointIndex]));
}
MultiLabel predictedLabels = pluginPredictor.predict(dataSet.getRow(dataPointIndex));
List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
predictionAnalysis.setInternalPrediction(internalPrediction);
List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setPrediction(prediction);
if (pluginPredictor instanceof SubsetAccPredictor || pluginPredictor instanceof InstanceF1Predictor) {
predictionAnalysis.setProbForPredictedLabels(boosting.predictAssignmentProbWithConstraint(dataSet.getRow(dataPointIndex), predictedLabels));
}
if (pluginPredictor instanceof HammingPredictor || pluginPredictor instanceof MacroF1Predictor) {
predictionAnalysis.setProbForPredictedLabels(boosting.predictAssignmentProbWithoutConstraint(dataSet.getRow(dataPointIndex), predictedLabels));
}
double[] classProbs = boosting.predictClassProbs(dataSet.getRow(dataPointIndex));
List<Integer> classes = new ArrayList<Integer>();
for (int k = 0; k < boosting.getNumClasses(); k++) {
if (classProbs[k] >= classProbThreshold || dataSet.getMultiLabels()[dataPointIndex].matchClass(k) || predictedLabels.matchClass(k)) {
classes.add(k);
}
}
List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
for (int k : classes) {
ClassScoreCalculation classScoreCalculation = decisionProcess(boosting, labelTranslator, dataSet.getRow(dataPointIndex), k, ruleLimit);
classScoreCalculations.add(classScoreCalculation);
}
predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
List<MultiLabelPredictionAnalysis.ClassRankInfo> labelRanking = classes.stream().map(label -> {
MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
rankInfo.setClassIndex(label);
rankInfo.setClassName(labelTranslator.toExtLabel(label));
rankInfo.setProb(boosting.predictClassProb(dataSet.getRow(dataPointIndex), label));
return rankInfo;
}).collect(Collectors.toList());
predictionAnalysis.setPredictedRanking(labelRanking);
List<MultiLabelPredictionAnalysis.LabelSetProbInfo> labelSetRanking = null;
if (pluginPredictor instanceof SubsetAccPredictor || pluginPredictor instanceof InstanceF1Predictor) {
double[] labelSetProbs = boosting.predictAllAssignmentProbsWithConstraint(dataSet.getRow(dataPointIndex));
labelSetRanking = IntStream.range(0, boosting.getAssignments().size()).mapToObj(i -> {
MultiLabel multiLabel = boosting.getAssignments().get(i);
double setProb = labelSetProbs[i];
MultiLabelPredictionAnalysis.LabelSetProbInfo labelSetProbInfo = new MultiLabelPredictionAnalysis.LabelSetProbInfo(multiLabel, setProb, labelTranslator);
return labelSetProbInfo;
}).sorted(Comparator.comparing(MultiLabelPredictionAnalysis.LabelSetProbInfo::getProbability).reversed()).limit(labelSetLimit).collect(Collectors.toList());
}
if (pluginPredictor instanceof HammingPredictor || pluginPredictor instanceof MacroF1Predictor) {
labelSetRanking = new ArrayList<>();
DynamicProgramming dp = new DynamicProgramming(classProbs);
for (int c = 0; c < labelSetLimit; c++) {
DynamicProgramming.Candidate candidate = dp.nextHighest();
MultiLabel multiLabel = candidate.getMultiLabel();
double setProb = candidate.getProbability();
MultiLabelPredictionAnalysis.LabelSetProbInfo labelSetProbInfo = new MultiLabelPredictionAnalysis.LabelSetProbInfo(multiLabel, setProb, labelTranslator);
labelSetRanking.add(labelSetProbInfo);
}
}
predictionAnalysis.setPredictedLabelSetRanking(labelSetRanking);
return predictionAnalysis;
}
use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.
the class Reranker method predict.
public MultiLabel predict(Vector vector, double[] uncalibratedLabelScores) {
double[] marginals = labelCalibrator.calibratedClassProbs(uncalibratedLabelScores);
DynamicProgramming dynamicProgramming = new DynamicProgramming(marginals);
List<Pair<MultiLabel, Double>> sparseJoint = dynamicProgramming.topK(numCandidate);
List<MultiLabel> multiLabels = sparseJoint.stream().map(pair -> pair.getFirst()).filter(candidate -> candidate.getNumMatchedLabels() >= minPredictionSize && candidate.getNumMatchedLabels() <= maxPredictionSize).collect(Collectors.toList());
List<Pair<MultiLabel, Double>> candidates = new ArrayList<>();
if (multiLabels.isEmpty()) {
int[] sorted = ArgSort.argSortDescending(marginals);
MultiLabel multiLabel = new MultiLabel();
for (int i = 0; i < minPredictionSize; i++) {
multiLabel.addLabel(sorted[i]);
}
multiLabels.add(multiLabel);
}
for (MultiLabel candidate : multiLabels) {
PredictionCandidate predictionCandidate = new PredictionCandidate();
predictionCandidate.x = vector;
predictionCandidate.labelProbs = marginals;
predictionCandidate.multiLabel = candidate;
predictionCandidate.sparseJoint = sparseJoint;
Vector feature = predictionFeatureExtractor.extractFeatures(predictionCandidate);
double score = regressor.predict(feature);
candidates.add(new Pair<>(candidate, score));
}
Comparator<Pair<MultiLabel, Double>> comparator = Comparator.comparing(pair -> pair.getSecond());
return candidates.stream().max(comparator).map(Pair::getFirst).get();
}
use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.
the class BRPrediction method simplePredictionAnalysisCalibrated.
public static String simplePredictionAnalysisCalibrated(MultiLabelClassifier.ClassProbEstimator classProbEstimator, LabelCalibrator labelCalibrator, VectorCalibrator setCalibrator, MultiLabelClfDataSet dataSet, int dataPointIndex, MultiLabelClassifier classifier, PredictionFeatureExtractor predictionFeatureExtractor) {
StringBuilder sb = new StringBuilder();
MultiLabel trueLabels = dataSet.getMultiLabels()[dataPointIndex];
String id = dataSet.getIdTranslator().toExtId(dataPointIndex);
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
double[] classProbs = classProbEstimator.predictClassProbs(dataSet.getRow(dataPointIndex));
double[] calibratedClassProbs = labelCalibrator.calibratedClassProbs(classProbs);
MultiLabel predicted = classifier.predict(dataSet.getRow(dataPointIndex));
List<Integer> classes = new ArrayList<Integer>();
for (int k = 0; k < dataSet.getNumClasses(); k++) {
if (dataSet.getMultiLabels()[dataPointIndex].matchClass(k) || predicted.matchClass(k)) {
classes.add(k);
}
}
Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(Pair::getSecond);
List<Pair<Integer, Double>> list = classes.stream().map(l -> {
if (l < classProbEstimator.getNumClasses()) {
return new Pair<>(l, calibratedClassProbs[l]);
} else {
return new Pair<>(l, 0.0);
}
}).sorted(comparator.reversed()).collect(Collectors.toList());
for (Pair<Integer, Double> pair : list) {
int label = pair.getFirst();
double prob = pair.getSecond();
int match = 0;
if (trueLabels.matchClass(label)) {
match = 1;
}
sb.append(id).append("\t").append(labelTranslator.toExtLabel(label)).append("\t").append("single").append("\t").append(prob).append("\t").append(match).append("\t").append("NA").append("\t").append("NA").append("\t").append("NA").append("\t").append("NA").append("\n");
}
PredictionCandidate predictedCandidate = new PredictionCandidate();
predictedCandidate.multiLabel = predicted;
predictedCandidate.labelProbs = calibratedClassProbs;
predictedCandidate.x = dataSet.getRow(dataPointIndex);
DynamicProgramming dynamicProgramming = new DynamicProgramming(calibratedClassProbs);
List<Pair<MultiLabel, Double>> sparseJoint = dynamicProgramming.topK(50);
predictedCandidate.sparseJoint = sparseJoint;
Vector feature = predictionFeatureExtractor.extractFeatures(predictedCandidate);
double probability = setCalibrator.calibrate(feature);
List<Integer> predictedList = predicted.getMatchedLabelsOrdered();
MultiLabel prediction = new MultiLabel();
sb.append(id).append("\t");
for (int i = 0; i < predictedList.size(); i++) {
prediction.addLabel(predictedList.get(i));
sb.append(labelTranslator.toExtLabel(predictedList.get(i)));
if (i != predictedList.size() - 1) {
sb.append(",");
}
}
sb.append("\t");
int setMatch = 0;
if (predicted.equals(trueLabels)) {
setMatch = 1;
}
List<Integer> truthList = trueLabels.getMatchedLabels().stream().sorted().collect(Collectors.toList());
StringBuilder sbLabels = new StringBuilder();
for (int i = 0; i < truthList.size(); i++) {
if (i != truthList.size() - 1) {
sbLabels.append(labelTranslator.toExtLabel(truthList.get(i))).append(",");
} else {
sbLabels.append(labelTranslator.toExtLabel(truthList.get(i)));
}
}
double precision = Precision.precision(trueLabels, prediction);
double recall = Recall.recall(trueLabels, prediction);
double f1 = FMeasure.f1(precision, recall);
sb.append("set").append("\t").append(probability).append("\t").append(setMatch).append("\t").append(sbLabels.toString()).append("\t").append(precision).append("\t").append(recall).append("\t").append(f1).append("\n");
return sb.toString();
}
Aggregations