use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.
the class CBMPredictor method predictByDynamic2.
public MultiLabel predictByDynamic2() {
// initialization
Map<Integer, DynamicProgramming> DPs = new HashMap<>();
double[] maxClusterProb = new double[numClusters];
for (int k = 0; k < numClusters; k++) {
DPs.put(k, new DynamicProgramming(probs[k], logProbs[k]));
maxClusterProb[k] = DPs.get(k).nextHighestProb();
}
// speed up:
// 1) for pi^k (D^k - q) >= 1 - pi^k
double[] cond1 = new double[numClusters];
// 2) save condition for sum_{r!=k} (pi^r * D^r)
double[] sumPiD = new double[numClusters];
for (int k = 0; k < numClusters; k++) {
cond1[k] = maxClusterProb[k] - 1.0 / pi[k] + 1;
double sum = 0.0;
for (int r = 0; r < numClusters; r++) {
if (r == k) {
continue;
}
sum += pi[r] * maxClusterProb[r];
}
sumPiD[k] = sum;
}
double maxLogProb = Double.NEGATIVE_INFINITY;
MultiLabel bestMultiLabel = new MultiLabel();
int iter = 0;
int maxIter = 10;
while (DPs.size() > 0) {
List<Integer> removeList = new LinkedList<>();
for (Map.Entry<Integer, DynamicProgramming> entry : DPs.entrySet()) {
int k = entry.getKey();
DynamicProgramming dp = entry.getValue();
double prob = dp.nextHighestProb();
MultiLabel multiLabel = dp.nextHighestVector();
// whether consider empty prediction
if ((multiLabel.getNumMatchedLabels() == 0) && !allowEmpty) {
if (dp.getQueue().size() == 0) {
removeList.add(k);
}
continue;
}
double logProb = logProbYnGivenXnLogisticProb(multiLabel);
if (logProb >= maxLogProb) {
bestMultiLabel = multiLabel;
maxLogProb = logProb;
// maxIter = iter;
}
// check if need to remove cluster k from the candidates
if (checkStop(prob, cond1[k], maxLogProb, sumPiD[k], k) || dp.getQueue().size() == 0) {
removeList.add(k);
}
}
for (int k : removeList) {
DPs.remove(k);
}
iter++;
if (iter >= maxIter) {
break;
}
}
return bestMultiLabel;
}
use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.
the class CalibrationTest method pred.
private static String pred(CBM cbm, MultiLabelClfDataSet dataSet, int top) {
StringBuilder stringBuilder = new StringBuilder();
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
double[] marginals = cbm.predictClassProbs(dataSet.getRow(i));
DynamicProgramming dynamicProgramming = new DynamicProgramming(marginals);
BMDistribution bmDistribution = cbm.computeBM(dataSet.getRow(i), 0.001);
for (int k = 0; k < top; k++) {
MultiLabel multiLabel = dynamicProgramming.nextHighestVector();
double score = bmDistribution.logProbability(multiLabel);
stringBuilder.append("" + i + ": " + multiLabel.toSimpleString()).append(" (").append(score).append(")").append("\n");
}
}
return stringBuilder.toString();
}
Aggregations