Search in sources :

Example 16 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CRFLoss method calEmpiricalCountForFeature.

private double calEmpiricalCountForFeature(int parameterIndex) {
    double empiricalCount = 0.0;
    int classIndex = parameterToClass[parameterIndex];
    int featureIndex = parameterToFeature[parameterIndex];
    if (featureIndex == -1) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            if (dataSet.getMultiLabels()[i].matchClass(classIndex)) {
                empiricalCount += 1;
            }
        }
    } else {
        Vector column = dataSet.getColumn(featureIndex);
        MultiLabel[] multiLabels = dataSet.getMultiLabels();
        for (Vector.Element element : column.nonZeroes()) {
            int dataIndex = element.index();
            double featureValue = element.get();
            if (multiLabels[dataIndex].matchClass(classIndex)) {
                empiricalCount += featureValue;
            }
        }
    }
    return empiricalCount;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector)

Example 17 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CRFLoss method calEmpiricalCountForLabelPair.

private double calEmpiricalCountForLabelPair(int parameterIndex) {
    double empiricalCount = 0.0;
    int start = parameterIndex - numWeightsForFeatures;
    int l1 = parameterToL1[start];
    int l2 = parameterToL2[start];
    int featureCase = start % 4;
    for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
        MultiLabel label = dataSet.getMultiLabels()[i];
        switch(featureCase) {
            // both l1, l2 equal 0;
            case 0:
                if (!label.matchClass(l1) && !label.matchClass(l2))
                    empiricalCount += 1.0;
                break;
            // l1 = 1; l2 = 0;
            case 1:
                if (label.matchClass(l1) && !label.matchClass(l2))
                    empiricalCount += 1.0;
                break;
            // l1 = 0; l2 = 1;
            case 2:
                if (!label.matchClass(l1) && label.matchClass(l2))
                    empiricalCount += 1.0;
                break;
            // l1 = 1; l2 = 1;
            case 3:
                if (label.matchClass(l1) && label.matchClass(l2))
                    empiricalCount += 1.0;
                break;
            default:
                throw new RuntimeException("feature case :" + featureCase + " failed.");
        }
    }
    return empiricalCount;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 18 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class InstanceF1Predictor method showPredictBySupport.

public GeneralF1Predictor.Analysis showPredictBySupport(Vector vector, MultiLabel truth) {
    //        System.out.println("support procedure");
    List<MultiLabel> support = cmlcrf.getSupportCombinations();
    double[] probs = cmlcrf.predictCombinationProbs(vector);
    GeneralF1Predictor generalF1Predictor = new GeneralF1Predictor();
    MultiLabel prediction = generalF1Predictor.predict(cmlcrf.getNumClasses(), support, probs);
    GeneralF1Predictor.Analysis analysis = GeneralF1Predictor.showSupportPrediction(support, probs, truth, prediction, cmlcrf.getNumClasses());
    return analysis;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) GeneralF1Predictor(edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)

Example 19 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CMLCRFElasticNet method calEmpiricalCountForFeature.

private double calEmpiricalCountForFeature(int parameterIndex) {
    double empiricalCount = 0.0;
    int classIndex = parameterToClass[parameterIndex];
    int featureIndex = parameterToFeature[parameterIndex];
    if (featureIndex == -1) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            if (dataSet.getMultiLabels()[i].matchClass(classIndex)) {
                empiricalCount += 1;
            }
        }
    } else {
        Vector column = dataSet.getColumn(featureIndex);
        MultiLabel[] multiLabels = dataSet.getMultiLabels();
        for (Vector.Element element : column.nonZeroes()) {
            int dataIndex = element.index();
            double featureValue = element.get();
            if (multiLabels[dataIndex].matchClass(classIndex)) {
                empiricalCount += featureValue;
            }
        }
    }
    return empiricalCount;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DenseVector(org.apache.mahout.math.DenseVector) SequentialAccessSparseVector(org.apache.mahout.math.SequentialAccessSparseVector) Vector(org.apache.mahout.math.Vector)

Example 20 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class MultiLabelSynthesizer method crfArgmax.

public static MultiLabelClfDataSet crfArgmax() {
    int numData = 1000;
    int numClass = 4;
    int numFeature = 10;
    MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
    List<MultiLabel> support = Enumerator.enumerate(numClass);
    CMLCRF cmlcrf = new CMLCRF(numClass, numFeature, support);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(0, 0);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(1, 10);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(0, 10);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(1, 10);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(0, 10);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(1, 0);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(0, -10);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(1, -10);
    // generate features
    for (int i = 0; i < numData; i++) {
        for (int j = 0; j < numFeature; j++) {
            dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1, 1));
        }
    }
    SubsetAccPredictor predictor = new SubsetAccPredictor(cmlcrf);
    // assign labels
    for (int i = 0; i < numData; i++) {
        MultiLabel label = predictor.predict(dataSet.getRow(i));
        dataSet.setLabels(i, label);
    }
    return dataSet;
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) SubsetAccPredictor(edu.neu.ccs.pyramid.multilabel_classification.crf.SubsetAccPredictor) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Aggregations

MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)101 Vector (org.apache.mahout.math.Vector)22 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)21 File (java.io.File)14 DenseVector (org.apache.mahout.math.DenseVector)13 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)12 Pair (edu.neu.ccs.pyramid.util.Pair)8 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)7 ArrayList (java.util.ArrayList)7 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)6 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)6 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)5 GeneralF1Predictor (edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)5 Collectors (java.util.stream.Collectors)5 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)4 java.util (java.util)4 StopWatch (org.apache.commons.lang3.time.StopWatch)4 Config (edu.neu.ccs.pyramid.configuration.Config)3 DataSetUtil (edu.neu.ccs.pyramid.dataset.DataSetUtil)3 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)3