Search in sources :

Example 91 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CMLCRFElasticNet method calEmpiricalCountForLabelPair.

private double calEmpiricalCountForLabelPair(int parameterIndex) {
    double empiricalCount = 0.0;
    int start = parameterIndex - numWeightsForFeatures;
    int l1 = parameterToL1[start];
    int l2 = parameterToL2[start];
    int featureCase = start % 4;
    for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
        MultiLabel label = dataSet.getMultiLabels()[i];
        switch(featureCase) {
            // both l1, l2 equal 0;
            case 0:
                if (!label.matchClass(l1) && !label.matchClass(l2))
                    empiricalCount += 1.0;
                break;
            // l1 = 1; l2 = 0;
            case 1:
                if (label.matchClass(l1) && !label.matchClass(l2))
                    empiricalCount += 1.0;
                break;
            // l1 = 0; l2 = 1;
            case 2:
                if (!label.matchClass(l1) && label.matchClass(l2))
                    empiricalCount += 1.0;
                break;
            // l1 = 1; l2 = 1;
            case 3:
                if (label.matchClass(l1) && label.matchClass(l2))
                    empiricalCount += 1.0;
                break;
            default:
                throw new RuntimeException("feature case :" + featureCase + " failed.");
        }
    }
    return empiricalCount;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 92 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CMLCRF method computeLabelPartScore.

/**
     * the part of score which depends only on labels
     * for each label pair, exactly one feature function returns 1
     * @return
     */
double computeLabelPartScore(int labelComIndex) {
    MultiLabel label = supportCombinations.get(labelComIndex);
    double score = 0;
    int pos = this.weights.getNumWeightsForFeatures();
    boolean[] matches = new boolean[numClasses];
    for (int match : label.getMatchedLabels()) {
        matches[match] = true;
    }
    for (int l1 = 0; l1 < numClasses; l1++) {
        for (int l2 = l1 + 1; l2 < numClasses; l2++) {
            if (!matches[l1] && !matches[l2]) {
                score += this.weights.getWeightForIndex(pos);
            } else if (matches[l1] && !matches[l2]) {
                score += this.weights.getWeightForIndex(pos + 1);
            } else if (!matches[l1] && matches[l2]) {
                score += this.weights.getWeightForIndex(pos + 2);
            } else {
                score += this.weights.getWeightForIndex(pos + 3);
            }
            pos += 4;
        }
    }
    return score;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 93 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class KLLoss method initTargMarginals.

private void initTargMarginals(int dataPoint) {
    double[] joint = targetDistribution[dataPoint];
    for (int c = 0; c < joint.length; c++) {
        MultiLabel multiLabel = supportedCombinations.get(c);
        double prob = joint[c];
        for (int l : multiLabel.getMatchedLabels()) {
            targetMarginals[dataPoint][l] += prob;
        }
    }
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel)

Example 94 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CMLCRFTest method test2.

public static void test2() throws Exception {
    System.out.println(config);
    MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_DENSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_DENSE, true);
    double gaussianVariance = config.getDouble("gaussianVariance");
    // loading or save model infos.
    String output = config.getString("output");
    String modelName = config.getString("modelName");
    CMLCRF cmlcrf;
    MultiLabel[] predTrain;
    MultiLabel[] predTest;
    if (config.getBoolean("train.warmStart")) {
        cmlcrf = CMLCRF.deserialize(new File(output, modelName));
        System.out.println("loading model:");
        System.out.println(cmlcrf);
    } else {
        cmlcrf = new CMLCRF(trainSet);
        CRFLoss crfLoss = new CRFLoss(cmlcrf, trainSet, gaussianVariance);
        if (config.getBoolean("isLBFGS")) {
            LBFGS optimizer = new LBFGS(crfLoss);
            optimizer.getTerminator().setAbsoluteEpsilon(0.1);
            for (int i = 0; i < config.getInt("numRounds"); i++) {
                optimizer.iterate();
                predTrain = cmlcrf.predict(trainSet);
                predTest = cmlcrf.predict(testSet);
                System.out.print("iter: " + String.format("%04d", i));
                System.out.print("\tTrain acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
                System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
                System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
                System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
            }
        } else {
            GradientDescent optimizer = new GradientDescent(crfLoss);
            for (int i = 0; i < config.getInt("numRounds"); i++) {
                optimizer.iterate();
                predTrain = cmlcrf.predict(trainSet);
                predTest = cmlcrf.predict(testSet);
                System.out.print("iter: " + String.format("%04d", i));
                System.out.print("\tTrain acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
                System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
                System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
                System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
            }
        }
    }
    System.out.println();
    System.out.println();
    System.out.println("--------------------------------Results-----------------------------\n");
    predTrain = cmlcrf.predict(trainSet);
    predTest = cmlcrf.predict(testSet);
    System.out.print("Train acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
    System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
    System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
    System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
    if (config.getBoolean("saveModel")) {
        (new File(output)).mkdirs();
        File serializeModel = new File(output, modelName);
        cmlcrf.serialize(serializeModel);
    }
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) CRFLoss(edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss) GradientDescent(edu.neu.ccs.pyramid.optimization.GradientDescent) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 95 with MultiLabel

use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.

the class CMLCRFTest method test7.

private static void test7() throws Exception {
    System.out.println(config);
    MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
    // loading or save model infos.
    String output = config.getString("output");
    String modelName = config.getString("modelName");
    CMLCRF cmlcrf = null;
    if (config.getString("train.warmStart").equals("true")) {
        cmlcrf = CMLCRF.deserialize(new File(output, modelName));
        System.out.println("loading model:");
        System.out.println(cmlcrf);
    } else if (config.getString("train.warmStart").equals("auto")) {
        cmlcrf = CMLCRF.deserialize(new File(output, modelName));
        System.out.println("retrain model:");
        CMLCRFElasticNet cmlcrfElasticNet = new CMLCRFElasticNet(cmlcrf, trainSet, config.getDouble("l1Ratio"), config.getDouble("regularization"));
        train(cmlcrfElasticNet, cmlcrf, trainSet, testSet, config);
    } else if (config.getString("train.warmStart").equals("false")) {
        cmlcrf = new CMLCRF(trainSet);
        cmlcrf.setConsiderPair(config.getBoolean("considerLabelPair"));
        CMLCRFElasticNet cmlcrfElasticNet = new CMLCRFElasticNet(cmlcrf, trainSet, config.getDouble("l1Ratio"), config.getDouble("regularization"));
        train(cmlcrfElasticNet, cmlcrf, trainSet, testSet, config);
    }
    System.out.println();
    System.out.println();
    System.out.println("--------------------------------Results-----------------------------\n");
    MLMeasures measures = new MLMeasures(cmlcrf, trainSet);
    System.out.println("========== Train ==========\n");
    System.out.println(measures);
    System.out.println("========== Test ==========\n");
    long startTimePred = System.nanoTime();
    MultiLabel[] preds = cmlcrf.predict(testSet);
    long stopTimePred = System.nanoTime();
    long predTime = stopTimePred - startTimePred;
    System.out.println("\nprediction time: " + TimeUnit.NANOSECONDS.toSeconds(predTime) + " sec.");
    System.out.println(new MLMeasures(cmlcrf, testSet));
    System.out.println("\n\n");
    InstanceF1Predictor pluginF1 = new InstanceF1Predictor(cmlcrf);
    System.out.println("Plugin F1");
    System.out.println(new MLMeasures(pluginF1, testSet));
    if (config.getBoolean("saveModel")) {
        (new File(output)).mkdirs();
        File serializeModel = new File(output, modelName);
        cmlcrf.serialize(serializeModel);
    }
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Aggregations

MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)101 Vector (org.apache.mahout.math.Vector)22 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)21 File (java.io.File)14 DenseVector (org.apache.mahout.math.DenseVector)13 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)12 Pair (edu.neu.ccs.pyramid.util.Pair)8 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)7 ArrayList (java.util.ArrayList)7 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)6 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)6 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)5 GeneralF1Predictor (edu.neu.ccs.pyramid.multilabel_classification.plugin_rule.GeneralF1Predictor)5 Collectors (java.util.stream.Collectors)5 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)4 java.util (java.util)4 StopWatch (org.apache.commons.lang3.time.StopWatch)4 Config (edu.neu.ccs.pyramid.configuration.Config)3 DataSetUtil (edu.neu.ccs.pyramid.dataset.DataSetUtil)3 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)3