Search in sources :

Example 1 with SupportPredictor

use of edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor in project pyramid by cheng-li.

the class BRLRInspector method analyzePrediction.

public static MultiLabelPredictionAnalysis analyzePrediction(CBM cbm, LabelCalibrator labelCalibrator, VectorCalibrator setCalibrator, MultiLabelClfDataSet dataSet, MultiLabelClassifier classifier, PredictionFeatureExtractor predictionFeatureExtractor, int dataPointIndex, int ruleLimit, int labelSetLimit, double classProbThreshold) {
    MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
    LabelTranslator labelTranslator = dataSet.getLabelTranslator();
    IdTranslator idTranslator = dataSet.getIdTranslator();
    predictionAnalysis.setInternalId(dataPointIndex);
    predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
    predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
    List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setLabels(labels);
    double[] classProbs = cbm.predictClassProbs(dataSet.getRow(dataPointIndex));
    double[] calibratedClassProbs = labelCalibrator.calibratedClassProbs(classProbs);
    PredictionCandidate trueCandidate = new PredictionCandidate();
    trueCandidate.x = dataSet.getRow(dataPointIndex);
    trueCandidate.multiLabel = dataSet.getMultiLabels()[dataPointIndex];
    trueCandidate.labelProbs = calibratedClassProbs;
    predictionAnalysis.setProbForTrueLabels(setCalibrator.calibrate(predictionFeatureExtractor.extractFeatures(trueCandidate)));
    MultiLabel predictedLabels = classifier.predict(dataSet.getRow(dataPointIndex));
    List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
    predictionAnalysis.setInternalPrediction(internalPrediction);
    List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setPrediction(prediction);
    PredictionCandidate predictedCandidate = new PredictionCandidate();
    predictedCandidate.x = dataSet.getRow(dataPointIndex);
    predictedCandidate.multiLabel = predictedLabels;
    predictedCandidate.labelProbs = calibratedClassProbs;
    predictionAnalysis.setProbForPredictedLabels(setCalibrator.calibrate(predictionFeatureExtractor.extractFeatures(predictedCandidate)));
    List<Integer> classes = new ArrayList<Integer>();
    for (int k = 0; k < cbm.getNumClasses(); k++) {
        if (calibratedClassProbs[k] >= classProbThreshold || dataSet.getMultiLabels()[dataPointIndex].matchClass(k) || predictedLabels.matchClass(k)) {
            classes.add(k);
        }
    }
    // todo
    List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
    for (int k : classes) {
        ClassScoreCalculation classScoreCalculation = decisionProcess(cbm, labelTranslator, calibratedClassProbs[k], dataSet.getRow(dataPointIndex), k, ruleLimit);
        classScoreCalculations.add(classScoreCalculation);
    }
    predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
    List<MultiLabelPredictionAnalysis.ClassRankInfo> labelRanking = classes.stream().map(label -> {
        MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
        rankInfo.setClassIndex(label);
        rankInfo.setClassName(labelTranslator.toExtLabel(label));
        rankInfo.setProb(calibratedClassProbs[label]);
        return rankInfo;
    }).collect(Collectors.toList());
    predictionAnalysis.setPredictedRanking(labelRanking);
    List<Pair<MultiLabel, Double>> topK;
    if (classifier instanceof SupportPredictor) {
        topK = TopKFinder.topKinSupport(dataSet.getRow(dataPointIndex), cbm, labelCalibrator, setCalibrator, predictionFeatureExtractor, ((SupportPredictor) classifier).getSupport(), labelSetLimit);
    } else {
        topK = TopKFinder.topK(dataSet.getRow(dataPointIndex), cbm, labelCalibrator, setCalibrator, predictionFeatureExtractor, labelSetLimit);
    }
    List<MultiLabelPredictionAnalysis.LabelSetProbInfo> labelSetRanking = topK.stream().map(pair -> {
        MultiLabel multiLabel = pair.getFirst();
        double setProb = pair.getSecond();
        MultiLabelPredictionAnalysis.LabelSetProbInfo labelSetProbInfo = new MultiLabelPredictionAnalysis.LabelSetProbInfo(multiLabel, setProb, labelTranslator);
        return labelSetProbInfo;
    }).sorted(Comparator.comparing(MultiLabelPredictionAnalysis.LabelSetProbInfo::getProbability).reversed()).limit(labelSetLimit).collect(Collectors.toList());
    predictionAnalysis.setPredictedLabelSetRanking(labelSetRanking);
    return predictionAnalysis;
}
Also used : edu.neu.ccs.pyramid.regression(edu.neu.ccs.pyramid.regression) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator) CBM(edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM) Collectors(java.util.stream.Collectors) ArrayList(java.util.ArrayList) edu.neu.ccs.pyramid.calibration(edu.neu.ccs.pyramid.calibration) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) List(java.util.List) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) SupportPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor) Vector(org.apache.mahout.math.Vector) Comparator(java.util.Comparator) Pair(edu.neu.ccs.pyramid.util.Pair) SupportPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) ArrayList(java.util.ArrayList) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 2 with SupportPredictor

use of edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor in project pyramid by cheng-li.

the class BRInspector method analyzePrediction.

public static MultiLabelPredictionAnalysis analyzePrediction(MultiLabelClassifier.ClassProbEstimator classProbEstimator, LabelCalibrator labelCalibrator, VectorCalibrator setCalibrator, MultiLabelClfDataSet dataSet, MultiLabelClassifier classifier, PredictionFeatureExtractor predictionFeatureExtractor, int dataPointIndex, int ruleLimit, int labelSetLimit, double classProbThreshold) {
    MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
    LabelTranslator labelTranslator = dataSet.getLabelTranslator();
    IdTranslator idTranslator = dataSet.getIdTranslator();
    predictionAnalysis.setInternalId(dataPointIndex);
    predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
    predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
    List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setLabels(labels);
    double[] classProbs = classProbEstimator.predictClassProbs(dataSet.getRow(dataPointIndex));
    double[] calibratedClassProbs = labelCalibrator.calibratedClassProbs(classProbs);
    PredictionCandidate trueCandidate = new PredictionCandidate();
    trueCandidate.x = dataSet.getRow(dataPointIndex);
    trueCandidate.multiLabel = dataSet.getMultiLabels()[dataPointIndex];
    trueCandidate.labelProbs = calibratedClassProbs;
    predictionAnalysis.setProbForTrueLabels(setCalibrator.calibrate(predictionFeatureExtractor.extractFeatures(trueCandidate)));
    MultiLabel predictedLabels = classifier.predict(dataSet.getRow(dataPointIndex));
    List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
    predictionAnalysis.setInternalPrediction(internalPrediction);
    List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setPrediction(prediction);
    PredictionCandidate predictedCandidate = new PredictionCandidate();
    predictedCandidate.x = dataSet.getRow(dataPointIndex);
    predictedCandidate.multiLabel = predictedLabels;
    predictedCandidate.labelProbs = calibratedClassProbs;
    predictionAnalysis.setProbForPredictedLabels(setCalibrator.calibrate(predictionFeatureExtractor.extractFeatures(predictedCandidate)));
    List<Integer> classes = new ArrayList<Integer>();
    for (int k = 0; k < classProbEstimator.getNumClasses(); k++) {
        if (calibratedClassProbs[k] >= classProbThreshold || dataSet.getMultiLabels()[dataPointIndex].matchClass(k) || predictedLabels.matchClass(k)) {
            classes.add(k);
        }
    }
    // todo
    List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
    for (int k : classes) {
        ClassScoreCalculation classScoreCalculation = null;
        if (classProbEstimator instanceof IMLGradientBoosting) {
            classScoreCalculation = decisionProcess((IMLGradientBoosting) classProbEstimator, labelTranslator, calibratedClassProbs[k], dataSet.getRow(dataPointIndex), k, ruleLimit);
        }
        if (classProbEstimator instanceof CBM) {
            classScoreCalculation = decisionProcess((CBM) classProbEstimator, labelTranslator, calibratedClassProbs[k], dataSet.getRow(dataPointIndex), k, ruleLimit);
        }
        classScoreCalculations.add(classScoreCalculation);
    }
    predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
    List<MultiLabelPredictionAnalysis.ClassRankInfo> labelRanking = classes.stream().map(label -> {
        MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
        rankInfo.setClassIndex(label);
        rankInfo.setClassName(labelTranslator.toExtLabel(label));
        rankInfo.setProb(calibratedClassProbs[label]);
        return rankInfo;
    }).collect(Collectors.toList());
    predictionAnalysis.setPredictedRanking(labelRanking);
    List<Pair<MultiLabel, Double>> topK;
    if (classifier instanceof SupportPredictor) {
        topK = TopKFinder.topKinSupport(dataSet.getRow(dataPointIndex), classProbEstimator, labelCalibrator, setCalibrator, predictionFeatureExtractor, ((SupportPredictor) classifier).getSupport(), labelSetLimit);
    } else {
        topK = TopKFinder.topK(dataSet.getRow(dataPointIndex), classProbEstimator, labelCalibrator, setCalibrator, predictionFeatureExtractor, labelSetLimit);
    }
    List<MultiLabelPredictionAnalysis.LabelSetProbInfo> labelSetRanking = topK.stream().map(pair -> {
        MultiLabel multiLabel = pair.getFirst();
        double setProb = pair.getSecond();
        MultiLabelPredictionAnalysis.LabelSetProbInfo labelSetProbInfo = new MultiLabelPredictionAnalysis.LabelSetProbInfo(multiLabel, setProb, labelTranslator);
        return labelSetProbInfo;
    }).sorted(Comparator.comparing(MultiLabelPredictionAnalysis.LabelSetProbInfo::getProbability).reversed()).limit(labelSetLimit).collect(Collectors.toList());
    predictionAnalysis.setPredictedLabelSetRanking(labelSetRanking);
    return predictionAnalysis;
}
Also used : edu.neu.ccs.pyramid.regression(edu.neu.ccs.pyramid.regression) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator) CBM(edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM) Collectors(java.util.stream.Collectors) ArrayList(java.util.ArrayList) edu.neu.ccs.pyramid.calibration(edu.neu.ccs.pyramid.calibration) IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) List(java.util.List) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) SupportPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor) Vector(org.apache.mahout.math.Vector) PriorProbClassifier(edu.neu.ccs.pyramid.classification.PriorProbClassifier) Comparator(java.util.Comparator) Pair(edu.neu.ccs.pyramid.util.Pair) ArrayList(java.util.ArrayList) CBM(edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator) Pair(edu.neu.ccs.pyramid.util.Pair) SupportPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting)

Aggregations

edu.neu.ccs.pyramid.calibration (edu.neu.ccs.pyramid.calibration)2 LogisticRegression (edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)2 IdTranslator (edu.neu.ccs.pyramid.dataset.IdTranslator)2 LabelTranslator (edu.neu.ccs.pyramid.dataset.LabelTranslator)2 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)2 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)2 Feature (edu.neu.ccs.pyramid.feature.Feature)2 CBM (edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM)2 SupportPredictor (edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor)2 edu.neu.ccs.pyramid.regression (edu.neu.ccs.pyramid.regression)2 Pair (edu.neu.ccs.pyramid.util.Pair)2 ArrayList (java.util.ArrayList)2 Comparator (java.util.Comparator)2 List (java.util.List)2 Collectors (java.util.stream.Collectors)2 Vector (org.apache.mahout.math.Vector)2 PriorProbClassifier (edu.neu.ccs.pyramid.classification.PriorProbClassifier)1 IMLGradientBoosting (edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting)1 RegressionTree (edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree)1 TreeRule (edu.neu.ccs.pyramid.regression.regression_tree.TreeRule)1