Search in sources :

Example 6 with DynamicProgramming

use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.

the class CalibrationDataGenerator method createInstance.

public CalibrationInstance createInstance(Vector x, double[] uncalibratedLabelScores, MultiLabel prediction, MultiLabel groundtruth, String calibrateTarget) {
    PredictionCandidate predictionCandidate = new PredictionCandidate();
    predictionCandidate.x = x;
    predictionCandidate.multiLabel = prediction;
    predictionCandidate.labelProbs = labelCalibrator.calibratedClassProbs(uncalibratedLabelScores);
    DynamicProgramming dynamicProgramming = new DynamicProgramming(predictionCandidate.labelProbs);
    List<Pair<MultiLabel, Double>> sparseJoint = dynamicProgramming.topK(50);
    predictionCandidate.sparseJoint = sparseJoint;
    return createInstance(groundtruth, predictionCandidate, calibrateTarget);
}
Also used : DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 7 with DynamicProgramming

use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.

the class CalibrationDataGenerator method expand.

private List<CalibrationInstance> expand(Vector x, MultiLabel groundTruth, double[] uncalibratedLabelScores, int queryId, int numCandidates, String calibrateTarget, List<MultiLabel> support, int numSupportCandidates) {
    double[] marginals = labelCalibrator.calibratedClassProbs(uncalibratedLabelScores);
    List<CalibrationInstance> calibrationInstances = new ArrayList<>();
    DynamicProgramming dynamicProgramming = new DynamicProgramming(marginals);
    List<Pair<MultiLabel, Double>> topK = dynamicProgramming.topK(numCandidates);
    // add a few random candidates from support
    List<MultiLabel> supportList = new ArrayList<>(support);
    Collections.shuffle(supportList, new Random(queryId));
    List<MultiLabel> candidates = new ArrayList<>();
    for (int i = 0; i < numSupportCandidates; i++) {
        candidates.add(supportList.get(i));
    }
    // add top K from DP
    for (Pair<MultiLabel, Double> pair : topK) {
        candidates.add(pair.getFirst());
    }
    for (MultiLabel multiLabel : candidates) {
        PredictionCandidate predictionCandidate = new PredictionCandidate();
        predictionCandidate.multiLabel = multiLabel;
        predictionCandidate.labelProbs = marginals;
        predictionCandidate.x = x;
        predictionCandidate.sparseJoint = topK;
        CalibrationInstance calibrationInstance = createInstance(groundTruth, predictionCandidate, calibrateTarget);
        calibrationInstance.weight = 1;
        calibrationInstance.queryIndex = queryId;
        calibrationInstances.add(calibrationInstance);
    }
    return calibrationInstances;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 8 with DynamicProgramming

use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.

the class IMLGBInspector method analyzePrediction.

//todo  speed up
public static MultiLabelPredictionAnalysis analyzePrediction(IMLGradientBoosting boosting, PluginPredictor<IMLGradientBoosting> pluginPredictor, MultiLabelClfDataSet dataSet, int dataPointIndex, int ruleLimit, int labelSetLimit, double classProbThreshold) {
    MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
    LabelTranslator labelTranslator = dataSet.getLabelTranslator();
    IdTranslator idTranslator = dataSet.getIdTranslator();
    predictionAnalysis.setInternalId(dataPointIndex);
    predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
    predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
    List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setLabels(labels);
    if (pluginPredictor instanceof SubsetAccPredictor || pluginPredictor instanceof InstanceF1Predictor) {
        predictionAnalysis.setProbForTrueLabels(boosting.predictAssignmentProbWithConstraint(dataSet.getRow(dataPointIndex), dataSet.getMultiLabels()[dataPointIndex]));
    }
    if (pluginPredictor instanceof HammingPredictor || pluginPredictor instanceof MacroF1Predictor) {
        predictionAnalysis.setProbForTrueLabels(boosting.predictAssignmentProbWithoutConstraint(dataSet.getRow(dataPointIndex), dataSet.getMultiLabels()[dataPointIndex]));
    }
    MultiLabel predictedLabels = pluginPredictor.predict(dataSet.getRow(dataPointIndex));
    List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
    predictionAnalysis.setInternalPrediction(internalPrediction);
    List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setPrediction(prediction);
    if (pluginPredictor instanceof SubsetAccPredictor || pluginPredictor instanceof InstanceF1Predictor) {
        predictionAnalysis.setProbForPredictedLabels(boosting.predictAssignmentProbWithConstraint(dataSet.getRow(dataPointIndex), predictedLabels));
    }
    if (pluginPredictor instanceof HammingPredictor || pluginPredictor instanceof MacroF1Predictor) {
        predictionAnalysis.setProbForPredictedLabels(boosting.predictAssignmentProbWithoutConstraint(dataSet.getRow(dataPointIndex), predictedLabels));
    }
    double[] classProbs = boosting.predictClassProbs(dataSet.getRow(dataPointIndex));
    List<Integer> classes = new ArrayList<Integer>();
    for (int k = 0; k < boosting.getNumClasses(); k++) {
        if (classProbs[k] >= classProbThreshold || dataSet.getMultiLabels()[dataPointIndex].matchClass(k) || predictedLabels.matchClass(k)) {
            classes.add(k);
        }
    }
    List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
    for (int k : classes) {
        ClassScoreCalculation classScoreCalculation = decisionProcess(boosting, labelTranslator, dataSet.getRow(dataPointIndex), k, ruleLimit);
        classScoreCalculations.add(classScoreCalculation);
    }
    predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
    List<MultiLabelPredictionAnalysis.ClassRankInfo> labelRanking = classes.stream().map(label -> {
        MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
        rankInfo.setClassIndex(label);
        rankInfo.setClassName(labelTranslator.toExtLabel(label));
        rankInfo.setProb(boosting.predictClassProb(dataSet.getRow(dataPointIndex), label));
        return rankInfo;
    }).collect(Collectors.toList());
    predictionAnalysis.setPredictedRanking(labelRanking);
    List<MultiLabelPredictionAnalysis.LabelSetProbInfo> labelSetRanking = null;
    if (pluginPredictor instanceof SubsetAccPredictor || pluginPredictor instanceof InstanceF1Predictor) {
        double[] labelSetProbs = boosting.predictAllAssignmentProbsWithConstraint(dataSet.getRow(dataPointIndex));
        labelSetRanking = IntStream.range(0, boosting.getAssignments().size()).mapToObj(i -> {
            MultiLabel multiLabel = boosting.getAssignments().get(i);
            double setProb = labelSetProbs[i];
            MultiLabelPredictionAnalysis.LabelSetProbInfo labelSetProbInfo = new MultiLabelPredictionAnalysis.LabelSetProbInfo(multiLabel, setProb, labelTranslator);
            return labelSetProbInfo;
        }).sorted(Comparator.comparing(MultiLabelPredictionAnalysis.LabelSetProbInfo::getProbability).reversed()).limit(labelSetLimit).collect(Collectors.toList());
    }
    if (pluginPredictor instanceof HammingPredictor || pluginPredictor instanceof MacroF1Predictor) {
        labelSetRanking = new ArrayList<>();
        DynamicProgramming dp = new DynamicProgramming(classProbs);
        for (int c = 0; c < labelSetLimit; c++) {
            DynamicProgramming.Candidate candidate = dp.nextHighest();
            MultiLabel multiLabel = candidate.getMultiLabel();
            double setProb = candidate.getProbability();
            MultiLabelPredictionAnalysis.LabelSetProbInfo labelSetProbInfo = new MultiLabelPredictionAnalysis.LabelSetProbInfo(multiLabel, setProb, labelTranslator);
            labelSetRanking.add(labelSetProbInfo);
        }
    }
    predictionAnalysis.setPredictedLabelSetRanking(labelSetRanking);
    return predictionAnalysis;
}
Also used : MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) edu.neu.ccs.pyramid.regression(edu.neu.ccs.pyramid.regression) IntStream(java.util.stream.IntStream) java.util(java.util) DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming) RegTreeInspector(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector) PluginPredictor(edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor) Collectors(java.util.stream.Collectors) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) Feature(edu.neu.ccs.pyramid.feature.Feature) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Pair(edu.neu.ccs.pyramid.util.Pair) DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming)

Example 9 with DynamicProgramming

use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.

the class Reranker method predict.

public MultiLabel predict(Vector vector, double[] uncalibratedLabelScores) {
    double[] marginals = labelCalibrator.calibratedClassProbs(uncalibratedLabelScores);
    DynamicProgramming dynamicProgramming = new DynamicProgramming(marginals);
    List<Pair<MultiLabel, Double>> sparseJoint = dynamicProgramming.topK(numCandidate);
    List<MultiLabel> multiLabels = sparseJoint.stream().map(pair -> pair.getFirst()).filter(candidate -> candidate.getNumMatchedLabels() >= minPredictionSize && candidate.getNumMatchedLabels() <= maxPredictionSize).collect(Collectors.toList());
    List<Pair<MultiLabel, Double>> candidates = new ArrayList<>();
    if (multiLabels.isEmpty()) {
        int[] sorted = ArgSort.argSortDescending(marginals);
        MultiLabel multiLabel = new MultiLabel();
        for (int i = 0; i < minPredictionSize; i++) {
            multiLabel.addLabel(sorted[i]);
        }
        multiLabels.add(multiLabel);
    }
    for (MultiLabel candidate : multiLabels) {
        PredictionCandidate predictionCandidate = new PredictionCandidate();
        predictionCandidate.x = vector;
        predictionCandidate.labelProbs = marginals;
        predictionCandidate.multiLabel = candidate;
        predictionCandidate.sparseJoint = sparseJoint;
        Vector feature = predictionFeatureExtractor.extractFeatures(predictionCandidate);
        double score = regressor.predict(feature);
        candidates.add(new Pair<>(candidate, score));
    }
    Comparator<Pair<MultiLabel, Double>> comparator = Comparator.comparing(pair -> pair.getSecond());
    return candidates.stream().max(comparator).map(Pair::getFirst).get();
}
Also used : BMDistribution(edu.neu.ccs.pyramid.multilabel_classification.cbm.BMDistribution) java.util(java.util) ArgSort(edu.neu.ccs.pyramid.util.ArgSort) DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming) ArgMax(edu.neu.ccs.pyramid.util.ArgMax) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) CBM(edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList) Collectors(java.util.stream.Collectors) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) Regressor(edu.neu.ccs.pyramid.regression.Regressor) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) Vector(org.apache.mahout.math.Vector) Pair(edu.neu.ccs.pyramid.util.Pair) GeneralF1Predictor(edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming) Vector(org.apache.mahout.math.Vector) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 10 with DynamicProgramming

use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.

the class BRPrediction method simplePredictionAnalysisCalibrated.

public static String simplePredictionAnalysisCalibrated(MultiLabelClassifier.ClassProbEstimator classProbEstimator, LabelCalibrator labelCalibrator, VectorCalibrator setCalibrator, MultiLabelClfDataSet dataSet, int dataPointIndex, MultiLabelClassifier classifier, PredictionFeatureExtractor predictionFeatureExtractor) {
    StringBuilder sb = new StringBuilder();
    MultiLabel trueLabels = dataSet.getMultiLabels()[dataPointIndex];
    String id = dataSet.getIdTranslator().toExtId(dataPointIndex);
    LabelTranslator labelTranslator = dataSet.getLabelTranslator();
    double[] classProbs = classProbEstimator.predictClassProbs(dataSet.getRow(dataPointIndex));
    double[] calibratedClassProbs = labelCalibrator.calibratedClassProbs(classProbs);
    MultiLabel predicted = classifier.predict(dataSet.getRow(dataPointIndex));
    List<Integer> classes = new ArrayList<Integer>();
    for (int k = 0; k < dataSet.getNumClasses(); k++) {
        if (dataSet.getMultiLabels()[dataPointIndex].matchClass(k) || predicted.matchClass(k)) {
            classes.add(k);
        }
    }
    Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(Pair::getSecond);
    List<Pair<Integer, Double>> list = classes.stream().map(l -> {
        if (l < classProbEstimator.getNumClasses()) {
            return new Pair<>(l, calibratedClassProbs[l]);
        } else {
            return new Pair<>(l, 0.0);
        }
    }).sorted(comparator.reversed()).collect(Collectors.toList());
    for (Pair<Integer, Double> pair : list) {
        int label = pair.getFirst();
        double prob = pair.getSecond();
        int match = 0;
        if (trueLabels.matchClass(label)) {
            match = 1;
        }
        sb.append(id).append("\t").append(labelTranslator.toExtLabel(label)).append("\t").append("single").append("\t").append(prob).append("\t").append(match).append("\t").append("NA").append("\t").append("NA").append("\t").append("NA").append("\t").append("NA").append("\n");
    }
    PredictionCandidate predictedCandidate = new PredictionCandidate();
    predictedCandidate.multiLabel = predicted;
    predictedCandidate.labelProbs = calibratedClassProbs;
    predictedCandidate.x = dataSet.getRow(dataPointIndex);
    DynamicProgramming dynamicProgramming = new DynamicProgramming(calibratedClassProbs);
    List<Pair<MultiLabel, Double>> sparseJoint = dynamicProgramming.topK(50);
    predictedCandidate.sparseJoint = sparseJoint;
    Vector feature = predictionFeatureExtractor.extractFeatures(predictedCandidate);
    double probability = setCalibrator.calibrate(feature);
    List<Integer> predictedList = predicted.getMatchedLabelsOrdered();
    MultiLabel prediction = new MultiLabel();
    sb.append(id).append("\t");
    for (int i = 0; i < predictedList.size(); i++) {
        prediction.addLabel(predictedList.get(i));
        sb.append(labelTranslator.toExtLabel(predictedList.get(i)));
        if (i != predictedList.size() - 1) {
            sb.append(",");
        }
    }
    sb.append("\t");
    int setMatch = 0;
    if (predicted.equals(trueLabels)) {
        setMatch = 1;
    }
    List<Integer> truthList = trueLabels.getMatchedLabels().stream().sorted().collect(Collectors.toList());
    StringBuilder sbLabels = new StringBuilder();
    for (int i = 0; i < truthList.size(); i++) {
        if (i != truthList.size() - 1) {
            sbLabels.append(labelTranslator.toExtLabel(truthList.get(i))).append(",");
        } else {
            sbLabels.append(labelTranslator.toExtLabel(truthList.get(i)));
        }
    }
    double precision = Precision.precision(trueLabels, prediction);
    double recall = Recall.recall(trueLabels, prediction);
    double f1 = FMeasure.f1(precision, recall);
    sb.append("set").append("\t").append(probability).append("\t").append(setMatch).append("\t").append(sbLabels.toString()).append("\t").append(precision).append("\t").append(recall).append("\t").append(f1).append("\n");
    return sb.toString();
}
Also used : ArrayList(java.util.ArrayList) DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming) Vector(org.apache.mahout.math.Vector)

Aggregations

DynamicProgramming (edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming)12 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)8 Pair (edu.neu.ccs.pyramid.util.Pair)8 Vector (org.apache.mahout.math.Vector)7 BMDistribution (edu.neu.ccs.pyramid.multilabel_classification.cbm.BMDistribution)3 java.util (java.util)3 Collectors (java.util.stream.Collectors)3 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)2 CBM (edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM)2 ArgMax (edu.neu.ccs.pyramid.util.ArgMax)2 ArgSort (edu.neu.ccs.pyramid.util.ArgSort)2 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)1 LabelTranslator (edu.neu.ccs.pyramid.dataset.LabelTranslator)1 Feature (edu.neu.ccs.pyramid.feature.Feature)1 FeatureList (edu.neu.ccs.pyramid.feature.FeatureList)1 TopFeatures (edu.neu.ccs.pyramid.feature.TopFeatures)1 MultiLabelPredictionAnalysis (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis)1 PluginPredictor (edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor)1 GeneralF1Predictor (edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)1 edu.neu.ccs.pyramid.regression (edu.neu.ccs.pyramid.regression)1