Search in sources :

Example 61 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class GeneralF1Predictor method bestWithLengthK.

private Pair<MultiLabel, Double> bestWithLengthK(double[] deltaVector, int k) {
    int[] sortedIndcies = ArgSort.argSortDescending(deltaVector);
    MultiLabel multiLabel = new MultiLabel();
    double score = 0;
    for (int i = 0; i < k; i++) {
        int label = sortedIndcies[i];
        multiLabel.addLabel(label);
        score += deltaVector[label];
    }
    return new Pair<>(multiLabel, score);
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 62 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class TunedMarginalClassifier method predict.

@Override
public MultiLabel predict(Vector vector) {
    MultiLabel multiLabel = new MultiLabel();
    int numClasses = classProbEstimator.getNumClasses();
    double[] probs = classProbEstimator.predictClassProbs(vector);
    for (int l = 0; l < numClasses; l++) {
        if (probs[l] > thresholds[l]) {
            multiLabel.addLabel(l);
        }
    }
    return multiLabel;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 63 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class MLLogisticRegression method predict.

@Override
public MultiLabel predict(Vector vector) {
    double maxScore = Double.NEGATIVE_INFINITY;
    MultiLabel prediction = null;
    double[] classeScores = predictClassScores(vector);
    for (MultiLabel assignment : this.assignments) {
        double score = this.calAssignmentScore(assignment, classeScores);
        if (score > maxScore) {
            maxScore = score;
            prediction = assignment;
        }
    }
    return prediction;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 64 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class AdaBoostMHInspector method analyzePrediction.

/**
     * can be binary scaling or across-class scaling
     * @param boosting
     * @param scaling
     * @param dataSet
     * @param dataPointIndex
     * @param classes
     * @param limit
     * @return
     */
public static MultiLabelPredictionAnalysis analyzePrediction(AdaBoostMH boosting, MultiLabelClassifier.ClassProbEstimator scaling, MultiLabelClfDataSet dataSet, int dataPointIndex, List<Integer> classes, int limit) {
    MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
    LabelTranslator labelTranslator = dataSet.getLabelTranslator();
    IdTranslator idTranslator = dataSet.getIdTranslator();
    predictionAnalysis.setInternalId(dataPointIndex);
    predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
    predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
    List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setLabels(labels);
    double probForTrueLabels = Double.NaN;
    if (scaling instanceof MultiLabelClassifier.AssignmentProbEstimator) {
        probForTrueLabels = ((MultiLabelClassifier.AssignmentProbEstimator) scaling).predictAssignmentProb(dataSet.getRow(dataPointIndex), dataSet.getMultiLabels()[dataPointIndex]);
    }
    predictionAnalysis.setProbForTrueLabels(probForTrueLabels);
    MultiLabel predictedLabels = boosting.predict(dataSet.getRow(dataPointIndex));
    List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
    predictionAnalysis.setInternalPrediction(internalPrediction);
    List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
    predictionAnalysis.setPrediction(prediction);
    double probForPredictedLabels = Double.NaN;
    if (scaling instanceof MultiLabelClassifier.AssignmentProbEstimator) {
        probForPredictedLabels = ((MultiLabelClassifier.AssignmentProbEstimator) scaling).predictAssignmentProb(dataSet.getRow(dataPointIndex), predictedLabels);
    }
    predictionAnalysis.setProbForPredictedLabels(probForPredictedLabels);
    List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
    for (int k : classes) {
        ClassScoreCalculation classScoreCalculation = decisionProcess(boosting, scaling, labelTranslator, dataSet.getRow(dataPointIndex), k, limit);
        classScoreCalculations.add(classScoreCalculation);
    }
    predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
    List<MultiLabelPredictionAnalysis.ClassRankInfo> ranking = classes.stream().map(label -> {
        MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
        rankInfo.setClassIndex(label);
        rankInfo.setClassName(labelTranslator.toExtLabel(label));
        rankInfo.setProb(scaling.predictClassProb(dataSet.getRow(dataPointIndex), label));
        return rankInfo;
    }).collect(Collectors.toList());
    return predictionAnalysis;
}
Also used : MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) edu.neu.ccs.pyramid.regression(edu.neu.ccs.pyramid.regression) java.util(java.util) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) MLPlattScaling(edu.neu.ccs.pyramid.multilabel_classification.MLPlattScaling) RegTreeInspector(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeInspector) Collectors(java.util.stream.Collectors) IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting) Classifier(edu.neu.ccs.pyramid.classification.Classifier) MLACPlattScaling(edu.neu.ccs.pyramid.multilabel_classification.MLACPlattScaling) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) PlattScaling(edu.neu.ccs.pyramid.classification.PlattScaling) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) Feature(edu.neu.ccs.pyramid.feature.Feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Pair(edu.neu.ccs.pyramid.util.Pair) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator)

Example 65 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CBM method predictLogAssignmentProbsAsList.

public List<Double> predictLogAssignmentProbsAsList(Vector x, List<MultiLabel> assignments) {
    BMDistribution bmDistribution = computeBM(x);
    // support prediction within each component
    //        BMDistribution bmDistribution = new BMDistribution(this, x, assignments);
    List<Double> probs = new ArrayList<>();
    for (MultiLabel multiLabel : assignments) {
        probs.add(bmDistribution.logProbability(multiLabel));
    }
    return probs;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Aggregations

MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)101 Vector (org.apache.mahout.math.Vector)22 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)21 File (java.io.File)14 DenseVector (org.apache.mahout.math.DenseVector)13 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)12 Pair (edu.neu.ccs.pyramid.util.Pair)8 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)7 ArrayList (java.util.ArrayList)7 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)6 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)6 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)5 GeneralF1Predictor (edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)5 Collectors (java.util.stream.Collectors)5 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)4 java.util (java.util)4 StopWatch (org.apache.commons.lang3.time.StopWatch)4 Config (edu.neu.ccs.pyramid.configuration.Config)3 DataSetUtil (edu.neu.ccs.pyramid.dataset.DataSetUtil)3 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)3