Search in sources :

Example 76 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class FMeasure method f1.

@Deprecated
public static /**
     * From "Mining Multi-label Data by Grigorios Tsoumakas".
     * @param multiLabels
     * @param predictions
     * @return
     */
double f1(MultiLabel[] multiLabels, MultiLabel[] predictions) {
    double f = 0.0;
    for (int i = 0; i < multiLabels.length; i++) {
        MultiLabel label = multiLabels[i];
        MultiLabel prediction = predictions[i];
        f += MultiLabel.intersection(label, prediction).size() * 2.0 / (label.getMatchedLabels().size() + prediction.getMatchedLabels().size());
    }
    return f / multiLabels.length;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 77 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class KLDivergence method kl_conditional.

// empirical KL
public static double kl_conditional(MultiLabelClassifier.AssignmentProbEstimator multiLabelClassifier, MultiLabelClfDataSet dataSet) {
    Map<MultiLabel, Integer> q_z = new HashMap<MultiLabel, Integer>();
    Map<MultiLabel, HashMap<MultiLabel, Integer>> q_yz = new HashMap<MultiLabel, HashMap<MultiLabel, Integer>>();
    // get overall empirical distribution
    for (int i = 0; i < dataSet.getNumDataPoints(); ++i) {
        MultiLabel z = new MultiLabel(dataSet.getRow(i));
        MultiLabel y = dataSet.getMultiLabels()[i];
        if (q_z.containsKey(z)) {
            q_z.put(z, q_z.get(z) + 1);
        } else {
            q_z.put(z, 1);
        }
        if (!q_yz.containsKey(z)) {
            q_yz.put(z, new HashMap<MultiLabel, Integer>());
        }
        if (q_yz.get(z).containsKey(y)) {
            q_yz.get(z).put(y, q_yz.get(z).get(y) + 1);
        } else {
            q_yz.get(z).put(y, 1);
        }
    }
    // compute kl divergence
    double kl = 0.0;
    for (Map.Entry<MultiLabel, Integer> e1 : q_z.entrySet()) {
        double kl_y = 0.0;
        for (Map.Entry<MultiLabel, Integer> e2 : q_yz.get(e1.getKey()).entrySet()) {
            double empirical_prob_yz = (double) e2.getValue() / (double) e1.getValue();
            double log_estimated_prob_yz = multiLabelClassifier.predictLogAssignmentProb(e1.getKey().toVector(dataSet.getNumFeatures()), e2.getKey());
            kl_y += empirical_prob_yz * (Math.log(empirical_prob_yz) - log_estimated_prob_yz);
        }
        double empirical_prob_z = (double) e1.getValue() / (double) dataSet.getNumDataPoints();
        kl += empirical_prob_z * kl_y;
    }
    // Printing information if needed
    int occur_threshold = 10;
    double marginal_threshold = 0.01;
    for (Map.Entry<MultiLabel, Integer> e1 : q_z.entrySet()) {
        double[] marginals1 = new double[dataSet.getNumFeatures()];
        for (Map.Entry<MultiLabel, Integer> e2 : q_yz.get(e1.getKey()).entrySet()) {
            double estimated_prob_yz = multiLabelClassifier.predictAssignmentProb(e1.getKey().toVector(dataSet.getNumFeatures()), e2.getKey());
            double empirical_prob_yz = (double) e2.getValue() / (double) e1.getValue();
            if (e1.getValue() >= occur_threshold) {
                System.out.println("#z:" + e1.getValue() + ",z=" + e1.getKey().toStringWithExtLabels(dataSet.getLabelTranslator()) + "->{" + e2.getKey().toStringWithExtLabels(dataSet.getLabelTranslator()) + "},#y:" + e2.getValue() + ",p_y|z_empirical:" + empirical_prob_yz + ",p_y|z_estimated:" + estimated_prob_yz);
            }
            for (int i = 0; i < dataSet.getNumFeatures(); i++) {
                if (e2.getKey().matchClass(i)) {
                    marginals1[i] += e2.getValue();
                }
            }
        }
        if (e1.getValue() >= occur_threshold) {
            double estimated_prob_zz = multiLabelClassifier.predictAssignmentProb(e1.getKey().toVector(dataSet.getNumFeatures()), e1.getKey());
            System.out.println("p(y=z|z)=" + estimated_prob_zz);
            CBM cbm = (CBM) multiLabelClassifier;
            //                List<MultiLabel> sampled = cbm.samples(e1.getKey().toVector(dataSet.getNumFeatures()), 10);
            //                for (int i = 0; i < sampled.size(); ++i) {
            //                    double prob = multiLabelClassifier.predictAssignmentProb(e1.getKey().toVector(dataSet.getNumFeatures()), sampled.get(i));
            //                    System.out.println(sampled.get(i).toStringWithExtLabels(dataSet.getLabelTranslator()) + ":" + prob);
            //                }
            System.out.println("p_y|z_estimated marginals are: ");
            double[] marginals = cbm.predictClassProbs(e1.getKey().toVector(dataSet.getNumFeatures()));
            int[] order = ArgSort.argSortDescending(marginals);
            for (int i = 0; i < order.length; ++i) {
                if (marginals[order[i]] > marginal_threshold) {
                    System.out.println(dataSet.getLabelTranslator().toExtLabel(order[i]) + ":" + marginals[order[i]]);
                }
            }
            System.out.println("p_y|z_empirical marginals are: ");
            for (int i = 0; i < dataSet.getNumFeatures(); i++) {
                marginals1[i] /= (double) e1.getValue();
            }
            int[] order1 = ArgSort.argSortDescending(marginals1);
            for (int i = 0; i < order1.length; ++i) {
                if (marginals1[order1[i]] > marginal_threshold) {
                    System.out.println(dataSet.getLabelTranslator().toExtLabel(order1[i]) + ":" + marginals1[order1[i]]);
                }
            }
        }
    }
    System.out.println("LRs for each label:");
    CBM cbm = (CBM) multiLabelClassifier;
    Classifier.ProbabilityEstimator[] estimators = cbm.getBinaryClassifiers()[0];
    for (int i = 0; i < estimators.length; i++) {
        System.out.println("LR:" + dataSet.getLabelTranslator().toExtLabel(i));
        LogisticRegression lr = (LogisticRegression) estimators[i];
        Vector weight_vec = lr.getWeights().getWeightsWithoutBiasForClass(1);
        double[] weights = new double[weight_vec.size()];
        for (int j = 0; j < weight_vec.size(); j++) {
            weights[j] = weight_vec.get(j);
        }
        System.out.println("bias:" + lr.getWeights().getBiasForClass(1));
        int[] order2 = ArgSort.argSortDescending(weights);
        for (int j = 0; j < order2.length; ++j) {
            System.out.println(dataSet.getLabelTranslator().toExtLabel(order2[j]) + ":" + weights[order2[j]]);
        }
    }
    System.out.println("---");
    return kl;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) HashMap(java.util.HashMap) CBM(edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) HashMap(java.util.HashMap) Map(java.util.Map) Vector(org.apache.mahout.math.Vector)

Example 78 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class DynamicProgramming method flipLabels.

/**
     * flip each bit in given multiLabel, and calculate its
     * log probability, if it is not cached yet, put it into
     * the max queue.
     * @param data
     */
private void flipLabels(Candidate data) {
    double prevlogProb = data.logProbability;
    MultiLabel multiLabel = data.multiLabel;
    // only flip uncertain labels
    for (int l : uncertainLabels) {
        MultiLabel flipped = multiLabel.copy();
        flipped.flipLabel(l);
        double logProb;
        if (flipped.matchClass(l)) {
            logProb = prevlogProb - this.logProbs[l][0] + this.logProbs[l][1];
        } else {
            logProb = prevlogProb - this.logProbs[l][1] + this.logProbs[l][0];
        }
        if (!cache.contains(flipped)) {
            queue.add(new Candidate(flipped, logProb));
            cache.add(flipped);
        }
    }
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 79 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class DynamicProgramming method nextHighest.

public Candidate nextHighest() {
    if (queue.size() > 0) {
        flipLabels(queue.peek());
        return queue.poll();
    }
    MultiLabel multiLabel = new MultiLabel();
    Candidate candidate = new Candidate(multiLabel, Double.NEGATIVE_INFINITY);
    return candidate;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 80 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class Accuracy method partialAccuracy.

/**
     * proportion of the predicted correct labels to the total number of labels for that instance.
     * @param multiLabels
     * @param predictions
     * @return
     */
public static double partialAccuracy(MultiLabel[] multiLabels, MultiLabel[] predictions) {
    double a = 0.0;
    for (int i = 0; i < multiLabels.length; i++) {
        MultiLabel label = multiLabels[i];
        MultiLabel prediction = predictions[i];
        a += MultiLabel.intersection(label, prediction).size() * 1.0 / MultiLabel.union(label, prediction).size();
    }
    return a / multiLabels.length;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Aggregations

MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)101 Vector (org.apache.mahout.math.Vector)22 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)21 File (java.io.File)14 DenseVector (org.apache.mahout.math.DenseVector)13 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)12 Pair (edu.neu.ccs.pyramid.util.Pair)8 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)7 ArrayList (java.util.ArrayList)7 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)6 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)6 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)5 GeneralF1Predictor (edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)5 Collectors (java.util.stream.Collectors)5 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)4 java.util (java.util)4 StopWatch (org.apache.commons.lang3.time.StopWatch)4 Config (edu.neu.ccs.pyramid.configuration.Config)3 DataSetUtil (edu.neu.ccs.pyramid.dataset.DataSetUtil)3 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)3