use of edu.neu.ccs.pyramid.dataset.LabelTranslator in project pyramid by cheng-li.
the class CRFInspector method pairRelations.
public static String pairRelations(CMLCRF crf) {
List<CRFInspector.PairWeight> list = crf.getWeights().printPairWeights();
Comparator<PairWeight> comparator = Comparator.comparing(pairWeight -> Math.abs(pairWeight.weight));
List<CRFInspector.PairWeight> sorted = list.stream().sorted(comparator.reversed()).collect(Collectors.toList());
StringBuilder sb = new StringBuilder();
LabelTranslator labelTranslator = crf.getLabelTranslator();
for (CRFInspector.PairWeight pairWeight : sorted) {
sb.append(labelTranslator.toExtLabel(pairWeight.label1)).append(":").append(pairWeight.hasLabel1).append(", ").append(labelTranslator.toExtLabel(pairWeight.label2)).append(":").append(pairWeight.hasLabel2).append("=>").append(pairWeight.weight).append("\n");
}
return sb.toString();
}
use of edu.neu.ccs.pyramid.dataset.LabelTranslator in project pyramid by cheng-li.
the class BRLRInspector method analyzePrediction.
public static MultiLabelPredictionAnalysis analyzePrediction(CBM cbm, LabelCalibrator labelCalibrator, VectorCalibrator setCalibrator, MultiLabelClfDataSet dataSet, MultiLabelClassifier classifier, PredictionFeatureExtractor predictionFeatureExtractor, int dataPointIndex, int ruleLimit, int labelSetLimit, double classProbThreshold) {
MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
IdTranslator idTranslator = dataSet.getIdTranslator();
predictionAnalysis.setInternalId(dataPointIndex);
predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setLabels(labels);
double[] classProbs = cbm.predictClassProbs(dataSet.getRow(dataPointIndex));
double[] calibratedClassProbs = labelCalibrator.calibratedClassProbs(classProbs);
PredictionCandidate trueCandidate = new PredictionCandidate();
trueCandidate.x = dataSet.getRow(dataPointIndex);
trueCandidate.multiLabel = dataSet.getMultiLabels()[dataPointIndex];
trueCandidate.labelProbs = calibratedClassProbs;
predictionAnalysis.setProbForTrueLabels(setCalibrator.calibrate(predictionFeatureExtractor.extractFeatures(trueCandidate)));
MultiLabel predictedLabels = classifier.predict(dataSet.getRow(dataPointIndex));
List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
predictionAnalysis.setInternalPrediction(internalPrediction);
List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setPrediction(prediction);
PredictionCandidate predictedCandidate = new PredictionCandidate();
predictedCandidate.x = dataSet.getRow(dataPointIndex);
predictedCandidate.multiLabel = predictedLabels;
predictedCandidate.labelProbs = calibratedClassProbs;
predictionAnalysis.setProbForPredictedLabels(setCalibrator.calibrate(predictionFeatureExtractor.extractFeatures(predictedCandidate)));
List<Integer> classes = new ArrayList<Integer>();
for (int k = 0; k < cbm.getNumClasses(); k++) {
if (calibratedClassProbs[k] >= classProbThreshold || dataSet.getMultiLabels()[dataPointIndex].matchClass(k) || predictedLabels.matchClass(k)) {
classes.add(k);
}
}
// todo
List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
for (int k : classes) {
ClassScoreCalculation classScoreCalculation = decisionProcess(cbm, labelTranslator, calibratedClassProbs[k], dataSet.getRow(dataPointIndex), k, ruleLimit);
classScoreCalculations.add(classScoreCalculation);
}
predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
List<MultiLabelPredictionAnalysis.ClassRankInfo> labelRanking = classes.stream().map(label -> {
MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
rankInfo.setClassIndex(label);
rankInfo.setClassName(labelTranslator.toExtLabel(label));
rankInfo.setProb(calibratedClassProbs[label]);
return rankInfo;
}).collect(Collectors.toList());
predictionAnalysis.setPredictedRanking(labelRanking);
List<Pair<MultiLabel, Double>> topK;
if (classifier instanceof SupportPredictor) {
topK = TopKFinder.topKinSupport(dataSet.getRow(dataPointIndex), cbm, labelCalibrator, setCalibrator, predictionFeatureExtractor, ((SupportPredictor) classifier).getSupport(), labelSetLimit);
} else {
topK = TopKFinder.topK(dataSet.getRow(dataPointIndex), cbm, labelCalibrator, setCalibrator, predictionFeatureExtractor, labelSetLimit);
}
List<MultiLabelPredictionAnalysis.LabelSetProbInfo> labelSetRanking = topK.stream().map(pair -> {
MultiLabel multiLabel = pair.getFirst();
double setProb = pair.getSecond();
MultiLabelPredictionAnalysis.LabelSetProbInfo labelSetProbInfo = new MultiLabelPredictionAnalysis.LabelSetProbInfo(multiLabel, setProb, labelTranslator);
return labelSetProbInfo;
}).sorted(Comparator.comparing(MultiLabelPredictionAnalysis.LabelSetProbInfo::getProbability).reversed()).limit(labelSetLimit).collect(Collectors.toList());
predictionAnalysis.setPredictedLabelSetRanking(labelSetRanking);
return predictionAnalysis;
}
use of edu.neu.ccs.pyramid.dataset.LabelTranslator in project pyramid by cheng-li.
the class BRInspector method analyzePrediction.
public static MultiLabelPredictionAnalysis analyzePrediction(MultiLabelClassifier.ClassProbEstimator classProbEstimator, LabelCalibrator labelCalibrator, VectorCalibrator setCalibrator, MultiLabelClfDataSet dataSet, MultiLabelClassifier classifier, PredictionFeatureExtractor predictionFeatureExtractor, int dataPointIndex, int ruleLimit, int labelSetLimit, double classProbThreshold) {
MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
IdTranslator idTranslator = dataSet.getIdTranslator();
predictionAnalysis.setInternalId(dataPointIndex);
predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setLabels(labels);
double[] classProbs = classProbEstimator.predictClassProbs(dataSet.getRow(dataPointIndex));
double[] calibratedClassProbs = labelCalibrator.calibratedClassProbs(classProbs);
PredictionCandidate trueCandidate = new PredictionCandidate();
trueCandidate.x = dataSet.getRow(dataPointIndex);
trueCandidate.multiLabel = dataSet.getMultiLabels()[dataPointIndex];
trueCandidate.labelProbs = calibratedClassProbs;
predictionAnalysis.setProbForTrueLabels(setCalibrator.calibrate(predictionFeatureExtractor.extractFeatures(trueCandidate)));
MultiLabel predictedLabels = classifier.predict(dataSet.getRow(dataPointIndex));
List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
predictionAnalysis.setInternalPrediction(internalPrediction);
List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setPrediction(prediction);
PredictionCandidate predictedCandidate = new PredictionCandidate();
predictedCandidate.x = dataSet.getRow(dataPointIndex);
predictedCandidate.multiLabel = predictedLabels;
predictedCandidate.labelProbs = calibratedClassProbs;
predictionAnalysis.setProbForPredictedLabels(setCalibrator.calibrate(predictionFeatureExtractor.extractFeatures(predictedCandidate)));
List<Integer> classes = new ArrayList<Integer>();
for (int k = 0; k < classProbEstimator.getNumClasses(); k++) {
if (calibratedClassProbs[k] >= classProbThreshold || dataSet.getMultiLabels()[dataPointIndex].matchClass(k) || predictedLabels.matchClass(k)) {
classes.add(k);
}
}
// todo
List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
for (int k : classes) {
ClassScoreCalculation classScoreCalculation = null;
if (classProbEstimator instanceof IMLGradientBoosting) {
classScoreCalculation = decisionProcess((IMLGradientBoosting) classProbEstimator, labelTranslator, calibratedClassProbs[k], dataSet.getRow(dataPointIndex), k, ruleLimit);
}
if (classProbEstimator instanceof CBM) {
classScoreCalculation = decisionProcess((CBM) classProbEstimator, labelTranslator, calibratedClassProbs[k], dataSet.getRow(dataPointIndex), k, ruleLimit);
}
classScoreCalculations.add(classScoreCalculation);
}
predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
List<MultiLabelPredictionAnalysis.ClassRankInfo> labelRanking = classes.stream().map(label -> {
MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
rankInfo.setClassIndex(label);
rankInfo.setClassName(labelTranslator.toExtLabel(label));
rankInfo.setProb(calibratedClassProbs[label]);
return rankInfo;
}).collect(Collectors.toList());
predictionAnalysis.setPredictedRanking(labelRanking);
List<Pair<MultiLabel, Double>> topK;
if (classifier instanceof SupportPredictor) {
topK = TopKFinder.topKinSupport(dataSet.getRow(dataPointIndex), classProbEstimator, labelCalibrator, setCalibrator, predictionFeatureExtractor, ((SupportPredictor) classifier).getSupport(), labelSetLimit);
} else {
topK = TopKFinder.topK(dataSet.getRow(dataPointIndex), classProbEstimator, labelCalibrator, setCalibrator, predictionFeatureExtractor, labelSetLimit);
}
List<MultiLabelPredictionAnalysis.LabelSetProbInfo> labelSetRanking = topK.stream().map(pair -> {
MultiLabel multiLabel = pair.getFirst();
double setProb = pair.getSecond();
MultiLabelPredictionAnalysis.LabelSetProbInfo labelSetProbInfo = new MultiLabelPredictionAnalysis.LabelSetProbInfo(multiLabel, setProb, labelTranslator);
return labelSetProbInfo;
}).sorted(Comparator.comparing(MultiLabelPredictionAnalysis.LabelSetProbInfo::getProbability).reversed()).limit(labelSetLimit).collect(Collectors.toList());
predictionAnalysis.setPredictedLabelSetRanking(labelSetRanking);
return predictionAnalysis;
}
use of edu.neu.ccs.pyramid.dataset.LabelTranslator in project pyramid by cheng-li.
the class LKBInspector method topFeatures.
public static TopFeatures topFeatures(LKBoost boosting, int classIndex, int limit) {
Map<Feature, Double> totalContributions = new HashMap<>();
List<Regressor> regressors = boosting.getEnsemble(classIndex).getRegressors();
List<RegressionTree> trees = regressors.stream().filter(regressor -> regressor instanceof RegressionTree).map(regressor -> (RegressionTree) regressor).collect(Collectors.toList());
for (RegressionTree tree : trees) {
Map<Feature, Double> contributions = RegTreeInspector.featureImportance(tree);
for (Map.Entry<Feature, Double> entry : contributions.entrySet()) {
Feature feature = entry.getKey();
Double contribution = entry.getValue();
double oldValue = totalContributions.getOrDefault(feature, 0.0);
double newValue = oldValue + contribution;
totalContributions.put(feature, newValue);
}
}
Comparator<Map.Entry<Feature, Double>> comparator = Comparator.comparing(Map.Entry::getValue);
List<Feature> list = totalContributions.entrySet().stream().sorted(comparator.reversed()).limit(limit).map(Map.Entry::getKey).collect(Collectors.toList());
TopFeatures topFeatures = new TopFeatures();
topFeatures.setTopFeatures(list);
topFeatures.setClassIndex(classIndex);
LabelTranslator labelTranslator = boosting.getLabelTranslator();
topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
return topFeatures;
}
use of edu.neu.ccs.pyramid.dataset.LabelTranslator in project pyramid by cheng-li.
the class LKBInspector method topFeatures.
/**
* @param lkBoosts ensemble of lktbs
* @param classIndex
* @return
*/
public static TopFeatures topFeatures(List<LKBoost> lkBoosts, int classIndex) {
Map<Feature, Double> totalContributions = new HashMap<>();
for (LKBoost lkBoost : lkBoosts) {
List<Regressor> regressors = lkBoost.getEnsemble(classIndex).getRegressors();
List<RegressionTree> trees = regressors.stream().filter(regressor -> regressor instanceof RegressionTree).map(regressor -> (RegressionTree) regressor).collect(Collectors.toList());
for (RegressionTree tree : trees) {
Map<Feature, Double> contributions = RegTreeInspector.featureImportance(tree);
for (Map.Entry<Feature, Double> entry : contributions.entrySet()) {
Feature feature = entry.getKey();
Double contribution = entry.getValue();
double oldValue = totalContributions.getOrDefault(feature, 0.0);
double newValue = oldValue + contribution;
totalContributions.put(feature, newValue);
}
}
}
Comparator<Map.Entry<Feature, Double>> comparator = Comparator.comparing(Map.Entry::getValue);
List<Feature> list = totalContributions.entrySet().stream().sorted(comparator.reversed()).map(Map.Entry::getKey).collect(Collectors.toList());
TopFeatures topFeatures = new TopFeatures();
topFeatures.setTopFeatures(list);
topFeatures.setClassIndex(classIndex);
LabelTranslator labelTranslator = lkBoosts.get(0).getLabelTranslator();
topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
return topFeatures;
}
Aggregations