Search in sources :

Example 51 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class MLLogisticRegression method predictClassProbs.

/**
     * expensive operation
     * @param vector
     * @return
     */
public double[] predictClassProbs(Vector vector) {
    double[] assignmentProbs = predictAssignmentProbs(vector);
    double[] classProbs = new double[numClasses];
    for (int a = 0; a < assignments.size(); a++) {
        MultiLabel assignment = assignments.get(a);
        double prob = assignmentProbs[a];
        for (Integer label : assignment.getMatchedLabels()) {
            double oldProb = classProbs[label];
            classProbs[label] = oldProb + prob;
        }
    }
    return classProbs;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 52 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class MLLogisticRegression method calClassProbs.

double[] calClassProbs(double[] assignmentProbs) {
    double[] classProbs = new double[numClasses];
    int numAssignments = assignments.size();
    for (int a = 0; a < numAssignments; a++) {
        MultiLabel assignment = assignments.get(a);
        double prob = assignmentProbs[a];
        for (Integer label : assignment.getMatchedLabels()) {
            classProbs[label] += prob;
        }
    }
    return classProbs;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 53 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class NoiseOptimizer method updateAlpha.

private void updateAlpha(int classIndex) {
    double numerator = 0;
    double denominator = 0;
    MultiLabel[] multiLabels = dataSet.getMultiLabels();
    for (int c = 0; c < combinations.size(); c++) {
        if (combinations.get(c).matchClass(classIndex)) {
            for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
                MultiLabel multiLabel = multiLabels[i];
                denominator += targets[i][c];
                if (multiLabel.matchClass(classIndex)) {
                    numerator += targets[i][c];
                }
            }
        }
    }
    alphas[classIndex] = SafeDivide.divide(numerator, denominator, 1);
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 54 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class HammingLossTest method main.

/**
     * Test example comes from:
     * http://scikit-learn.org/stable/modules/model_evaluation.html#hamming-loss
     * @param args
     */
public static void main(String[] args) {
    MultiLabel[] labels = new MultiLabel[1];
    MultiLabel[] predictions = new MultiLabel[1];
    labels[0] = new MultiLabel();
    predictions[0] = new MultiLabel();
    labels[0].addLabel(1);
    labels[0].addLabel(2);
    labels[0].addLabel(3);
    labels[0].addLabel(4);
    predictions[0].addLabel(2);
    predictions[0].addLabel(2);
    predictions[0].addLabel(3);
    predictions[0].addLabel(4);
    System.out.println("Expected (value=0.25) - Output: " + HammingLoss.hammingLoss(labels, predictions, 4));
    MultiLabel[] labels1 = new MultiLabel[2];
    MultiLabel[] predictions1 = new MultiLabel[2];
    labels1[0] = new MultiLabel();
    labels1[1] = new MultiLabel();
    predictions1[0] = new MultiLabel();
    predictions1[1] = new MultiLabel();
    labels1[0].addLabel(0);
    labels1[0].addLabel(1);
    labels1[1].addLabel(1);
    predictions1[0].addLabel(0);
    predictions1[1].addLabel(0);
    System.out.println("Expected (value=0.75) - Output: " + HammingLoss.hammingLoss(labels1, predictions1, 2));
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 55 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class MultiLabelSynthesizer method flipOne.

public static MultiLabelClfDataSet flipOne(int numData, int numFeature, int numClass) {
    MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
    // generate weights
    Vector[] weights = new Vector[numClass];
    for (int k = 0; k < numClass; k++) {
        Vector vector = new DenseVector(numFeature);
        for (int j = 0; j < numFeature; j++) {
            vector.set(j, Sampling.doubleUniform(-1, 1));
        }
        weights[k] = vector;
    }
    // generate features
    for (int i = 0; i < numData; i++) {
        for (int j = 0; j < numFeature; j++) {
            dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1, 1));
        }
    }
    // assign labels
    for (int i = 0; i < numData; i++) {
        for (int k = 0; k < numClass; k++) {
            double dot = weights[k].dot(dataSet.getRow(i));
            if (dot >= 0) {
                dataSet.addLabel(i, k);
            }
        }
    }
    // flip
    for (int i = 0; i < numData; i++) {
        int toChange = Sampling.intUniform(0, numClass - 1);
        MultiLabel label = dataSet.getMultiLabels()[i];
        if (label.matchClass(toChange)) {
            label.removeLabel(toChange);
        } else {
            label.addLabel(toChange);
        }
    }
    return dataSet;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) DenseVector(org.apache.mahout.math.DenseVector)

Aggregations

MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)101 Vector (org.apache.mahout.math.Vector)22 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)21 File (java.io.File)14 DenseVector (org.apache.mahout.math.DenseVector)13 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)12 Pair (edu.neu.ccs.pyramid.util.Pair)8 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)7 ArrayList (java.util.ArrayList)7 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)6 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)6 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)5 GeneralF1Predictor (edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)5 Collectors (java.util.stream.Collectors)5 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)4 java.util (java.util)4 StopWatch (org.apache.commons.lang3.time.StopWatch)4 Config (edu.neu.ccs.pyramid.configuration.Config)3 DataSetUtil (edu.neu.ccs.pyramid.dataset.DataSetUtil)3 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)3