use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class GeneralF1Predictor method bestWithLengthK.
private Pair<MultiLabel, Double> bestWithLengthK(double[] deltaVector, int k) {
int[] sortedIndcies = ArgSort.argSortDescending(deltaVector);
MultiLabel multiLabel = new MultiLabel();
double score = 0;
for (int i = 0; i < k; i++) {
int label = sortedIndcies[i];
multiLabel.addLabel(label);
score += deltaVector[label];
}
return new Pair<>(multiLabel, score);
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class TunedMarginalClassifier method predict.
@Override
public MultiLabel predict(Vector vector) {
MultiLabel multiLabel = new MultiLabel();
int numClasses = classProbEstimator.getNumClasses();
double[] probs = classProbEstimator.predictClassProbs(vector);
for (int l = 0; l < numClasses; l++) {
if (probs[l] > thresholds[l]) {
multiLabel.addLabel(l);
}
}
return multiLabel;
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class MLLogisticRegression method predict.
@Override
public MultiLabel predict(Vector vector) {
double maxScore = Double.NEGATIVE_INFINITY;
MultiLabel prediction = null;
double[] classeScores = predictClassScores(vector);
for (MultiLabel assignment : this.assignments) {
double score = this.calAssignmentScore(assignment, classeScores);
if (score > maxScore) {
maxScore = score;
prediction = assignment;
}
}
return prediction;
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class AdaBoostMHInspector method analyzePrediction.
/**
* can be binary scaling or across-class scaling
* @param boosting
* @param scaling
* @param dataSet
* @param dataPointIndex
* @param classes
* @param limit
* @return
*/
public static MultiLabelPredictionAnalysis analyzePrediction(AdaBoostMH boosting, MultiLabelClassifier.ClassProbEstimator scaling, MultiLabelClfDataSet dataSet, int dataPointIndex, List<Integer> classes, int limit) {
MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
IdTranslator idTranslator = dataSet.getIdTranslator();
predictionAnalysis.setInternalId(dataPointIndex);
predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setLabels(labels);
double probForTrueLabels = Double.NaN;
if (scaling instanceof MultiLabelClassifier.AssignmentProbEstimator) {
probForTrueLabels = ((MultiLabelClassifier.AssignmentProbEstimator) scaling).predictAssignmentProb(dataSet.getRow(dataPointIndex), dataSet.getMultiLabels()[dataPointIndex]);
}
predictionAnalysis.setProbForTrueLabels(probForTrueLabels);
MultiLabel predictedLabels = boosting.predict(dataSet.getRow(dataPointIndex));
List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
predictionAnalysis.setInternalPrediction(internalPrediction);
List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setPrediction(prediction);
double probForPredictedLabels = Double.NaN;
if (scaling instanceof MultiLabelClassifier.AssignmentProbEstimator) {
probForPredictedLabels = ((MultiLabelClassifier.AssignmentProbEstimator) scaling).predictAssignmentProb(dataSet.getRow(dataPointIndex), predictedLabels);
}
predictionAnalysis.setProbForPredictedLabels(probForPredictedLabels);
List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
for (int k : classes) {
ClassScoreCalculation classScoreCalculation = decisionProcess(boosting, scaling, labelTranslator, dataSet.getRow(dataPointIndex), k, limit);
classScoreCalculations.add(classScoreCalculation);
}
predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
List<MultiLabelPredictionAnalysis.ClassRankInfo> ranking = classes.stream().map(label -> {
MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
rankInfo.setClassIndex(label);
rankInfo.setClassName(labelTranslator.toExtLabel(label));
rankInfo.setProb(scaling.predictClassProb(dataSet.getRow(dataPointIndex), label));
return rankInfo;
}).collect(Collectors.toList());
return predictionAnalysis;
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CBM method predictLogAssignmentProbsAsList.
public List<Double> predictLogAssignmentProbsAsList(Vector x, List<MultiLabel> assignments) {
BMDistribution bmDistribution = computeBM(x);
// support prediction within each component
// BMDistribution bmDistribution = new BMDistribution(this, x, assignments);
List<Double> probs = new ArrayList<>();
for (MultiLabel multiLabel : assignments) {
probs.add(bmDistribution.logProbability(multiLabel));
}
return probs;
}
Aggregations