Search in sources :

Example 11 with CMLCRF

use of edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF in project pyramid by cheng-li.

the class CMLCRFTest method test5.

private static void test5() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
    CMLCRF cmlcrf = new CMLCRF(dataSet);
    CRFLoss crfLoss = new CRFLoss(cmlcrf, dataSet, 1);
    cmlcrf.setConsiderPair(false);
    MultiLabel[] predTrain;
    MultiLabel[] predTest;
    LBFGS optimizer = new LBFGS(crfLoss);
    for (int i = 0; i < 5; i++) {
        // System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
        System.out.println("iter: " + i);
        optimizer.iterate();
        System.out.println(crfLoss.getValue());
        predTrain = cmlcrf.predict(dataSet);
        predTest = cmlcrf.predict(testSet);
        System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
        System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
    // System.out.println("crf = "+cmlcrf.getWeights());
    // System.out.println(Arrays.toString(predTrain));
    }
    CRFLoss crfLoss2 = new CRFLoss(cmlcrf, dataSet, 1);
    cmlcrf.setConsiderPair(true);
    LBFGS optimizer2 = new LBFGS(crfLoss2);
    for (int i = 0; i < 50; i++) {
        System.out.println("consider pairs");
        // System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
        System.out.println("iter: " + i);
        optimizer2.iterate();
        System.out.println(crfLoss2.getValue());
        predTrain = cmlcrf.predict(dataSet);
        predTest = cmlcrf.predict(testSet);
        System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
        System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
    // System.out.println("crf = "+cmlcrf.getWeights());
    // System.out.println(Arrays.toString(predTrain));
    }
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) CRFLoss(edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 12 with CMLCRF

use of edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF in project pyramid by cheng-li.

the class CMLCRFTest method test3.

private static void test3() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "/imdb/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "/imdb/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
    CMLCRF cmlcrf = new CMLCRF(dataSet);
    CRFLoss crfLoss = new CRFLoss(cmlcrf, dataSet, 1);
    MultiLabel[] predTrain;
    MultiLabel[] predTest;
    LBFGS optimizer = new LBFGS(crfLoss);
    for (int i = 0; i < 50; i++) {
        // System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
        System.out.println("iter: " + i);
        optimizer.iterate();
        System.out.println(crfLoss.getValue());
        predTrain = cmlcrf.predict(dataSet);
        predTest = cmlcrf.predict(testSet);
        System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
        System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
    // System.out.println("crf = "+cmlcrf.getWeights());
    // System.out.println(Arrays.toString(predTrain));
    }
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) CRFLoss(edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 13 with CMLCRF

use of edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF in project pyramid by cheng-li.

the class CMLCRFTest method test7.

private static void test7() throws Exception {
    System.out.println(config);
    MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
    // loading or save model infos.
    String output = config.getString("output");
    String modelName = config.getString("modelName");
    CMLCRF cmlcrf = null;
    if (config.getString("train.warmStart").equals("true")) {
        cmlcrf = CMLCRF.deserialize(new File(output, modelName));
        System.out.println("loading model:");
        System.out.println(cmlcrf);
    } else if (config.getString("train.warmStart").equals("auto")) {
        cmlcrf = CMLCRF.deserialize(new File(output, modelName));
        System.out.println("retrain model:");
        CMLCRFElasticNet cmlcrfElasticNet = new CMLCRFElasticNet(cmlcrf, trainSet, config.getDouble("l1Ratio"), config.getDouble("regularization"));
        train(cmlcrfElasticNet, cmlcrf, trainSet, testSet, config);
    } else if (config.getString("train.warmStart").equals("false")) {
        cmlcrf = new CMLCRF(trainSet);
        cmlcrf.setConsiderPair(config.getBoolean("considerLabelPair"));
        CMLCRFElasticNet cmlcrfElasticNet = new CMLCRFElasticNet(cmlcrf, trainSet, config.getDouble("l1Ratio"), config.getDouble("regularization"));
        train(cmlcrfElasticNet, cmlcrf, trainSet, testSet, config);
    }
    System.out.println();
    System.out.println();
    System.out.println("--------------------------------Results-----------------------------\n");
    MLMeasures measures = new MLMeasures(cmlcrf, trainSet);
    System.out.println("========== Train ==========\n");
    System.out.println(measures);
    System.out.println("========== Test ==========\n");
    long startTimePred = System.nanoTime();
    MultiLabel[] preds = cmlcrf.predict(testSet);
    long stopTimePred = System.nanoTime();
    long predTime = stopTimePred - startTimePred;
    System.out.println("\nprediction time: " + TimeUnit.NANOSECONDS.toSeconds(predTime) + " sec.");
    System.out.println(new MLMeasures(cmlcrf, testSet));
    System.out.println("\n\n");
    InstanceF1Predictor pluginF1 = new InstanceF1Predictor(cmlcrf);
    System.out.println("Plugin F1");
    System.out.println(new MLMeasures(pluginF1, testSet));
    if (config.getBoolean("saveModel")) {
        (new File(output)).mkdirs();
        File serializeModel = new File(output, modelName);
        cmlcrf.serialize(serializeModel);
    }
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Aggregations

MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)13 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)13 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)12 File (java.io.File)8 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)7 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)7 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)3 SubsetAccPredictor (edu.neu.ccs.pyramid.multilabel_classification.crf.SubsetAccPredictor)3 SamplingPredictor (edu.neu.ccs.pyramid.multilabel_classification.crf.SamplingPredictor)1 GradientDescent (edu.neu.ccs.pyramid.optimization.GradientDescent)1