Search in sources :

Example 46 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class IMLLogisticLoss method calEmpricalCount.

private double calEmpricalCount(int parameterIndex) {
    int classIndex = logisticRegression.getWeights().getClassIndex(parameterIndex);
    MultiLabel[] labels = dataSet.getMultiLabels();
    int featureIndex = logisticRegression.getWeights().getFeatureIndex(parameterIndex);
    double count = 0;
    //bias
    if (featureIndex == -1) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            if (labels[i].matchClass(classIndex)) {
                count += 1;
            }
        }
    } else {
        Vector featureColumn = dataSet.getColumn(featureIndex);
        for (Vector.Element element : featureColumn.nonZeroes()) {
            int dataPointIndex = element.index();
            double featureValue = element.get();
            MultiLabel label = labels[dataPointIndex];
            if (label.matchClass(classIndex)) {
                count += featureValue;
            }
        }
    }
    return count;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector)

Example 47 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class IMLLogisticRegression method predictWithConstraints.

/**
     * only consider these assignments
     * @param vector
     * @return
     */
private MultiLabel predictWithConstraints(Vector vector) {
    double maxScore = Double.NEGATIVE_INFINITY;
    MultiLabel prediction = null;
    double[] classScores = predictClassScores(vector);
    for (MultiLabel assignment : this.assignments) {
        double score = this.calAssignmentScore(assignment, classScores);
        if (score > maxScore) {
            maxScore = score;
            prediction = assignment;
        }
    }
    return prediction;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 48 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class GeneralF1Predictor method exhaustiveSearch.

public static MultiLabel exhaustiveSearch(int numClasses, Matrix lossMatrix, List<Double> probabilities) {
    double bestScore = Double.POSITIVE_INFINITY;
    Vector vector = new DenseVector(probabilities.size());
    for (int i = 0; i < vector.size(); i++) {
        vector.set(i, probabilities.get(i));
    }
    List<MultiLabel> multiLabels = Enumerator.enumerate(numClasses);
    MultiLabel multiLabel = null;
    for (int j = 0; j < lossMatrix.numCols(); j++) {
        Vector column = lossMatrix.viewColumn(j);
        double score = column.dot(vector);
        System.out.println("column " + j + ", expected loss = " + score);
        if (score < bestScore) {
            bestScore = score;
            multiLabel = multiLabels.get(j);
        }
    }
    return multiLabel;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector) DenseVector(org.apache.mahout.math.DenseVector)

Example 49 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class InstanceF1Predictor method predict.

@Override
public MultiLabel predict(Vector vector) {
    double[] probs = imlGradientBoosting.predictAllAssignmentProbsWithConstraint(vector);
    List<Double> probList = Arrays.stream(probs).mapToObj(a -> a).collect(Collectors.toList());
    GeneralF1Predictor generalF1Predictor = new GeneralF1Predictor();
    return generalF1Predictor.predict(imlGradientBoosting.getNumClasses(), imlGradientBoosting.getAssignments(), probList);
}
Also used : Arrays(java.util.Arrays) List(java.util.List) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) Vector(org.apache.mahout.math.Vector) PluginPredictor(edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor) Collectors(java.util.stream.Collectors) GeneralF1Predictor(edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor) GeneralF1Predictor(edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)

Example 50 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class MLLogisticLoss method calEmpricalCount.

private double calEmpricalCount(int parameterIndex) {
    int classIndex = mlLogisticRegression.getWeights().getClassIndex(parameterIndex);
    MultiLabel[] labels = dataSet.getMultiLabels();
    int featureIndex = mlLogisticRegression.getWeights().getFeatureIndex(parameterIndex);
    double count = 0;
    //bias
    if (featureIndex == -1) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            if (labels[i].matchClass(classIndex)) {
                count += 1;
            }
        }
    } else {
        Vector featureColumn = dataSet.getColumn(featureIndex);
        for (Vector.Element element : featureColumn.nonZeroes()) {
            int dataPointIndex = element.index();
            double featureValue = element.get();
            MultiLabel label = labels[dataPointIndex];
            if (label.matchClass(classIndex)) {
                count += featureValue;
            }
        }
    }
    return count;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector)

Aggregations

MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)101 Vector (org.apache.mahout.math.Vector)22 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)21 File (java.io.File)14 DenseVector (org.apache.mahout.math.DenseVector)13 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)12 Pair (edu.neu.ccs.pyramid.util.Pair)8 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)7 ArrayList (java.util.ArrayList)7 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)6 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)6 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)5 GeneralF1Predictor (edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)5 Collectors (java.util.stream.Collectors)5 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)4 java.util (java.util)4 StopWatch (org.apache.commons.lang3.time.StopWatch)4 Config (edu.neu.ccs.pyramid.configuration.Config)3 DataSetUtil (edu.neu.ccs.pyramid.dataset.DataSetUtil)3 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)3