Search in sources :

Example 16 with LBFGS

use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.

the class MLLogisticTrainerTest method test1.

private static void test1() throws Exception {
    ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"), DataSetType.CLF_DENSE, true);
    int numDataPoints = singleLabeldataSet.getNumDataPoints();
    int numFeatures = singleLabeldataSet.getNumFeatures();
    MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(2).build();
    int[] labels = singleLabeldataSet.getLabels();
    for (int i = 0; i < numDataPoints; i++) {
        dataSet.addLabel(i, labels[i]);
        for (int j = 0; j < numFeatures; j++) {
            double value = singleLabeldataSet.getRow(i).get(j);
            dataSet.setFeatureValue(i, j, value);
        }
    }
    List<MultiLabel> assignments = new ArrayList<>();
    assignments.add(new MultiLabel().addLabel(0));
    assignments.add(new MultiLabel().addLabel(1));
    MLLogisticRegression mlLogisticRegression = new MLLogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures(), assignments);
    MLLogisticLoss function = new MLLogisticLoss(mlLogisticRegression, dataSet, 10000);
    LBFGS lbfgs = new LBFGS(function);
    lbfgs.getTerminator().setRelativeEpsilon(0.01);
    lbfgs.setHistory(5);
    for (int i = 0; i < 100; i++) {
        System.out.println(function.getValue());
        // System.out.println(Accuracy.accuracy(mlLogisticRegression,dataSet));
        lbfgs.iterate();
    }
    System.out.println(Accuracy.accuracy(mlLogisticRegression, dataSet));
    System.out.println(Overlap.overlap(mlLogisticRegression, dataSet));
}
Also used : LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) ArrayList(java.util.ArrayList) File(java.io.File)

Example 17 with LBFGS

use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.

the class MLLogisticTrainerTest method test3.

/**
 *      * add a fake label in spam data set, if x=spam and x_0<0.1, also label it as 2
 * @throws Exception
 */
private static void test3() throws Exception {
    ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"), DataSetType.CLF_DENSE, true);
    int numDataPoints = singleLabeldataSet.getNumDataPoints();
    int numFeatures = singleLabeldataSet.getNumFeatures();
    MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(3).build();
    int[] labels = singleLabeldataSet.getLabels();
    for (int i = 0; i < numDataPoints; i++) {
        dataSet.addLabel(i, labels[i]);
        if (labels[i] == 1 && singleLabeldataSet.getRow(i).get(0) < 0.1) {
            dataSet.addLabel(i, 2);
        }
        for (int j = 0; j < numFeatures; j++) {
            double value = singleLabeldataSet.getRow(i).get(j);
            dataSet.setFeatureValue(i, j, value);
        }
    }
    List<MultiLabel> assignments = new ArrayList<>();
    assignments.add(new MultiLabel().addLabel(0));
    assignments.add(new MultiLabel().addLabel(1));
    assignments.add(new MultiLabel().addLabel(1).addLabel(2));
    MLLogisticRegression mlLogisticRegression = new MLLogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures(), assignments);
    MLLogisticLoss function = new MLLogisticLoss(mlLogisticRegression, dataSet, 10000);
    LBFGS lbfgs = new LBFGS(function);
    lbfgs.getTerminator().setRelativeEpsilon(0.01);
    lbfgs.setHistory(5);
    for (int i = 0; i < 1000; i++) {
        // System.out.println(function.getValue());
        System.out.println(Accuracy.accuracy(mlLogisticRegression, dataSet));
        lbfgs.iterate();
    }
    System.out.println(Accuracy.accuracy(mlLogisticRegression, dataSet));
    System.out.println(Overlap.overlap(mlLogisticRegression, dataSet));
}
Also used : LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) ArrayList(java.util.ArrayList) File(java.io.File)

Example 18 with LBFGS

use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.

the class CMLCRFTest method test2.

public static void test2() throws Exception {
    System.out.println(config);
    MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_DENSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_DENSE, true);
    double gaussianVariance = config.getDouble("gaussianVariance");
    // loading or save model infos.
    String output = config.getString("output");
    String modelName = config.getString("modelName");
    CMLCRF cmlcrf;
    MultiLabel[] predTrain;
    MultiLabel[] predTest;
    if (config.getBoolean("train.warmStart")) {
        cmlcrf = CMLCRF.deserialize(new File(output, modelName));
        System.out.println("loading model:");
        System.out.println(cmlcrf);
    } else {
        cmlcrf = new CMLCRF(trainSet);
        CRFLoss crfLoss = new CRFLoss(cmlcrf, trainSet, gaussianVariance);
        if (config.getBoolean("isLBFGS")) {
            LBFGS optimizer = new LBFGS(crfLoss);
            optimizer.getTerminator().setAbsoluteEpsilon(0.1);
            for (int i = 0; i < config.getInt("numRounds"); i++) {
                optimizer.iterate();
                predTrain = cmlcrf.predict(trainSet);
                predTest = cmlcrf.predict(testSet);
                System.out.print("iter: " + String.format("%04d", i));
                System.out.print("\tTrain acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
                System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
                System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
                System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
            }
        } else {
            GradientDescent optimizer = new GradientDescent(crfLoss);
            for (int i = 0; i < config.getInt("numRounds"); i++) {
                optimizer.iterate();
                predTrain = cmlcrf.predict(trainSet);
                predTest = cmlcrf.predict(testSet);
                System.out.print("iter: " + String.format("%04d", i));
                System.out.print("\tTrain acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
                System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
                System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
                System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
            }
        }
    }
    System.out.println();
    System.out.println();
    System.out.println("--------------------------------Results-----------------------------\n");
    predTrain = cmlcrf.predict(trainSet);
    predTest = cmlcrf.predict(testSet);
    System.out.print("Train acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
    System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
    System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
    System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
    if (config.getBoolean("saveModel")) {
        (new File(output)).mkdirs();
        File serializeModel = new File(output, modelName);
        cmlcrf.serialize(serializeModel);
    }
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) CRFLoss(edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss) GradientDescent(edu.neu.ccs.pyramid.optimization.GradientDescent) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 19 with LBFGS

use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.

the class CMLCRFTest method test5.

private static void test5() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
    CMLCRF cmlcrf = new CMLCRF(dataSet);
    CRFLoss crfLoss = new CRFLoss(cmlcrf, dataSet, 1);
    cmlcrf.setConsiderPair(false);
    MultiLabel[] predTrain;
    MultiLabel[] predTest;
    LBFGS optimizer = new LBFGS(crfLoss);
    for (int i = 0; i < 5; i++) {
        // System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
        System.out.println("iter: " + i);
        optimizer.iterate();
        System.out.println(crfLoss.getValue());
        predTrain = cmlcrf.predict(dataSet);
        predTest = cmlcrf.predict(testSet);
        System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
        System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
    // System.out.println("crf = "+cmlcrf.getWeights());
    // System.out.println(Arrays.toString(predTrain));
    }
    CRFLoss crfLoss2 = new CRFLoss(cmlcrf, dataSet, 1);
    cmlcrf.setConsiderPair(true);
    LBFGS optimizer2 = new LBFGS(crfLoss2);
    for (int i = 0; i < 50; i++) {
        System.out.println("consider pairs");
        // System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
        System.out.println("iter: " + i);
        optimizer2.iterate();
        System.out.println(crfLoss2.getValue());
        predTrain = cmlcrf.predict(dataSet);
        predTest = cmlcrf.predict(testSet);
        System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
        System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
    // System.out.println("crf = "+cmlcrf.getWeights());
    // System.out.println(Arrays.toString(predTrain));
    }
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) CRFLoss(edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 20 with LBFGS

use of edu.neu.ccs.pyramid.optimization.LBFGS in project pyramid by cheng-li.

the class CMLCRFTest method test3.

private static void test3() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "/imdb/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "/imdb/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
    CMLCRF cmlcrf = new CMLCRF(dataSet);
    CRFLoss crfLoss = new CRFLoss(cmlcrf, dataSet, 1);
    MultiLabel[] predTrain;
    MultiLabel[] predTest;
    LBFGS optimizer = new LBFGS(crfLoss);
    for (int i = 0; i < 50; i++) {
        // System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
        System.out.println("iter: " + i);
        optimizer.iterate();
        System.out.println(crfLoss.getValue());
        predTrain = cmlcrf.predict(dataSet);
        predTest = cmlcrf.predict(testSet);
        System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
        System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
        System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
    // System.out.println("crf = "+cmlcrf.getWeights());
    // System.out.println(Arrays.toString(predTrain));
    }
}
Also used : CMLCRF(edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF) LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) CRFLoss(edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss) File(java.io.File) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Aggregations

LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)22 File (java.io.File)12 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)9 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)7 CMLCRF (edu.neu.ccs.pyramid.multilabel_classification.crf.CMLCRF)7 CRFLoss (edu.neu.ccs.pyramid.multilabel_classification.crf.CRFLoss)7 GradientDescent (edu.neu.ccs.pyramid.optimization.GradientDescent)7 Optimizer (edu.neu.ccs.pyramid.optimization.Optimizer)6 RidgeLogisticOptimizer (edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer)2 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)2 ArrayList (java.util.ArrayList)2 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)1 LoggerContext (org.apache.logging.log4j.core.LoggerContext)1 Configuration (org.apache.logging.log4j.core.config.Configuration)1 LoggerConfig (org.apache.logging.log4j.core.config.LoggerConfig)1