use of edu.neu.ccs.pyramid.dataset.LabelTranslator in project pyramid by cheng-li.
the class AdaBoostMHInspector method analyzePrediction.
/**
* can be binary scaling or across-class scaling
* @param boosting
* @param scaling
* @param dataSet
* @param dataPointIndex
* @param classes
* @param limit
* @return
*/
public static MultiLabelPredictionAnalysis analyzePrediction(AdaBoostMH boosting, MultiLabelClassifier.ClassProbEstimator scaling, MultiLabelClfDataSet dataSet, int dataPointIndex, List<Integer> classes, int limit) {
MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
IdTranslator idTranslator = dataSet.getIdTranslator();
predictionAnalysis.setInternalId(dataPointIndex);
predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setLabels(labels);
double probForTrueLabels = Double.NaN;
if (scaling instanceof MultiLabelClassifier.AssignmentProbEstimator) {
probForTrueLabels = ((MultiLabelClassifier.AssignmentProbEstimator) scaling).predictAssignmentProb(dataSet.getRow(dataPointIndex), dataSet.getMultiLabels()[dataPointIndex]);
}
predictionAnalysis.setProbForTrueLabels(probForTrueLabels);
MultiLabel predictedLabels = boosting.predict(dataSet.getRow(dataPointIndex));
List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
predictionAnalysis.setInternalPrediction(internalPrediction);
List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setPrediction(prediction);
double probForPredictedLabels = Double.NaN;
if (scaling instanceof MultiLabelClassifier.AssignmentProbEstimator) {
probForPredictedLabels = ((MultiLabelClassifier.AssignmentProbEstimator) scaling).predictAssignmentProb(dataSet.getRow(dataPointIndex), predictedLabels);
}
predictionAnalysis.setProbForPredictedLabels(probForPredictedLabels);
List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
for (int k : classes) {
ClassScoreCalculation classScoreCalculation = decisionProcess(boosting, scaling, labelTranslator, dataSet.getRow(dataPointIndex), k, limit);
classScoreCalculations.add(classScoreCalculation);
}
predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
List<MultiLabelPredictionAnalysis.ClassRankInfo> ranking = classes.stream().map(label -> {
MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
rankInfo.setClassIndex(label);
rankInfo.setClassName(labelTranslator.toExtLabel(label));
rankInfo.setProb(scaling.predictClassProb(dataSet.getRow(dataPointIndex), label));
return rankInfo;
}).collect(Collectors.toList());
return predictionAnalysis;
}
use of edu.neu.ccs.pyramid.dataset.LabelTranslator in project pyramid by cheng-li.
the class AdaBoostMHInspector method topFeatures.
public static TopFeatures topFeatures(AdaBoostMH boosting, int classIndex, int limit) {
Map<Feature, Double> totalContributions = new HashMap<>();
List<Regressor> regressors = boosting.getRegressors(classIndex);
List<RegressionTree> trees = regressors.stream().filter(regressor -> regressor instanceof RegressionTree).map(regressor -> (RegressionTree) regressor).collect(Collectors.toList());
for (RegressionTree tree : trees) {
Map<Feature, Double> contributions = RegTreeInspector.featureImportance(tree);
for (Map.Entry<Feature, Double> entry : contributions.entrySet()) {
Feature feature = entry.getKey();
Double contribution = entry.getValue();
double oldValue = totalContributions.getOrDefault(feature, 0.0);
double newValue = oldValue + contribution;
totalContributions.put(feature, newValue);
}
}
Comparator<Map.Entry<Feature, Double>> comparator = Comparator.comparing(Map.Entry::getValue);
List<Feature> list = totalContributions.entrySet().stream().sorted(comparator.reversed()).limit(limit).map(Map.Entry::getKey).collect(Collectors.toList());
TopFeatures topFeatures = new TopFeatures();
topFeatures.setTopFeatures(list);
topFeatures.setClassIndex(classIndex);
LabelTranslator labelTranslator = boosting.getLabelTranslator();
topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
return topFeatures;
}
Aggregations