Search in sources :

Example 1 with LabelTranslator

use of edu.neu.ccs.pyramid.dataset.LabelTranslator in project pyramid by cheng-li.

the class CRFInspector method pairRelations.

public static String pairRelations(CMLCRF crf) {
    List<CRFInspector.PairWeight> list = crf.getWeights().printPairWeights();
    Comparator<PairWeight> comparator = Comparator.comparing(pairWeight -> Math.abs(pairWeight.weight));
    List<CRFInspector.PairWeight> sorted = list.stream().sorted(comparator.reversed()).collect(Collectors.toList());
    StringBuilder sb = new StringBuilder();
    LabelTranslator labelTranslator = crf.getLabelTranslator();
    for (CRFInspector.PairWeight pairWeight : sorted) {
        sb.append(labelTranslator.toExtLabel(pairWeight.label1)).append(":").append(pairWeight.hasLabel1).append(", ").append(labelTranslator.toExtLabel(pairWeight.label2)).append(":").append(pairWeight.hasLabel2).append("=>").append(pairWeight.weight).append("\n");
    }
    return sb.toString();
}
Also used : LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator)

Example 2 with LabelTranslator

use of edu.neu.ccs.pyramid.dataset.LabelTranslator in project pyramid by cheng-li.

the class BRLRInspector method analyzePrediction.

public static MultiLabelPredictionAnalysis analyzePrediction(CBM cbm, LabelCalibrator labelCalibrator, VectorCalibrator setCalibrator, MultiLabelClfDataSet dataSet, MultiLabelClassifier classifier, PredictionFeatureExtractor predictionFeatureExtractor, int dataPointIndex, int ruleLimit, int labelSetLimit, double classProbThreshold) {
    MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
    LabelTranslator labelTranslator = dataSet.getLabelTranslator();
    IdTranslator idTranslator = dataSet.getIdTranslator();
    predictionAnalysis.setInternalId(dataPointIndex);
    predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
    predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
    List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setLabels(labels);
    double[] classProbs = cbm.predictClassProbs(dataSet.getRow(dataPointIndex));
    double[] calibratedClassProbs = labelCalibrator.calibratedClassProbs(classProbs);
    PredictionCandidate trueCandidate = new PredictionCandidate();
    trueCandidate.x = dataSet.getRow(dataPointIndex);
    trueCandidate.multiLabel = dataSet.getMultiLabels()[dataPointIndex];
    trueCandidate.labelProbs = calibratedClassProbs;
    predictionAnalysis.setProbForTrueLabels(setCalibrator.calibrate(predictionFeatureExtractor.extractFeatures(trueCandidate)));
    MultiLabel predictedLabels = classifier.predict(dataSet.getRow(dataPointIndex));
    List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
    predictionAnalysis.setInternalPrediction(internalPrediction);
    List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setPrediction(prediction);
    PredictionCandidate predictedCandidate = new PredictionCandidate();
    predictedCandidate.x = dataSet.getRow(dataPointIndex);
    predictedCandidate.multiLabel = predictedLabels;
    predictedCandidate.labelProbs = calibratedClassProbs;
    predictionAnalysis.setProbForPredictedLabels(setCalibrator.calibrate(predictionFeatureExtractor.extractFeatures(predictedCandidate)));
    List<Integer> classes = new ArrayList<Integer>();
    for (int k = 0; k < cbm.getNumClasses(); k++) {
        if (calibratedClassProbs[k] >= classProbThreshold || dataSet.getMultiLabels()[dataPointIndex].matchClass(k) || predictedLabels.matchClass(k)) {
            classes.add(k);
        }
    }
    // todo
    List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
    for (int k : classes) {
        ClassScoreCalculation classScoreCalculation = decisionProcess(cbm, labelTranslator, calibratedClassProbs[k], dataSet.getRow(dataPointIndex), k, ruleLimit);
        classScoreCalculations.add(classScoreCalculation);
    }
    predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
    List<MultiLabelPredictionAnalysis.ClassRankInfo> labelRanking = classes.stream().map(label -> {
        MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
        rankInfo.setClassIndex(label);
        rankInfo.setClassName(labelTranslator.toExtLabel(label));
        rankInfo.setProb(calibratedClassProbs[label]);
        return rankInfo;
    }).collect(Collectors.toList());
    predictionAnalysis.setPredictedRanking(labelRanking);
    List<Pair<MultiLabel, Double>> topK;
    if (classifier instanceof SupportPredictor) {
        topK = TopKFinder.topKinSupport(dataSet.getRow(dataPointIndex), cbm, labelCalibrator, setCalibrator, predictionFeatureExtractor, ((SupportPredictor) classifier).getSupport(), labelSetLimit);
    } else {
        topK = TopKFinder.topK(dataSet.getRow(dataPointIndex), cbm, labelCalibrator, setCalibrator, predictionFeatureExtractor, labelSetLimit);
    }
    List<MultiLabelPredictionAnalysis.LabelSetProbInfo> labelSetRanking = topK.stream().map(pair -> {
        MultiLabel multiLabel = pair.getFirst();
        double setProb = pair.getSecond();
        MultiLabelPredictionAnalysis.LabelSetProbInfo labelSetProbInfo = new MultiLabelPredictionAnalysis.LabelSetProbInfo(multiLabel, setProb, labelTranslator);
        return labelSetProbInfo;
    }).sorted(Comparator.comparing(MultiLabelPredictionAnalysis.LabelSetProbInfo::getProbability).reversed()).limit(labelSetLimit).collect(Collectors.toList());
    predictionAnalysis.setPredictedLabelSetRanking(labelSetRanking);
    return predictionAnalysis;
}
Also used : edu.neu.ccs.pyramid.regression(edu.neu.ccs.pyramid.regression) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator) CBM(edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM) Collectors(java.util.stream.Collectors) ArrayList(java.util.ArrayList) edu.neu.ccs.pyramid.calibration(edu.neu.ccs.pyramid.calibration) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) List(java.util.List) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) SupportPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor) Vector(org.apache.mahout.math.Vector) Comparator(java.util.Comparator) Pair(edu.neu.ccs.pyramid.util.Pair) SupportPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) ArrayList(java.util.ArrayList) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 3 with LabelTranslator

use of edu.neu.ccs.pyramid.dataset.LabelTranslator in project pyramid by cheng-li.

the class BRInspector method analyzePrediction.

public static MultiLabelPredictionAnalysis analyzePrediction(MultiLabelClassifier.ClassProbEstimator classProbEstimator, LabelCalibrator labelCalibrator, VectorCalibrator setCalibrator, MultiLabelClfDataSet dataSet, MultiLabelClassifier classifier, PredictionFeatureExtractor predictionFeatureExtractor, int dataPointIndex, int ruleLimit, int labelSetLimit, double classProbThreshold) {
    MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
    LabelTranslator labelTranslator = dataSet.getLabelTranslator();
    IdTranslator idTranslator = dataSet.getIdTranslator();
    predictionAnalysis.setInternalId(dataPointIndex);
    predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
    predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
    List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setLabels(labels);
    double[] classProbs = classProbEstimator.predictClassProbs(dataSet.getRow(dataPointIndex));
    double[] calibratedClassProbs = labelCalibrator.calibratedClassProbs(classProbs);
    PredictionCandidate trueCandidate = new PredictionCandidate();
    trueCandidate.x = dataSet.getRow(dataPointIndex);
    trueCandidate.multiLabel = dataSet.getMultiLabels()[dataPointIndex];
    trueCandidate.labelProbs = calibratedClassProbs;
    predictionAnalysis.setProbForTrueLabels(setCalibrator.calibrate(predictionFeatureExtractor.extractFeatures(trueCandidate)));
    MultiLabel predictedLabels = classifier.predict(dataSet.getRow(dataPointIndex));
    List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
    predictionAnalysis.setInternalPrediction(internalPrediction);
    List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setPrediction(prediction);
    PredictionCandidate predictedCandidate = new PredictionCandidate();
    predictedCandidate.x = dataSet.getRow(dataPointIndex);
    predictedCandidate.multiLabel = predictedLabels;
    predictedCandidate.labelProbs = calibratedClassProbs;
    predictionAnalysis.setProbForPredictedLabels(setCalibrator.calibrate(predictionFeatureExtractor.extractFeatures(predictedCandidate)));
    List<Integer> classes = new ArrayList<Integer>();
    for (int k = 0; k < classProbEstimator.getNumClasses(); k++) {
        if (calibratedClassProbs[k] >= classProbThreshold || dataSet.getMultiLabels()[dataPointIndex].matchClass(k) || predictedLabels.matchClass(k)) {
            classes.add(k);
        }
    }
    // todo
    List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
    for (int k : classes) {
        ClassScoreCalculation classScoreCalculation = null;
        if (classProbEstimator instanceof IMLGradientBoosting) {
            classScoreCalculation = decisionProcess((IMLGradientBoosting) classProbEstimator, labelTranslator, calibratedClassProbs[k], dataSet.getRow(dataPointIndex), k, ruleLimit);
        }
        if (classProbEstimator instanceof CBM) {
            classScoreCalculation = decisionProcess((CBM) classProbEstimator, labelTranslator, calibratedClassProbs[k], dataSet.getRow(dataPointIndex), k, ruleLimit);
        }
        classScoreCalculations.add(classScoreCalculation);
    }
    predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
    List<MultiLabelPredictionAnalysis.ClassRankInfo> labelRanking = classes.stream().map(label -> {
        MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
        rankInfo.setClassIndex(label);
        rankInfo.setClassName(labelTranslator.toExtLabel(label));
        rankInfo.setProb(calibratedClassProbs[label]);
        return rankInfo;
    }).collect(Collectors.toList());
    predictionAnalysis.setPredictedRanking(labelRanking);
    List<Pair<MultiLabel, Double>> topK;
    if (classifier instanceof SupportPredictor) {
        topK = TopKFinder.topKinSupport(dataSet.getRow(dataPointIndex), classProbEstimator, labelCalibrator, setCalibrator, predictionFeatureExtractor, ((SupportPredictor) classifier).getSupport(), labelSetLimit);
    } else {
        topK = TopKFinder.topK(dataSet.getRow(dataPointIndex), classProbEstimator, labelCalibrator, setCalibrator, predictionFeatureExtractor, labelSetLimit);
    }
    List<MultiLabelPredictionAnalysis.LabelSetProbInfo> labelSetRanking = topK.stream().map(pair -> {
        MultiLabel multiLabel = pair.getFirst();
        double setProb = pair.getSecond();
        MultiLabelPredictionAnalysis.LabelSetProbInfo labelSetProbInfo = new MultiLabelPredictionAnalysis.LabelSetProbInfo(multiLabel, setProb, labelTranslator);
        return labelSetProbInfo;
    }).sorted(Comparator.comparing(MultiLabelPredictionAnalysis.LabelSetProbInfo::getProbability).reversed()).limit(labelSetLimit).collect(Collectors.toList());
    predictionAnalysis.setPredictedLabelSetRanking(labelSetRanking);
    return predictionAnalysis;
}
Also used : edu.neu.ccs.pyramid.regression(edu.neu.ccs.pyramid.regression) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator) CBM(edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM) Collectors(java.util.stream.Collectors) ArrayList(java.util.ArrayList) edu.neu.ccs.pyramid.calibration(edu.neu.ccs.pyramid.calibration) IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) List(java.util.List) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) SupportPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor) Vector(org.apache.mahout.math.Vector) PriorProbClassifier(edu.neu.ccs.pyramid.classification.PriorProbClassifier) Comparator(java.util.Comparator) Pair(edu.neu.ccs.pyramid.util.Pair) ArrayList(java.util.ArrayList) CBM(edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator) Pair(edu.neu.ccs.pyramid.util.Pair) SupportPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting)

Example 4 with LabelTranslator

use of edu.neu.ccs.pyramid.dataset.LabelTranslator in project pyramid by cheng-li.

the class LKBInspector method topFeatures.

public static TopFeatures topFeatures(LKBoost boosting, int classIndex, int limit) {
    Map<Feature, Double> totalContributions = new HashMap<>();
    List<Regressor> regressors = boosting.getEnsemble(classIndex).getRegressors();
    List<RegressionTree> trees = regressors.stream().filter(regressor -> regressor instanceof RegressionTree).map(regressor -> (RegressionTree) regressor).collect(Collectors.toList());
    for (RegressionTree tree : trees) {
        Map<Feature, Double> contributions = RegTreeInspector.featureImportance(tree);
        for (Map.Entry<Feature, Double> entry : contributions.entrySet()) {
            Feature feature = entry.getKey();
            Double contribution = entry.getValue();
            double oldValue = totalContributions.getOrDefault(feature, 0.0);
            double newValue = oldValue + contribution;
            totalContributions.put(feature, newValue);
        }
    }
    Comparator<Map.Entry<Feature, Double>> comparator = Comparator.comparing(Map.Entry::getValue);
    List<Feature> list = totalContributions.entrySet().stream().sorted(comparator.reversed()).limit(limit).map(Map.Entry::getKey).collect(Collectors.toList());
    TopFeatures topFeatures = new TopFeatures();
    topFeatures.setTopFeatures(list);
    topFeatures.setClassIndex(classIndex);
    LabelTranslator labelTranslator = boosting.getLabelTranslator();
    topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
    return topFeatures;
}
Also used : edu.neu.ccs.pyramid.regression(edu.neu.ccs.pyramid.regression) ClassProbability(edu.neu.ccs.pyramid.classification.ClassProbability) java.util(java.util) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator) ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) RegTreeInspector(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector) Collectors(java.util.stream.Collectors) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) PredictionAnalysis(edu.neu.ccs.pyramid.classification.PredictionAnalysis) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator)

Example 5 with LabelTranslator

use of edu.neu.ccs.pyramid.dataset.LabelTranslator in project pyramid by cheng-li.

the class LKBInspector method topFeatures.

/**
 * @param lkBoosts ensemble of lktbs
 * @param classIndex
 * @return
 */
public static TopFeatures topFeatures(List<LKBoost> lkBoosts, int classIndex) {
    Map<Feature, Double> totalContributions = new HashMap<>();
    for (LKBoost lkBoost : lkBoosts) {
        List<Regressor> regressors = lkBoost.getEnsemble(classIndex).getRegressors();
        List<RegressionTree> trees = regressors.stream().filter(regressor -> regressor instanceof RegressionTree).map(regressor -> (RegressionTree) regressor).collect(Collectors.toList());
        for (RegressionTree tree : trees) {
            Map<Feature, Double> contributions = RegTreeInspector.featureImportance(tree);
            for (Map.Entry<Feature, Double> entry : contributions.entrySet()) {
                Feature feature = entry.getKey();
                Double contribution = entry.getValue();
                double oldValue = totalContributions.getOrDefault(feature, 0.0);
                double newValue = oldValue + contribution;
                totalContributions.put(feature, newValue);
            }
        }
    }
    Comparator<Map.Entry<Feature, Double>> comparator = Comparator.comparing(Map.Entry::getValue);
    List<Feature> list = totalContributions.entrySet().stream().sorted(comparator.reversed()).map(Map.Entry::getKey).collect(Collectors.toList());
    TopFeatures topFeatures = new TopFeatures();
    topFeatures.setTopFeatures(list);
    topFeatures.setClassIndex(classIndex);
    LabelTranslator labelTranslator = lkBoosts.get(0).getLabelTranslator();
    topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
    return topFeatures;
}
Also used : edu.neu.ccs.pyramid.regression(edu.neu.ccs.pyramid.regression) ClassProbability(edu.neu.ccs.pyramid.classification.ClassProbability) java.util(java.util) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator) ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) RegTreeInspector(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector) Collectors(java.util.stream.Collectors) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) PredictionAnalysis(edu.neu.ccs.pyramid.classification.PredictionAnalysis) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator)

Aggregations

LabelTranslator (edu.neu.ccs.pyramid.dataset.LabelTranslator)12 IdTranslator (edu.neu.ccs.pyramid.dataset.IdTranslator)10 edu.neu.ccs.pyramid.regression (edu.neu.ccs.pyramid.regression)8 Collectors (java.util.stream.Collectors)8 Vector (org.apache.mahout.math.Vector)8 ClassProbability (edu.neu.ccs.pyramid.classification.ClassProbability)6 PredictionAnalysis (edu.neu.ccs.pyramid.classification.PredictionAnalysis)6 Feature (edu.neu.ccs.pyramid.feature.Feature)6 java.util (java.util)6 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)5 RegressionTree (edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree)5 TreeRule (edu.neu.ccs.pyramid.regression.regression_tree.TreeRule)5 Pair (edu.neu.ccs.pyramid.util.Pair)5 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)4 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)4 TopFeatures (edu.neu.ccs.pyramid.feature.TopFeatures)4 RegTreeInspector (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector)4 IMLGradientBoosting (edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting)3 ArrayList (java.util.ArrayList)3 edu.neu.ccs.pyramid.calibration (edu.neu.ccs.pyramid.calibration)2