use of edu.neu.ccs.pyramid.dataset.IdTranslator in project pyramid by cheng-li.
the class BRLRInspector method analyzePrediction.
public static MultiLabelPredictionAnalysis analyzePrediction(CBM cbm, LabelCalibrator labelCalibrator, VectorCalibrator setCalibrator, MultiLabelClfDataSet dataSet, MultiLabelClassifier classifier, PredictionFeatureExtractor predictionFeatureExtractor, int dataPointIndex, int ruleLimit, int labelSetLimit, double classProbThreshold) {
MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
IdTranslator idTranslator = dataSet.getIdTranslator();
predictionAnalysis.setInternalId(dataPointIndex);
predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setLabels(labels);
double[] classProbs = cbm.predictClassProbs(dataSet.getRow(dataPointIndex));
double[] calibratedClassProbs = labelCalibrator.calibratedClassProbs(classProbs);
PredictionCandidate trueCandidate = new PredictionCandidate();
trueCandidate.x = dataSet.getRow(dataPointIndex);
trueCandidate.multiLabel = dataSet.getMultiLabels()[dataPointIndex];
trueCandidate.labelProbs = calibratedClassProbs;
predictionAnalysis.setProbForTrueLabels(setCalibrator.calibrate(predictionFeatureExtractor.extractFeatures(trueCandidate)));
MultiLabel predictedLabels = classifier.predict(dataSet.getRow(dataPointIndex));
List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
predictionAnalysis.setInternalPrediction(internalPrediction);
List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setPrediction(prediction);
PredictionCandidate predictedCandidate = new PredictionCandidate();
predictedCandidate.x = dataSet.getRow(dataPointIndex);
predictedCandidate.multiLabel = predictedLabels;
predictedCandidate.labelProbs = calibratedClassProbs;
predictionAnalysis.setProbForPredictedLabels(setCalibrator.calibrate(predictionFeatureExtractor.extractFeatures(predictedCandidate)));
List<Integer> classes = new ArrayList<Integer>();
for (int k = 0; k < cbm.getNumClasses(); k++) {
if (calibratedClassProbs[k] >= classProbThreshold || dataSet.getMultiLabels()[dataPointIndex].matchClass(k) || predictedLabels.matchClass(k)) {
classes.add(k);
}
}
// todo
List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
for (int k : classes) {
ClassScoreCalculation classScoreCalculation = decisionProcess(cbm, labelTranslator, calibratedClassProbs[k], dataSet.getRow(dataPointIndex), k, ruleLimit);
classScoreCalculations.add(classScoreCalculation);
}
predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
List<MultiLabelPredictionAnalysis.ClassRankInfo> labelRanking = classes.stream().map(label -> {
MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
rankInfo.setClassIndex(label);
rankInfo.setClassName(labelTranslator.toExtLabel(label));
rankInfo.setProb(calibratedClassProbs[label]);
return rankInfo;
}).collect(Collectors.toList());
predictionAnalysis.setPredictedRanking(labelRanking);
List<Pair<MultiLabel, Double>> topK;
if (classifier instanceof SupportPredictor) {
topK = TopKFinder.topKinSupport(dataSet.getRow(dataPointIndex), cbm, labelCalibrator, setCalibrator, predictionFeatureExtractor, ((SupportPredictor) classifier).getSupport(), labelSetLimit);
} else {
topK = TopKFinder.topK(dataSet.getRow(dataPointIndex), cbm, labelCalibrator, setCalibrator, predictionFeatureExtractor, labelSetLimit);
}
List<MultiLabelPredictionAnalysis.LabelSetProbInfo> labelSetRanking = topK.stream().map(pair -> {
MultiLabel multiLabel = pair.getFirst();
double setProb = pair.getSecond();
MultiLabelPredictionAnalysis.LabelSetProbInfo labelSetProbInfo = new MultiLabelPredictionAnalysis.LabelSetProbInfo(multiLabel, setProb, labelTranslator);
return labelSetProbInfo;
}).sorted(Comparator.comparing(MultiLabelPredictionAnalysis.LabelSetProbInfo::getProbability).reversed()).limit(labelSetLimit).collect(Collectors.toList());
predictionAnalysis.setPredictedLabelSetRanking(labelSetRanking);
return predictionAnalysis;
}
use of edu.neu.ccs.pyramid.dataset.IdTranslator in project pyramid by cheng-li.
the class BRInspector method analyzePrediction.
public static MultiLabelPredictionAnalysis analyzePrediction(MultiLabelClassifier.ClassProbEstimator classProbEstimator, LabelCalibrator labelCalibrator, VectorCalibrator setCalibrator, MultiLabelClfDataSet dataSet, MultiLabelClassifier classifier, PredictionFeatureExtractor predictionFeatureExtractor, int dataPointIndex, int ruleLimit, int labelSetLimit, double classProbThreshold) {
MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
IdTranslator idTranslator = dataSet.getIdTranslator();
predictionAnalysis.setInternalId(dataPointIndex);
predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setLabels(labels);
double[] classProbs = classProbEstimator.predictClassProbs(dataSet.getRow(dataPointIndex));
double[] calibratedClassProbs = labelCalibrator.calibratedClassProbs(classProbs);
PredictionCandidate trueCandidate = new PredictionCandidate();
trueCandidate.x = dataSet.getRow(dataPointIndex);
trueCandidate.multiLabel = dataSet.getMultiLabels()[dataPointIndex];
trueCandidate.labelProbs = calibratedClassProbs;
predictionAnalysis.setProbForTrueLabels(setCalibrator.calibrate(predictionFeatureExtractor.extractFeatures(trueCandidate)));
MultiLabel predictedLabels = classifier.predict(dataSet.getRow(dataPointIndex));
List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
predictionAnalysis.setInternalPrediction(internalPrediction);
List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setPrediction(prediction);
PredictionCandidate predictedCandidate = new PredictionCandidate();
predictedCandidate.x = dataSet.getRow(dataPointIndex);
predictedCandidate.multiLabel = predictedLabels;
predictedCandidate.labelProbs = calibratedClassProbs;
predictionAnalysis.setProbForPredictedLabels(setCalibrator.calibrate(predictionFeatureExtractor.extractFeatures(predictedCandidate)));
List<Integer> classes = new ArrayList<Integer>();
for (int k = 0; k < classProbEstimator.getNumClasses(); k++) {
if (calibratedClassProbs[k] >= classProbThreshold || dataSet.getMultiLabels()[dataPointIndex].matchClass(k) || predictedLabels.matchClass(k)) {
classes.add(k);
}
}
// todo
List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
for (int k : classes) {
ClassScoreCalculation classScoreCalculation = null;
if (classProbEstimator instanceof IMLGradientBoosting) {
classScoreCalculation = decisionProcess((IMLGradientBoosting) classProbEstimator, labelTranslator, calibratedClassProbs[k], dataSet.getRow(dataPointIndex), k, ruleLimit);
}
if (classProbEstimator instanceof CBM) {
classScoreCalculation = decisionProcess((CBM) classProbEstimator, labelTranslator, calibratedClassProbs[k], dataSet.getRow(dataPointIndex), k, ruleLimit);
}
classScoreCalculations.add(classScoreCalculation);
}
predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
List<MultiLabelPredictionAnalysis.ClassRankInfo> labelRanking = classes.stream().map(label -> {
MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
rankInfo.setClassIndex(label);
rankInfo.setClassName(labelTranslator.toExtLabel(label));
rankInfo.setProb(calibratedClassProbs[label]);
return rankInfo;
}).collect(Collectors.toList());
predictionAnalysis.setPredictedRanking(labelRanking);
List<Pair<MultiLabel, Double>> topK;
if (classifier instanceof SupportPredictor) {
topK = TopKFinder.topKinSupport(dataSet.getRow(dataPointIndex), classProbEstimator, labelCalibrator, setCalibrator, predictionFeatureExtractor, ((SupportPredictor) classifier).getSupport(), labelSetLimit);
} else {
topK = TopKFinder.topK(dataSet.getRow(dataPointIndex), classProbEstimator, labelCalibrator, setCalibrator, predictionFeatureExtractor, labelSetLimit);
}
List<MultiLabelPredictionAnalysis.LabelSetProbInfo> labelSetRanking = topK.stream().map(pair -> {
MultiLabel multiLabel = pair.getFirst();
double setProb = pair.getSecond();
MultiLabelPredictionAnalysis.LabelSetProbInfo labelSetProbInfo = new MultiLabelPredictionAnalysis.LabelSetProbInfo(multiLabel, setProb, labelTranslator);
return labelSetProbInfo;
}).sorted(Comparator.comparing(MultiLabelPredictionAnalysis.LabelSetProbInfo::getProbability).reversed()).limit(labelSetLimit).collect(Collectors.toList());
predictionAnalysis.setPredictedLabelSetRanking(labelSetRanking);
return predictionAnalysis;
}
use of edu.neu.ccs.pyramid.dataset.IdTranslator in project pyramid by cheng-li.
the class LogisticRegressionInspector method analyzePrediction.
public static PredictionAnalysis analyzePrediction(LogisticRegression logisticRegression, ClfDataSet dataSet, int dataPointIndex, int limit) {
PredictionAnalysis predictionAnalysis = new PredictionAnalysis();
IdTranslator idTranslator = dataSet.getIdTranslator();
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
predictionAnalysis.setInternalId(dataPointIndex).setId(idTranslator.toExtId(dataPointIndex)).setInternalLabel(dataSet.getLabels()[dataPointIndex]).setLabel(labelTranslator.toExtLabel(dataSet.getLabels()[dataPointIndex]));
int prediction = logisticRegression.predict(dataSet.getRow(dataPointIndex));
predictionAnalysis.setInternalPrediction(prediction);
predictionAnalysis.setPrediction(labelTranslator.toExtLabel(prediction));
double[] probs = logisticRegression.predictClassProbs(dataSet.getRow(dataPointIndex));
List<ClassProbability> classProbabilities = new ArrayList<>();
for (int k = 0; k < probs.length; k++) {
ClassProbability classProbability = new ClassProbability(k, labelTranslator.toExtLabel(k), probs[k]);
classProbabilities.add(classProbability);
}
predictionAnalysis.setClassProbabilities(classProbabilities);
List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
for (int k = 0; k < probs.length; k++) {
ClassScoreCalculation classScoreCalculation = decisionProcess(logisticRegression, labelTranslator, dataSet.getRow(dataPointIndex), k, limit);
classScoreCalculations.add(classScoreCalculation);
}
predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
return predictionAnalysis;
}
use of edu.neu.ccs.pyramid.dataset.IdTranslator in project pyramid by cheng-li.
the class LKBInspector method analyzePrediction.
// todo speed up
public static PredictionAnalysis analyzePrediction(LKBoost boosting, ClfDataSet dataSet, int dataPointIndex, int limit) {
PredictionAnalysis predictionAnalysis = new PredictionAnalysis();
IdTranslator idTranslator = dataSet.getIdTranslator();
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
predictionAnalysis.setInternalId(dataPointIndex).setId(idTranslator.toExtId(dataPointIndex)).setInternalLabel(dataSet.getLabels()[dataPointIndex]).setLabel(labelTranslator.toExtLabel(dataSet.getLabels()[dataPointIndex]));
int prediction = boosting.predict(dataSet.getRow(dataPointIndex));
predictionAnalysis.setInternalPrediction(prediction);
predictionAnalysis.setPrediction(labelTranslator.toExtLabel(prediction));
double[] probs = boosting.predictClassProbs(dataSet.getRow(dataPointIndex));
List<ClassProbability> classProbabilities = new ArrayList<>();
for (int k = 0; k < probs.length; k++) {
ClassProbability classProbability = new ClassProbability(k, labelTranslator.toExtLabel(k), probs[k]);
classProbabilities.add(classProbability);
}
predictionAnalysis.setClassProbabilities(classProbabilities);
List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
for (int k = 0; k < probs.length; k++) {
ClassScoreCalculation classScoreCalculation = decisionProcess(boosting, labelTranslator, dataSet.getRow(dataPointIndex), k, limit);
classScoreCalculations.add(classScoreCalculation);
}
predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
return predictionAnalysis;
}
use of edu.neu.ccs.pyramid.dataset.IdTranslator in project pyramid by cheng-li.
the class AdaBoostMHInspector method analyzePrediction.
/**
* can be binary scaling or across-class scaling
* @param boosting
* @param scaling
* @param dataSet
* @param dataPointIndex
* @param classes
* @param limit
* @return
*/
public static MultiLabelPredictionAnalysis analyzePrediction(AdaBoostMH boosting, MultiLabelClassifier.ClassProbEstimator scaling, MultiLabelClfDataSet dataSet, int dataPointIndex, List<Integer> classes, int limit) {
MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
IdTranslator idTranslator = dataSet.getIdTranslator();
predictionAnalysis.setInternalId(dataPointIndex);
predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setLabels(labels);
double probForTrueLabels = Double.NaN;
if (scaling instanceof MultiLabelClassifier.AssignmentProbEstimator) {
probForTrueLabels = ((MultiLabelClassifier.AssignmentProbEstimator) scaling).predictAssignmentProb(dataSet.getRow(dataPointIndex), dataSet.getMultiLabels()[dataPointIndex]);
}
predictionAnalysis.setProbForTrueLabels(probForTrueLabels);
MultiLabel predictedLabels = boosting.predict(dataSet.getRow(dataPointIndex));
List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
predictionAnalysis.setInternalPrediction(internalPrediction);
List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setPrediction(prediction);
double probForPredictedLabels = Double.NaN;
if (scaling instanceof MultiLabelClassifier.AssignmentProbEstimator) {
probForPredictedLabels = ((MultiLabelClassifier.AssignmentProbEstimator) scaling).predictAssignmentProb(dataSet.getRow(dataPointIndex), predictedLabels);
}
predictionAnalysis.setProbForPredictedLabels(probForPredictedLabels);
List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
for (int k : classes) {
ClassScoreCalculation classScoreCalculation = decisionProcess(boosting, scaling, labelTranslator, dataSet.getRow(dataPointIndex), k, limit);
classScoreCalculations.add(classScoreCalculation);
}
predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
List<MultiLabelPredictionAnalysis.ClassRankInfo> ranking = classes.stream().map(label -> {
MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
rankInfo.setClassIndex(label);
rankInfo.setClassName(labelTranslator.toExtLabel(label));
rankInfo.setProb(scaling.predictClassProb(dataSet.getRow(dataPointIndex), label));
return rankInfo;
}).collect(Collectors.toList());
return predictionAnalysis;
}
Aggregations