use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.
the class CBMPredictor method predictByDynamic.
public MultiLabel predictByDynamic() {
// initialization
DynamicProgramming[] DPs = new DynamicProgramming[numClusters];
double[] minClusterProb = new double[numClusters];
for (int k = 0; k < numClusters; k++) {
DPs[k] = new DynamicProgramming(probs[k], logProbs[k]);
minClusterProb[k] = Double.POSITIVE_INFINITY;
}
double maxLogProb = Double.NEGATIVE_INFINITY;
MultiLabel bestMultiLabel = new MultiLabel();
int maxIter = 10;
for (int iter = 0; iter < maxIter; iter++) {
for (int k = 0; k < numClusters; k++) {
DynamicProgramming dp = DPs[k];
double prob = dp.nextHighestProb();
MultiLabel multiLabel = dp.nextHighestVector();
minClusterProb[k] = prob;
double threshold = computeThreshold(minClusterProb);
boolean isCandidateValid = true;
if ((multiLabel.getNumMatchedLabels() == 0) && !allowEmpty) {
isCandidateValid = false;
}
if (isCandidateValid) {
double logProb = logProbYnGivenXnLogisticProb(multiLabel);
if (logProb >= maxLogProb) {
bestMultiLabel = multiLabel;
maxLogProb = logProb;
}
}
if (Math.exp(maxLogProb) >= threshold) {
return bestMultiLabel;
}
}
}
return bestMultiLabel;
}
use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.
the class SupportPredictor method predict.
@Override
public MultiLabel predict(Vector vector) {
double[] uncali = classifier.predictClassProbs(vector);
double[] marginals = labelCalibrator.calibratedClassProbs(uncali);
List<Pair<MultiLabel, Double>> candidates = new ArrayList<>();
DynamicProgramming dynamicProgramming = new DynamicProgramming(marginals);
List<Pair<MultiLabel, Double>> sparseJoint = dynamicProgramming.topK(50);
for (MultiLabel candidate : support) {
PredictionCandidate predictionCandidate = new PredictionCandidate();
predictionCandidate.x = vector;
predictionCandidate.labelProbs = marginals;
predictionCandidate.multiLabel = candidate;
predictionCandidate.sparseJoint = sparseJoint;
Vector feature = predictionFeatureExtractor.extractFeatures(predictionCandidate);
double score = setCalibrator.calibrate(feature);
candidates.add(new Pair<>(candidate, score));
}
Comparator<Pair<MultiLabel, Double>> comparator = Comparator.comparing(Pair::getSecond);
return candidates.stream().max(comparator).map(Pair::getFirst).get();
}
use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.
the class Reranker method prob.
public double prob(Vector vector, MultiLabel multiLabel) {
double[] marginals = labelCalibrator.calibratedClassProbs(classProbEstimator.predictClassProbs(vector));
DynamicProgramming dynamicProgramming = new DynamicProgramming(marginals);
List<Pair<MultiLabel, Double>> topK = dynamicProgramming.topK(numCandidate);
PredictionCandidate predictionCandidate = new PredictionCandidate();
predictionCandidate.x = vector;
predictionCandidate.labelProbs = marginals;
predictionCandidate.multiLabel = multiLabel;
predictionCandidate.sparseJoint = topK;
Vector feature = predictionFeatureExtractor.extractFeatures(predictionCandidate);
double score = regressor.predict(feature);
if (score > 1) {
score = 1;
}
if (score < 0) {
score = 0;
}
return score;
}
use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.
the class TopKFinder method topK.
public static List<Pair<MultiLabel, Double>> topK(Vector x, MultiLabelClassifier.ClassProbEstimator classProbEstimator, LabelCalibrator labelCalibrator, VectorCalibrator vectorCalibrator, PredictionFeatureExtractor predictionFeatureExtractor, int minSetSize, int maxSetSize, int top) {
double[] uncalibratedMarginals = classProbEstimator.predictClassProbs(x);
double[] marginals = labelCalibrator.calibratedClassProbs(uncalibratedMarginals);
DynamicProgramming dynamicProgramming = new DynamicProgramming(marginals);
// todo better
List<Pair<MultiLabel, Double>> sparseJoint = dynamicProgramming.topK(50);
List<MultiLabel> multiLabels = sparseJoint.stream().map(pair -> pair.getFirst()).filter(candidate -> candidate.getNumMatchedLabels() >= minSetSize && candidate.getNumMatchedLabels() <= maxSetSize).collect(Collectors.toList());
List<Pair<MultiLabel, Double>> candidates = new ArrayList<>();
if (multiLabels.isEmpty()) {
int[] sorted = ArgSort.argSortDescending(marginals);
MultiLabel multiLabel = new MultiLabel();
for (int i = 0; i < minSetSize; i++) {
multiLabel.addLabel(sorted[i]);
}
multiLabels.add(multiLabel);
}
for (MultiLabel candidate : multiLabels) {
PredictionCandidate predictionCandidate = new PredictionCandidate();
predictionCandidate.x = x;
predictionCandidate.labelProbs = marginals;
predictionCandidate.multiLabel = candidate;
predictionCandidate.sparseJoint = sparseJoint;
Vector feature = predictionFeatureExtractor.extractFeatures(predictionCandidate);
double score = vectorCalibrator.calibrate(feature);
candidates.add(new Pair<>(candidate, score));
}
Comparator<Pair<MultiLabel, Double>> comparator = Comparator.comparing(pair -> pair.getSecond());
return candidates.stream().sorted(comparator.reversed()).limit(top).collect(Collectors.toList());
}
use of edu.neu.ccs.pyramid.multilabel_classification.DynamicProgramming in project pyramid by cheng-li.
the class TopKFinder method topKinSupport.
public static List<Pair<MultiLabel, Double>> topKinSupport(Vector x, MultiLabelClassifier.ClassProbEstimator classProbEstimator, LabelCalibrator labelCalibrator, VectorCalibrator vectorCalibrator, PredictionFeatureExtractor predictionFeatureExtractor, List<MultiLabel> support, int top) {
double[] marginals = labelCalibrator.calibratedClassProbs(classProbEstimator.predictClassProbs(x));
DynamicProgramming dynamicProgramming = new DynamicProgramming(marginals);
// todo better
List<Pair<MultiLabel, Double>> sparseJoint = dynamicProgramming.topK(50);
List<Pair<MultiLabel, Double>> list = new ArrayList<>();
for (MultiLabel candidate : support) {
PredictionCandidate predictionCandidate = new PredictionCandidate();
predictionCandidate.x = x;
predictionCandidate.labelProbs = marginals;
predictionCandidate.multiLabel = candidate;
predictionCandidate.sparseJoint = sparseJoint;
Vector feature = predictionFeatureExtractor.extractFeatures(predictionCandidate);
double pro = vectorCalibrator.calibrate(feature);
list.add(new Pair<>(candidate, pro));
}
Comparator<Pair<MultiLabel, Double>> comparator = Comparator.comparing(pair -> pair.getSecond());
return list.stream().sorted(comparator.reversed()).limit(top).collect(Collectors.toList());
}
Aggregations