Search in sources :

Example 86 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CRFF1Loss method calEmpiricalCountForFeature.

private double calEmpiricalCountForFeature(int parameterIndex) {
    double empiricalCount = 0.0;
    int classIndex = parameterToClass[parameterIndex];
    int featureIndex = parameterToFeature[parameterIndex];
    if (featureIndex == -1) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            if (dataSet.getMultiLabels()[i].matchClass(classIndex)) {
                empiricalCount += 1;
            }
        }
    } else {
        Vector column = dataSet.getColumn(featureIndex);
        MultiLabel[] multiLabels = dataSet.getMultiLabels();
        for (Vector.Element element : column.nonZeroes()) {
            int dataIndex = element.index();
            double featureValue = element.get();
            if (multiLabels[dataIndex].matchClass(classIndex)) {
                empiricalCount += featureValue;
            }
        }
    }
    return empiricalCount;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector)

Example 87 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CRFInspector method simplePredictionAnalysis.

public static String simplePredictionAnalysis(CMLCRF crf, PluginPredictor<CMLCRF> pluginPredictor, MultiLabelClfDataSet dataSet, int dataPointIndex, double classProbThreshold) {
    StringBuilder sb = new StringBuilder();
    MultiLabel trueLabels = dataSet.getMultiLabels()[dataPointIndex];
    String id = dataSet.getIdTranslator().toExtId(dataPointIndex);
    LabelTranslator labelTranslator = dataSet.getLabelTranslator();
    double[] combProbs = crf.predictCombinationProbs(dataSet.getRow(dataPointIndex));
    double[] classProbs = crf.calClassProbs(combProbs);
    MultiLabel predicted = pluginPredictor.predict(dataSet.getRow(dataPointIndex));
    List<Integer> classes = new ArrayList<Integer>();
    for (int k = 0; k < crf.getNumClasses(); k++) {
        if (classProbs[k] >= classProbThreshold || dataSet.getMultiLabels()[dataPointIndex].matchClass(k) || predicted.matchClass(k)) {
            classes.add(k);
        }
    }
    Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(pair -> pair.getSecond());
    List<Pair<Integer, Double>> list = classes.stream().map(l -> new Pair<Integer, Double>(l, classProbs[l])).sorted(comparator.reversed()).collect(Collectors.toList());
    for (Pair<Integer, Double> pair : list) {
        int label = pair.getFirst();
        double prob = pair.getSecond();
        int match = 0;
        if (trueLabels.matchClass(label)) {
            match = 1;
        }
        sb.append(id).append("\t").append(labelTranslator.toExtLabel(label)).append("\t").append("single").append("\t").append(prob).append("\t").append(match).append("\n");
    }
    double probability = 0;
    List<MultiLabel> support = crf.getSupportCombinations();
    for (int i = 0; i < support.size(); i++) {
        MultiLabel candidate = support.get(i);
        if (candidate.equals(predicted)) {
            probability = combProbs[i];
            break;
        }
    }
    List<Integer> predictedList = predicted.getMatchedLabelsOrdered();
    sb.append(id).append("\t");
    for (int i = 0; i < predictedList.size(); i++) {
        sb.append(labelTranslator.toExtLabel(predictedList.get(i)));
        if (i != predictedList.size() - 1) {
            sb.append(",");
        }
    }
    sb.append("\t");
    int setMatch = 0;
    if (predicted.equals(trueLabels)) {
        setMatch = 1;
    }
    sb.append("set").append("\t").append(probability).append("\t").append(setMatch).append("\n");
    return sb.toString();
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) ArrayList(java.util.ArrayList) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 88 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class NoiseOptimizer method updateTransformProb.

private void updateTransformProb(int dataPoint, int comIndex) {
    MultiLabel labels = dataSet.getMultiLabels()[dataPoint];
    MultiLabel candidate = combinations.get(comIndex);
    if (labels.isSubsetOf(candidate)) {
        double prod = 1;
        for (int l : candidate.getMatchedLabels()) {
            if (labels.matchClass(l)) {
                prod *= alphas[l];
            } else {
                prod *= (1 - alphas[l]);
            }
        }
        transformProbs[dataPoint][comIndex] = prod;
    } else {
        transformProbs[dataPoint][comIndex] = 0;
    }
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 89 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class InstanceF1Predictor method predict.

@Override
public MultiLabel predict(Vector vector) {
    List<MultiLabel> supports = cmlcrf.getSupportCombinations();
    double[] probs = cmlcrf.predictCombinationProbs(vector);
    GeneralF1Predictor generalF1Predictor = new GeneralF1Predictor();
    return generalF1Predictor.predict(numClasses, supports, probs);
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) GeneralF1Predictor(edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)

Example 90 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CMLCRFElasticNet method expandData.

private DataSet expandData(int l) {
    SequentialSparseDataSet newData = new SequentialSparseDataSet(numData, numParameters, false);
    MultiLabel label = supportedCombinations.get(l);
    List<Integer> labelPairForL = combinationToLabelPair.get(l);
    // TODO: parallelism
    for (int i = 0; i < numData; i++) {
        // add feature-label feature
        for (int y : label.getMatchedLabels()) {
            // set bias as 1
            newData.setFeatureValue(i, (numFeature + 1) * y, 1.0);
            for (Vector.Element element : dataSet.getRow(i).nonZeroes()) {
                int index = element.index();
                double value = element.get();
                newData.setFeatureValue(i, (numFeature + 1) * y + index + 1, value);
            }
        }
        for (int y : labelPairForL) {
            newData.setFeatureValue(i, (numWeightsForFeatures + y), 1.0);
        }
    }
    return newData;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) SequentialSparseDataSet(edu.neu.ccs.pyramid.dataset.SequentialSparseDataSet) DenseVector(org.apache.mahout.math.DenseVector) SequentialAccessSparseVector(org.apache.mahout.math.SequentialAccessSparseVector) Vector(org.apache.mahout.math.Vector)

Aggregations

MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)101 Vector (org.apache.mahout.math.Vector)22 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)21 File (java.io.File)14 DenseVector (org.apache.mahout.math.DenseVector)13 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)12 Pair (edu.neu.ccs.pyramid.util.Pair)8 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)7 ArrayList (java.util.ArrayList)7 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)6 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)6 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)5 GeneralF1Predictor (edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)5 Collectors (java.util.stream.Collectors)5 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)4 java.util (java.util)4 StopWatch (org.apache.commons.lang3.time.StopWatch)4 Config (edu.neu.ccs.pyramid.configuration.Config)3 DataSetUtil (edu.neu.ccs.pyramid.dataset.DataSetUtil)3 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)3