Search in sources :

Example 1 with DynamicProgramming

use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.

the class CBMPredictor method predictByDynamic.

public MultiLabel predictByDynamic() {
    // initialization
    DynamicProgramming[] DPs = new DynamicProgramming[numClusters];
    double[] minClusterProb = new double[numClusters];
    for (int k = 0; k < numClusters; k++) {
        DPs[k] = new DynamicProgramming(probs[k], logProbs[k]);
        minClusterProb[k] = Double.POSITIVE_INFINITY;
    }
    double maxLogProb = Double.NEGATIVE_INFINITY;
    MultiLabel bestMultiLabel = new MultiLabel();
    int maxIter = 10;
    for (int iter = 0; iter < maxIter; iter++) {
        for (int k = 0; k < numClusters; k++) {
            DynamicProgramming dp = DPs[k];
            double prob = dp.nextHighestProb();
            MultiLabel multiLabel = dp.nextHighestVector();
            minClusterProb[k] = prob;
            double threshold = computeThreshold(minClusterProb);
            boolean isCandidateValid = true;
            if ((multiLabel.getNumMatchedLabels() == 0) && !allowEmpty) {
                isCandidateValid = false;
            }
            if (isCandidateValid) {
                double logProb = logProbYnGivenXnLogisticProb(multiLabel);
                if (logProb >= maxLogProb) {
                    bestMultiLabel = multiLabel;
                    maxLogProb = logProb;
                }
            }
            if (Math.exp(maxLogProb) >= threshold) {
                return bestMultiLabel;
            }
        }
    }
    return bestMultiLabel;
}
Also used : DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 2 with DynamicProgramming

use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.

the class SupportPredictor method predict.

@Override
public MultiLabel predict(Vector vector) {
    double[] uncali = classifier.predictClassProbs(vector);
    double[] marginals = labelCalibrator.calibratedClassProbs(uncali);
    List<Pair<MultiLabel, Double>> candidates = new ArrayList<>();
    DynamicProgramming dynamicProgramming = new DynamicProgramming(marginals);
    List<Pair<MultiLabel, Double>> sparseJoint = dynamicProgramming.topK(50);
    for (MultiLabel candidate : support) {
        PredictionCandidate predictionCandidate = new PredictionCandidate();
        predictionCandidate.x = vector;
        predictionCandidate.labelProbs = marginals;
        predictionCandidate.multiLabel = candidate;
        predictionCandidate.sparseJoint = sparseJoint;
        Vector feature = predictionFeatureExtractor.extractFeatures(predictionCandidate);
        double score = setCalibrator.calibrate(feature);
        candidates.add(new Pair<>(candidate, score));
    }
    Comparator<Pair<MultiLabel, Double>> comparator = Comparator.comparing(Pair::getSecond);
    return candidates.stream().max(comparator).map(Pair::getFirst).get();
}
Also used : DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) Vector(org.apache.mahout.math.Vector) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 3 with DynamicProgramming

use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.

the class Reranker method prob.

public double prob(Vector vector, MultiLabel multiLabel) {
    double[] marginals = labelCalibrator.calibratedClassProbs(classProbEstimator.predictClassProbs(vector));
    DynamicProgramming dynamicProgramming = new DynamicProgramming(marginals);
    List<Pair<MultiLabel, Double>> topK = dynamicProgramming.topK(numCandidate);
    PredictionCandidate predictionCandidate = new PredictionCandidate();
    predictionCandidate.x = vector;
    predictionCandidate.labelProbs = marginals;
    predictionCandidate.multiLabel = multiLabel;
    predictionCandidate.sparseJoint = topK;
    Vector feature = predictionFeatureExtractor.extractFeatures(predictionCandidate);
    double score = regressor.predict(feature);
    if (score > 1) {
        score = 1;
    }
    if (score < 0) {
        score = 0;
    }
    return score;
}
Also used : DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming) Vector(org.apache.mahout.math.Vector) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 4 with DynamicProgramming

use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.

the class TopKFinder method topK.

public static List<Pair<MultiLabel, Double>> topK(Vector x, MultiLabelClassifier.ClassProbEstimator classProbEstimator, LabelCalibrator labelCalibrator, VectorCalibrator vectorCalibrator, PredictionFeatureExtractor predictionFeatureExtractor, int minSetSize, int maxSetSize, int top) {
    double[] uncalibratedMarginals = classProbEstimator.predictClassProbs(x);
    double[] marginals = labelCalibrator.calibratedClassProbs(uncalibratedMarginals);
    DynamicProgramming dynamicProgramming = new DynamicProgramming(marginals);
    // todo better
    List<Pair<MultiLabel, Double>> sparseJoint = dynamicProgramming.topK(50);
    List<MultiLabel> multiLabels = sparseJoint.stream().map(pair -> pair.getFirst()).filter(candidate -> candidate.getNumMatchedLabels() >= minSetSize && candidate.getNumMatchedLabels() <= maxSetSize).collect(Collectors.toList());
    List<Pair<MultiLabel, Double>> candidates = new ArrayList<>();
    if (multiLabels.isEmpty()) {
        int[] sorted = ArgSort.argSortDescending(marginals);
        MultiLabel multiLabel = new MultiLabel();
        for (int i = 0; i < minSetSize; i++) {
            multiLabel.addLabel(sorted[i]);
        }
        multiLabels.add(multiLabel);
    }
    for (MultiLabel candidate : multiLabels) {
        PredictionCandidate predictionCandidate = new PredictionCandidate();
        predictionCandidate.x = x;
        predictionCandidate.labelProbs = marginals;
        predictionCandidate.multiLabel = candidate;
        predictionCandidate.sparseJoint = sparseJoint;
        Vector feature = predictionFeatureExtractor.extractFeatures(predictionCandidate);
        double score = vectorCalibrator.calibrate(feature);
        candidates.add(new Pair<>(candidate, score));
    }
    Comparator<Pair<MultiLabel, Double>> comparator = Comparator.comparing(pair -> pair.getSecond());
    return candidates.stream().sorted(comparator.reversed()).limit(top).collect(Collectors.toList());
}
Also used : BMDistribution(edu.neu.ccs.pyramid.multilabel_classification.cbm.BMDistribution) java.util(java.util) ArgSort(edu.neu.ccs.pyramid.util.ArgSort) DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming) ArgMax(edu.neu.ccs.pyramid.util.ArgMax) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) Vector(org.apache.mahout.math.Vector) CBM(edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM) Collectors(java.util.stream.Collectors) Pair(edu.neu.ccs.pyramid.util.Pair) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming) Vector(org.apache.mahout.math.Vector) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 5 with DynamicProgramming

use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.

the class TopKFinder method topKinSupport.

public static List<Pair<MultiLabel, Double>> topKinSupport(Vector x, MultiLabelClassifier.ClassProbEstimator classProbEstimator, LabelCalibrator labelCalibrator, VectorCalibrator vectorCalibrator, PredictionFeatureExtractor predictionFeatureExtractor, List<MultiLabel> support, int top) {
    double[] marginals = labelCalibrator.calibratedClassProbs(classProbEstimator.predictClassProbs(x));
    DynamicProgramming dynamicProgramming = new DynamicProgramming(marginals);
    // todo better
    List<Pair<MultiLabel, Double>> sparseJoint = dynamicProgramming.topK(50);
    List<Pair<MultiLabel, Double>> list = new ArrayList<>();
    for (MultiLabel candidate : support) {
        PredictionCandidate predictionCandidate = new PredictionCandidate();
        predictionCandidate.x = x;
        predictionCandidate.labelProbs = marginals;
        predictionCandidate.multiLabel = candidate;
        predictionCandidate.sparseJoint = sparseJoint;
        Vector feature = predictionFeatureExtractor.extractFeatures(predictionCandidate);
        double pro = vectorCalibrator.calibrate(feature);
        list.add(new Pair<>(candidate, pro));
    }
    Comparator<Pair<MultiLabel, Double>> comparator = Comparator.comparing(pair -> pair.getSecond());
    return list.stream().sorted(comparator.reversed()).limit(top).collect(Collectors.toList());
}
Also used : DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) Vector(org.apache.mahout.math.Vector) Pair(edu.neu.ccs.pyramid.util.Pair)

Aggregations

DynamicProgramming (edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming)12 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)8 Pair (edu.neu.ccs.pyramid.util.Pair)8 Vector (org.apache.mahout.math.Vector)7 BMDistribution (edu.neu.ccs.pyramid.multilabel_classification.cbm.BMDistribution)3 java.util (java.util)3 Collectors (java.util.stream.Collectors)3 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)2 CBM (edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM)2 ArgMax (edu.neu.ccs.pyramid.util.ArgMax)2 ArgSort (edu.neu.ccs.pyramid.util.ArgSort)2 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)1 LabelTranslator (edu.neu.ccs.pyramid.dataset.LabelTranslator)1 Feature (edu.neu.ccs.pyramid.feature.Feature)1 FeatureList (edu.neu.ccs.pyramid.feature.FeatureList)1 TopFeatures (edu.neu.ccs.pyramid.feature.TopFeatures)1 MultiLabelPredictionAnalysis (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis)1 PluginPredictor (edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor)1 GeneralF1Predictor (edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)1 edu.neu.ccs.pyramid.regression (edu.neu.ccs.pyramid.regression)1