Search in sources :

Example 1 with CRFLoss

use of edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss in project pyramid by cheng-li.

the class CMLCRFTest method test9.

private static void test9() {
    MultiLabelClfDataSet train = MultiLabelSynthesizer.independentNoise();
    MultiLabelClfDataSet test = MultiLabelSynthesizer.independent();
    CMLCRF cmlcrf = new CMLCRF(train);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(0, 0);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(1, 1);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(0, 1);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(1, 1);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(0, 1);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(1, 0);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(0, 1);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(1, -1);
    CRFLoss crfLoss = new CRFLoss(cmlcrf, train, 1);
    System.out.println(cmlcrf);
    System.out.println("initial loss = " + crfLoss.getValue());
    System.out.println("training performance");
    System.out.println(new MLMeasures(cmlcrf, train));
    System.out.println("test performance");
    System.out.println(new MLMeasures(cmlcrf, test));
    LBFGS optimizer = new LBFGS(crfLoss);
    while (!optimizer.getTerminator().shouldTerminate()) {
        System.out.println("------------");
        optimizer.iterate();
        System.out.println(optimizer.getTerminator().getLastValue());
        System.out.println("training performance");
        System.out.println(new MLMeasures(cmlcrf, train));
        System.out.println("test performance");
        System.out.println(new MLMeasures(cmlcrf, test));
    }
    System.out.println(cmlcrf);
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) CRFLoss(edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 2 with CRFLoss

use of edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss in project pyramid by cheng-li.

the class CMLCRFTest method test6.

private static void test6() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "medical/train"), DataSetType.ML_CLF_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "medical/test"), DataSetType.ML_CLF_SPARSE, true);
    CMLCRF cmlcrf = new CMLCRF(dataSet);
    CRFLoss crfLoss = new CRFLoss(cmlcrf, dataSet, 1);
    MultiLabel[] predTrain;
    MultiLabel[] predTest;
    LBFGS optimizer = new LBFGS(crfLoss);
    for (int i = 0; i < 50; i++) {
        // System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
        System.out.println("iter: " + i);
        optimizer.iterate();
        System.out.println(crfLoss.getValue());
        predTrain = cmlcrf.predict(dataSet);
        predTest = cmlcrf.predict(testSet);
        System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
        System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
    // System.out.println("crf = "+cmlcrf.getWeights());
    // System.out.println(Arrays.toString(predTrain));
    }
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) CRFLoss(edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 3 with CRFLoss

use of edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss in project pyramid by cheng-li.

the class CMLCRFTest method test1.

private static void test1() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "spam/trec_data/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "spam/trec_data/test.trec"), DataSetType.ML_CLF_SPARSE, true);
    CMLCRF cmlcrf = new CMLCRF(dataSet);
    CRFLoss crfLoss = new CRFLoss(cmlcrf, dataSet, 1);
    cmlcrf.setConsiderPair(true);
    MultiLabel[] predTrain;
    MultiLabel[] predTest;
    LBFGS optimizer = new LBFGS(crfLoss);
    for (int i = 0; i < 5000; i++) {
        // System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
        System.out.println("iter: " + i);
        optimizer.iterate();
        System.out.println(crfLoss.getValue());
        predTrain = cmlcrf.predict(dataSet);
        predTest = cmlcrf.predict(testSet);
        System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
        System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
    // System.out.println("crf = "+cmlcrf.getWeights());
    // System.out.println(Arrays.toString(predTrain));
    }
// LBFGS optimizer = new LBFGS(crfLoss);
// optimizer.getTerminator().setAbsoluteEpsilon(0.01);
// optimizer.optimize();
// predTrain = cmlcrf.predict(dataSet);
// predTest = cmlcrf.predict(testSet);
// System.out.print("Train acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
// System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
// System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
// System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) CRFLoss(edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 4 with CRFLoss

use of edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss in project pyramid by cheng-li.

the class CMLCRFTest method test4.

private static void test4() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "20newsgroup/1/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "20newsgroup/1/test.trec"), DataSetType.ML_CLF_SPARSE, true);
    CMLCRF cmlcrf = new CMLCRF(dataSet);
    CRFLoss crfLoss = new CRFLoss(cmlcrf, dataSet, 1);
    MultiLabel[] predTrain;
    MultiLabel[] predTest;
    LBFGS optimizer = new LBFGS(crfLoss);
    for (int i = 0; i < 50; i++) {
        // System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
        System.out.println("iter: " + i);
        optimizer.iterate();
        System.out.println(crfLoss.getValue());
        predTrain = cmlcrf.predict(dataSet);
        predTest = cmlcrf.predict(testSet);
        System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
        System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
    // System.out.println("crf = "+cmlcrf.getWeights());
    // System.out.println(Arrays.toString(predTrain));
    }
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) CRFLoss(edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 5 with CRFLoss

use of edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss in project pyramid by cheng-li.

the class CMLCRFTest method test2.

public static void test2() throws Exception {
    System.out.println(config);
    MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_DENSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_DENSE, true);
    double gaussianVariance = config.getDouble("gaussianVariance");
    // loading or save model infos.
    String output = config.getString("output");
    String modelName = config.getString("modelName");
    CMLCRF cmlcrf;
    MultiLabel[] predTrain;
    MultiLabel[] predTest;
    if (config.getBoolean("train.warmStart")) {
        cmlcrf = CMLCRF.deserialize(new File(output, modelName));
        System.out.println("loading model:");
        System.out.println(cmlcrf);
    } else {
        cmlcrf = new CMLCRF(trainSet);
        CRFLoss crfLoss = new CRFLoss(cmlcrf, trainSet, gaussianVariance);
        if (config.getBoolean("isLBFGS")) {
            LBFGS optimizer = new LBFGS(crfLoss);
            optimizer.getTerminator().setAbsoluteEpsilon(0.1);
            for (int i = 0; i < config.getInt("numRounds"); i++) {
                optimizer.iterate();
                predTrain = cmlcrf.predict(trainSet);
                predTest = cmlcrf.predict(testSet);
                System.out.print("iter: " + String.format("%04d", i));
                System.out.print("\tTrain acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
                System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
                System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
                System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
            }
        } else {
            GradientDescent optimizer = new GradientDescent(crfLoss);
            for (int i = 0; i < config.getInt("numRounds"); i++) {
                optimizer.iterate();
                predTrain = cmlcrf.predict(trainSet);
                predTest = cmlcrf.predict(testSet);
                System.out.print("iter: " + String.format("%04d", i));
                System.out.print("\tTrain acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
                System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
                System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
                System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
            }
        }
    }
    System.out.println();
    System.out.println();
    System.out.println("--------------------------------Results-----------------------------\n");
    predTrain = cmlcrf.predict(trainSet);
    predTest = cmlcrf.predict(testSet);
    System.out.print("Train acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
    System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
    System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
    System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
    if (config.getBoolean("saveModel")) {
        (new File(output)).mkdirs();
        File serializeModel = new File(output, modelName);
        cmlcrf.serialize(serializeModel);
    }
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) CRFLoss(edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss) GradientDescent(edu.neu.ccs.pyramid.optimization.GradientDescent) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Aggregations

MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)7 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)7 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)7 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)7 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)6 File (java.io.File)6 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)1 GradientDescent (edu.neu.ccs.pyramid.optimization.GradientDescent)1