Search in sources :

Example 11 with DynamicProgramming

use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.

the class CBMPredictor method predictByDynamic2.

public MultiLabel predictByDynamic2() {
    // initialization
    Map<Integer, DynamicProgramming> DPs = new HashMap<>();
    double[] maxClusterProb = new double[numClusters];
    for (int k = 0; k < numClusters; k++) {
        DPs.put(k, new DynamicProgramming(probs[k], logProbs[k]));
        maxClusterProb[k] = DPs.get(k).nextHighestProb();
    }
    // speed up:
    // 1) for pi^k (D^k - q) >= 1 - pi^k
    double[] cond1 = new double[numClusters];
    // 2) save condition for sum_{r!=k} (pi^r * D^r)
    double[] sumPiD = new double[numClusters];
    for (int k = 0; k < numClusters; k++) {
        cond1[k] = maxClusterProb[k] - 1.0 / pi[k] + 1;
        double sum = 0.0;
        for (int r = 0; r < numClusters; r++) {
            if (r == k) {
                continue;
            }
            sum += pi[r] * maxClusterProb[r];
        }
        sumPiD[k] = sum;
    }
    double maxLogProb = Double.NEGATIVE_INFINITY;
    MultiLabel bestMultiLabel = new MultiLabel();
    int iter = 0;
    int maxIter = 10;
    while (DPs.size() > 0) {
        List<Integer> removeList = new LinkedList<>();
        for (Map.Entry<Integer, DynamicProgramming> entry : DPs.entrySet()) {
            int k = entry.getKey();
            DynamicProgramming dp = entry.getValue();
            double prob = dp.nextHighestProb();
            MultiLabel multiLabel = dp.nextHighestVector();
            // whether consider empty prediction
            if ((multiLabel.getNumMatchedLabels() == 0) && !allowEmpty) {
                if (dp.getQueue().size() == 0) {
                    removeList.add(k);
                }
                continue;
            }
            double logProb = logProbYnGivenXnLogisticProb(multiLabel);
            if (logProb >= maxLogProb) {
                bestMultiLabel = multiLabel;
                maxLogProb = logProb;
            // maxIter = iter;
            }
            // check if need to remove cluster k from the candidates
            if (checkStop(prob, cond1[k], maxLogProb, sumPiD[k], k) || dp.getQueue().size() == 0) {
                removeList.add(k);
            }
        }
        for (int k : removeList) {
            DPs.remove(k);
        }
        iter++;
        if (iter >= maxIter) {
            break;
        }
    }
    return bestMultiLabel;
}
Also used : DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 12 with DynamicProgramming

use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.

the class CalibrationTest method pred.

private static String pred(CBM cbm, MultiLabelClfDataSet dataSet, int top) {
    StringBuilder stringBuilder = new StringBuilder();
    for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
        double[] marginals = cbm.predictClassProbs(dataSet.getRow(i));
        DynamicProgramming dynamicProgramming = new DynamicProgramming(marginals);
        BMDistribution bmDistribution = cbm.computeBM(dataSet.getRow(i), 0.001);
        for (int k = 0; k < top; k++) {
            MultiLabel multiLabel = dynamicProgramming.nextHighestVector();
            double score = bmDistribution.logProbability(multiLabel);
            stringBuilder.append("" + i + ": " + multiLabel.toSimpleString()).append(" (").append(score).append(")").append("\n");
        }
    }
    return stringBuilder.toString();
}
Also used : DynamicProgramming(edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) BMDistribution(edu.neu.ccs.pyramid.multilabel_classification.cbm.BMDistribution)

Aggregations

DynamicProgramming (edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming)12 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)8 Pair (edu.neu.ccs.pyramid.util.Pair)8 Vector (org.apache.mahout.math.Vector)7 BMDistribution (edu.neu.ccs.pyramid.multilabel_classification.cbm.BMDistribution)3 java.util (java.util)3 Collectors (java.util.stream.Collectors)3 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)2 CBM (edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM)2 ArgMax (edu.neu.ccs.pyramid.util.ArgMax)2 ArgSort (edu.neu.ccs.pyramid.util.ArgSort)2 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)1 LabelTranslator (edu.neu.ccs.pyramid.dataset.LabelTranslator)1 Feature (edu.neu.ccs.pyramid.feature.Feature)1 FeatureList (edu.neu.ccs.pyramid.feature.FeatureList)1 TopFeatures (edu.neu.ccs.pyramid.feature.TopFeatures)1 MultiLabelPredictionAnalysis (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis)1 PluginPredictor (edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor)1 GeneralF1Predictor (edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)1 edu.neu.ccs.pyramid.regression (edu.neu.ccs.pyramid.regression)1