Search in sources :

Example 6 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class Recall method recall.

@Deprecated
public static /**
     * From "Mining Multi-label Data by Grigorios Tsoumakas".
     * @param multiLabels
     * @param predictions
     * @return
     */
double recall(MultiLabel[] multiLabels, MultiLabel[] predictions) {
    double r = 0.0;
    for (int i = 0; i < multiLabels.length; i++) {
        MultiLabel label = multiLabels[i];
        MultiLabel prediction = predictions[i];
        if (label.getMatchedLabels().size() == 0) {
            r += 1.0;
        } else {
            r += MultiLabel.intersection(label, prediction).size() * 1.0 / label.getMatchedLabels().size();
        }
    }
    return r / multiLabels.length;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 7 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class MarginalPredictor method predict.

@Override
public MultiLabel predict(Vector vector) {
    BMDistribution bmDistribution = new BMDistribution(cbm, vector, piThreshold);
    double[] probs = bmDistribution.marginals();
    MultiLabel prediction = new MultiLabel();
    for (int l = 0; l < cbm.getNumClasses(); l++) {
        if (probs[l] > 0.5) {
            prediction.addLabel(l);
        }
    }
    return prediction;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 8 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class SparkCBMOptimizer method updateGamma.

private void updateGamma(int n) {
    Vector x = dataSet.getRow(n);
    MultiLabel y = dataSet.getMultiLabels()[n];
    double[] posterior = cbm.posteriorMembership(x, y);
    for (int k = 0; k < cbm.numComponents; k++) {
        gammas[n][k] = posterior[k];
        gammasT[k][n] = posterior[k];
    }
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) Vector(org.apache.mahout.math.Vector)

Example 9 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class BlockwiseCD method calEmpiricalCountForLabelPair.

private double calEmpiricalCountForLabelPair(int parameterIndex) {
    double empiricalCount = 0.0;
    int start = parameterIndex - numWeightsForFeatures;
    int l1 = parameterToL1[start];
    int l2 = parameterToL2[start];
    int featureCase = start % 4;
    for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
        MultiLabel label = dataSet.getMultiLabels()[i];
        switch(featureCase) {
            // both l1, l2 equal 0;
            case 0:
                if (!label.matchClass(l1) && !label.matchClass(l2))
                    empiricalCount += 1.0;
                break;
            // l1 = 1; l2 = 0;
            case 1:
                if (label.matchClass(l1) && !label.matchClass(l2))
                    empiricalCount += 1.0;
                break;
            // l1 = 0; l2 = 1;
            case 2:
                if (!label.matchClass(l1) && label.matchClass(l2))
                    empiricalCount += 1.0;
                break;
            // l1 = 1; l2 = 1;
            case 3:
                if (label.matchClass(l1) && label.matchClass(l2))
                    empiricalCount += 1.0;
                break;
            default:
                throw new RuntimeException("feature case :" + featureCase + " failed.");
        }
    }
    return empiricalCount;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 10 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CMLCRF method predictByArgmax.

public MultiLabel predictByArgmax(Vector vector) {
    double[] scores = predictClassScores(vector);
    MultiLabel label = new MultiLabel();
    for (int l = 0; l < scores.length; l++) {
        if (scores[l] > 0) {
            label.addLabel(l);
        }
    }
    return label;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Aggregations

MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)101 Vector (org.apache.mahout.math.Vector)22 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)21 File (java.io.File)14 DenseVector (org.apache.mahout.math.DenseVector)13 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)12 Pair (edu.neu.ccs.pyramid.util.Pair)8 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)7 ArrayList (java.util.ArrayList)7 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)6 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)6 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)5 GeneralF1Predictor (edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)5 Collectors (java.util.stream.Collectors)5 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)4 java.util (java.util)4 StopWatch (org.apache.commons.lang3.time.StopWatch)4 Config (edu.neu.ccs.pyramid.configuration.Config)3 DataSetUtil (edu.neu.ccs.pyramid.dataset.DataSetUtil)3 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)3