Search in sources :

Example 66 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CBM method predictLogAssignmentProbs.

/**
     * for batch jobs, use this to save computation
     * @param x
     * @param assignments
     * @return
     */
public double[] predictLogAssignmentProbs(Vector x, List<MultiLabel> assignments) {
    BMDistribution bmDistribution = computeBM(x);
    double[] probs = new double[assignments.size()];
    for (int c = 0; c < assignments.size(); c++) {
        MultiLabel multiLabel = assignments.get(c);
        probs[c] = bmDistribution.logProbability(multiLabel);
    }
    return probs;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 67 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CBM method predictByMarginals.

/**
     * predict Sign(E(y|x))
     * @param vector
     * @return
     */
MultiLabel predictByMarginals(Vector vector) {
    double[] probs = predictClassProbs(vector);
    MultiLabel prediction = new MultiLabel();
    for (int l = 0; l < numLabels; l++) {
        if (probs[l] > 0.5) {
            prediction.addLabel(l);
        }
    }
    return prediction;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 68 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class BMDistribution method sample.

List<MultiLabel> sample(int numSamples) {
    List<MultiLabel> list = new ArrayList<>();
    double[] proportions = Arrays.stream(logProportions).map(Math::exp).toArray();
    double[][] classProbs = new double[numComponents][numLabels];
    for (int k = 0; k < numComponents; k++) {
        for (int l = 0; l < numLabels; l++) {
            classProbs[k][l] = Math.exp(logClassProbs[k][l][1]);
        }
    }
    int[] components = IntStream.range(0, numComponents).toArray();
    EnumeratedIntegerDistribution enumeratedIntegerDistribution = new EnumeratedIntegerDistribution(components, proportions);
    BernoulliDistribution[][] bernoulliDistributions = new BernoulliDistribution[numComponents][numLabels];
    for (int k = 0; k < numComponents; k++) {
        for (int l = 0; l < numLabels; l++) {
            bernoulliDistributions[k][l] = new BernoulliDistribution(classProbs[k][l]);
        }
    }
    for (int num = 0; num < numSamples; num++) {
        MultiLabel multiLabel = new MultiLabel();
        int k = enumeratedIntegerDistribution.sample();
        for (int l = 0; l < numLabels; l++) {
            int v = bernoulliDistributions[k][l].sample();
            if (v == 1) {
                multiLabel.addLabel(l);
            }
        }
        list.add(multiLabel);
    }
    return list;
}
Also used : EnumeratedIntegerDistribution(org.apache.commons.math3.distribution.EnumeratedIntegerDistribution) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) BernoulliDistribution(edu.neu.ccs.pyramid.util.BernoulliDistribution) ArrayList(java.util.ArrayList)

Example 69 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class Enumerator method enumerate.

public static List<MultiLabel> enumerate(int numClasses) {
    List<MultiLabel> list = new ArrayList<>();
    int size = (int) Math.pow(2, numClasses);
    for (int i = 0; i < size; i++) {
        String ib = toBinary(i, numClasses);
        MultiLabel multiLabel = toML(ib);
        list.add(multiLabel);
    }
    return list;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) ArrayList(java.util.ArrayList)

Example 70 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class LossMatrixGenerator method matrix.

public static Matrix matrix(int n, String lossName) {
    int size = (int) Math.pow(2, n);
    double[][] matrixBuilder = new double[size][size];
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < size; j++) {
            String ib = toBinary(i, n);
            String jb = toBinary(j, n);
            MultiLabel multiLabel1 = toML(ib);
            MultiLabel multiLabel2 = toML(jb);
            MultiLabel[] trueLabels = { multiLabel1 };
            MultiLabel[] predicted = { multiLabel2 };
            MLConfusionMatrix mlConfusionMatrix = new MLConfusionMatrix(n, trueLabels, predicted);
            InstanceAverage instanceAverage = new InstanceAverage(mlConfusionMatrix);
            double loss;
            switch(lossName.toLowerCase()) {
                case "hamming":
                    loss = instanceAverage.getHammingLoss() * n;
                    break;
                case "overlap":
                    loss = 1 - instanceAverage.getOverlap();
                    break;
                case "accuracy":
                    loss = 1 - instanceAverage.getAccuracy();
                    break;
                case "precision":
                    loss = 1 - instanceAverage.getPrecision();
                    break;
                case "recall":
                    loss = 1 - instanceAverage.getRecall();
                    break;
                case "f1":
                    loss = 1 - instanceAverage.getF1();
                    break;
                default:
                    throw new IllegalArgumentException("unknown loss");
            }
            matrixBuilder[i][j] = loss;
        }
    }
    Matrix matrix = new DenseMatrix(matrixBuilder);
    return matrix;
}
Also used : MLConfusionMatrix(edu.neu.ccs.pyramid.eval.MLConfusionMatrix) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DenseMatrix(org.apache.mahout.math.DenseMatrix) MLConfusionMatrix(edu.neu.ccs.pyramid.eval.MLConfusionMatrix) Matrix(org.apache.mahout.math.Matrix) InstanceAverage(edu.neu.ccs.pyramid.eval.InstanceAverage) DenseMatrix(org.apache.mahout.math.DenseMatrix)

Aggregations

MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)101 Vector (org.apache.mahout.math.Vector)22 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)21 File (java.io.File)14 DenseVector (org.apache.mahout.math.DenseVector)13 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)12 Pair (edu.neu.ccs.pyramid.util.Pair)8 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)7 ArrayList (java.util.ArrayList)7 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)6 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)6 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)5 GeneralF1Predictor (edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)5 Collectors (java.util.stream.Collectors)5 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)4 java.util (java.util)4 StopWatch (org.apache.commons.lang3.time.StopWatch)4 Config (edu.neu.ccs.pyramid.configuration.Config)3 DataSetUtil (edu.neu.ccs.pyramid.dataset.DataSetUtil)3 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)3