Search in sources :

Example 11 with LabelTranslator

use of edu.neu.ccs.pyramid.dataset.LabelTranslator in project pyramid by cheng-li.

the class AdaBoostMHInspector method analyzePrediction.

/**
 * can be binary scaling or across-class scaling
 * @param boosting
 * @param scaling
 * @param dataSet
 * @param dataPointIndex
 * @param classes
 * @param limit
 * @return
 */
public static MultiLabelPredictionAnalysis analyzePrediction(AdaBoostMH boosting, MultiLabelClassifier.ClassProbEstimator scaling, MultiLabelClfDataSet dataSet, int dataPointIndex, List<Integer> classes, int limit) {
    MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
    LabelTranslator labelTranslator = dataSet.getLabelTranslator();
    IdTranslator idTranslator = dataSet.getIdTranslator();
    predictionAnalysis.setInternalId(dataPointIndex);
    predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
    predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
    List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setLabels(labels);
    double probForTrueLabels = Double.NaN;
    if (scaling instanceof MultiLabelClassifier.AssignmentProbEstimator) {
        probForTrueLabels = ((MultiLabelClassifier.AssignmentProbEstimator) scaling).predictAssignmentProb(dataSet.getRow(dataPointIndex), dataSet.getMultiLabels()[dataPointIndex]);
    }
    predictionAnalysis.setProbForTrueLabels(probForTrueLabels);
    MultiLabel predictedLabels = boosting.predict(dataSet.getRow(dataPointIndex));
    List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
    predictionAnalysis.setInternalPrediction(internalPrediction);
    List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setPrediction(prediction);
    double probForPredictedLabels = Double.NaN;
    if (scaling instanceof MultiLabelClassifier.AssignmentProbEstimator) {
        probForPredictedLabels = ((MultiLabelClassifier.AssignmentProbEstimator) scaling).predictAssignmentProb(dataSet.getRow(dataPointIndex), predictedLabels);
    }
    predictionAnalysis.setProbForPredictedLabels(probForPredictedLabels);
    List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
    for (int k : classes) {
        ClassScoreCalculation classScoreCalculation = decisionProcess(boosting, scaling, labelTranslator, dataSet.getRow(dataPointIndex), k, limit);
        classScoreCalculations.add(classScoreCalculation);
    }
    predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
    List<MultiLabelPredictionAnalysis.ClassRankInfo> ranking = classes.stream().map(label -> {
        MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
        rankInfo.setClassIndex(label);
        rankInfo.setClassName(labelTranslator.toExtLabel(label));
        rankInfo.setProb(scaling.predictClassProb(dataSet.getRow(dataPointIndex), label));
        return rankInfo;
    }).collect(Collectors.toList());
    return predictionAnalysis;
}
Also used : MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) edu.neu.ccs.pyramid.regression(edu.neu.ccs.pyramid.regression) java.util(java.util) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) MLPlattScaling(edu.neu.ccs.pyramid.multilabel_classification.MLPlattScaling) RegTreeInspector(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector) Collectors(java.util.stream.Collectors) IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting) Classifier(edu.neu.ccs.pyramid.classification.Classifier) MLACPlattScaling(edu.neu.ccs.pyramid.multilabel_classification.MLACPlattScaling) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) PlattScaling(edu.neu.ccs.pyramid.classification.PlattScaling) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Pair(edu.neu.ccs.pyramid.util.Pair) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator)

Example 12 with LabelTranslator

use of edu.neu.ccs.pyramid.dataset.LabelTranslator in project pyramid by cheng-li.

the class AdaBoostMHInspector method topFeatures.

public static TopFeatures topFeatures(AdaBoostMH boosting, int classIndex, int limit) {
    Map<Feature, Double> totalContributions = new HashMap<>();
    List<Regressor> regressors = boosting.getRegressors(classIndex);
    List<RegressionTree> trees = regressors.stream().filter(regressor -> regressor instanceof RegressionTree).map(regressor -> (RegressionTree) regressor).collect(Collectors.toList());
    for (RegressionTree tree : trees) {
        Map<Feature, Double> contributions = RegTreeInspector.featureImportance(tree);
        for (Map.Entry<Feature, Double> entry : contributions.entrySet()) {
            Feature feature = entry.getKey();
            Double contribution = entry.getValue();
            double oldValue = totalContributions.getOrDefault(feature, 0.0);
            double newValue = oldValue + contribution;
            totalContributions.put(feature, newValue);
        }
    }
    Comparator<Map.Entry<Feature, Double>> comparator = Comparator.comparing(Map.Entry::getValue);
    List<Feature> list = totalContributions.entrySet().stream().sorted(comparator.reversed()).limit(limit).map(Map.Entry::getKey).collect(Collectors.toList());
    TopFeatures topFeatures = new TopFeatures();
    topFeatures.setTopFeatures(list);
    topFeatures.setClassIndex(classIndex);
    LabelTranslator labelTranslator = boosting.getLabelTranslator();
    topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
    return topFeatures;
}
Also used : MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) edu.neu.ccs.pyramid.regression(edu.neu.ccs.pyramid.regression) java.util(java.util) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) MLPlattScaling(edu.neu.ccs.pyramid.multilabel_classification.MLPlattScaling) RegTreeInspector(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector) Collectors(java.util.stream.Collectors) IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting) Classifier(edu.neu.ccs.pyramid.classification.Classifier) MLACPlattScaling(edu.neu.ccs.pyramid.multilabel_classification.MLACPlattScaling) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) PlattScaling(edu.neu.ccs.pyramid.classification.PlattScaling) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Pair(edu.neu.ccs.pyramid.util.Pair) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator)

Aggregations

LabelTranslator (edu.neu.ccs.pyramid.dataset.LabelTranslator)12 IdTranslator (edu.neu.ccs.pyramid.dataset.IdTranslator)10 edu.neu.ccs.pyramid.regression (edu.neu.ccs.pyramid.regression)8 Collectors (java.util.stream.Collectors)8 Vector (org.apache.mahout.math.Vector)8 ClassProbability (edu.neu.ccs.pyramid.classification.ClassProbability)6 PredictionAnalysis (edu.neu.ccs.pyramid.classification.PredictionAnalysis)6 Feature (edu.neu.ccs.pyramid.feature.Feature)6 java.util (java.util)6 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)5 RegressionTree (edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree)5 TreeRule (edu.neu.ccs.pyramid.regression.regression_tree.TreeRule)5 Pair (edu.neu.ccs.pyramid.util.Pair)5 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)4 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)4 TopFeatures (edu.neu.ccs.pyramid.feature.TopFeatures)4 RegTreeInspector (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector)4 IMLGradientBoosting (edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting)3 ArrayList (java.util.ArrayList)3 edu.neu.ccs.pyramid.calibration (edu.neu.ccs.pyramid.calibration)2