Search in sources :

Example 26 with ClfDataSet

use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.

the class RidgeLogisticOptimizerTest method test3.

private static void test3() throws Exception {
    //        ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/train.trec"),
    //                DataSetType.CLF_SPARSE, true);
    //        ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/test.trec"),
    //                DataSetType.CLF_SPARSE, true);
    ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "20newsgroup/1/train.trec"), DataSetType.CLF_SPARSE, true);
    ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "20newsgroup/1/test.trec"), DataSetType.CLF_SPARSE, true);
    //        ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"),
    //                DataSetType.CLF_SPARSE, true);
    //        ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"),
    //                DataSetType.CLF_SPARSE, true);
    double variance = 1000;
    LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
    Optimizable.ByGradientValue loss = new LogisticLoss(logisticRegression, dataSet, variance, true);
    //        GradientDescent optimizer = new GradientDescent(loss);
    LBFGS optimizer = new LBFGS(loss);
    System.out.println("after initialization");
    System.out.println("train acc = " + Accuracy.accuracy(logisticRegression, dataSet));
    System.out.println("test acc = " + Accuracy.accuracy(logisticRegression, testSet));
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    for (int i = 0; i < 200; i++) {
        optimizer.iterate();
        System.out.println("after iteration " + i);
        System.out.println("loss = " + loss.getValue());
        System.out.println("train acc = " + Accuracy.accuracy(logisticRegression, dataSet));
        System.out.println("test acc = " + Accuracy.accuracy(logisticRegression, testSet));
    //            System.out.println(logisticRegression);
    }
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) File(java.io.File) StopWatch(org.apache.commons.lang3.time.StopWatch)

Example 27 with ClfDataSet

use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.

the class RidgeLogisticOptimizerTest method test2.

private static void test2() throws Exception {
    ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"), DataSetType.CLF_SPARSE, true);
    ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"), DataSetType.CLF_SPARSE, true);
    LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
    // generate equal weights
    double[] gammas = new double[dataSet.getNumDataPoints()];
    for (int n = 0; n < dataSet.getNumDataPoints(); n++) {
        gammas[n] = 1.0;
    }
    // generate the targets distributions.
    int[] labels = dataSet.getLabels();
    double[][] targets = new double[dataSet.getNumDataPoints()][2];
    for (int n = 0; n < dataSet.getNumDataPoints(); n++) {
        int label = labels[n];
        if (label == 0.0) {
            targets[n][0] = 1;
        } else {
            targets[n][1] = 1;
        }
    }
    RidgeLogisticOptimizer optimizer = new RidgeLogisticOptimizer(logisticRegression, dataSet, gammas, targets, 500, true);
    optimizer.getOptimizer().getTerminator().setMaxIteration(10000).setMode(Terminator.Mode.STANDARD);
    System.out.println("after initialization");
    System.out.println("train acc = " + Accuracy.accuracy(logisticRegression, dataSet));
    System.out.println("test acc = " + Accuracy.accuracy(logisticRegression, testSet));
    optimizer.optimize();
    System.out.println("after training");
    System.out.println("train acc = " + Accuracy.accuracy(logisticRegression, dataSet));
    System.out.println("test acc = " + Accuracy.accuracy(logisticRegression, testSet));
    System.out.println(optimizer.getOptimizer().getTerminator().getHistory());
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) File(java.io.File)

Example 28 with ClfDataSet

use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.

the class RidgeLogisticTrainerTest method test1.

private static void test1() throws Exception {
    ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/train.trec"), DataSetType.CLF_SPARSE, true);
    ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/test.trec"), DataSetType.CLF_SPARSE, true);
    System.out.println(dataSet.getMetaInfo());
    RidgeLogisticTrainer trainer = RidgeLogisticTrainer.getBuilder().setEpsilon(1).setGaussianPriorVariance(0.5).setHistory(5).build();
    LogisticRegression logisticRegression = trainer.train(dataSet);
    System.out.println("train: " + Accuracy.accuracy(logisticRegression, dataSet));
    System.out.println("test: " + Accuracy.accuracy(logisticRegression, testSet));
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) File(java.io.File)

Example 29 with ClfDataSet

use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.

the class RidgeLogisticTrainerTest method test3.

private static void test3() throws Exception {
    ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "20newsgroup/1/train.trec"), DataSetType.CLF_SPARSE, true);
    ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "20newsgroup/1/test.trec"), DataSetType.CLF_SPARSE, true);
    System.out.println(dataSet.getMetaInfo());
    RidgeLogisticTrainer trainer = RidgeLogisticTrainer.getBuilder().setEpsilon(0.01).setGaussianPriorVariance(0.5).setHistory(5).build();
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    LogisticRegression logisticRegression = trainer.train(dataSet);
    System.out.println(stopWatch);
    System.out.println("train: " + Accuracy.accuracy(logisticRegression, dataSet));
    System.out.println("test: " + Accuracy.accuracy(logisticRegression, testSet));
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) File(java.io.File) StopWatch(org.apache.commons.lang3.time.StopWatch)

Example 30 with ClfDataSet

use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.

the class ElasticNetLogisticTrainerTest method test2.

private static void test2() throws Exception {
    ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"), DataSetType.CLF_SPARSE, true);
    ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"), DataSetType.CLF_SPARSE, true);
    System.out.println(dataSet.getMetaInfo());
    LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
    ElasticNetLogisticTrainer trainer = ElasticNetLogisticTrainer.newBuilder(logisticRegression, dataSet).setEpsilon(0.01).setL1Ratio(0.5).setRegularization(0.0001).build();
    trainer.optimize();
    System.out.println("training accuracy = " + Accuracy.accuracy(logisticRegression, dataSet));
    System.out.println("test accuracy = " + Accuracy.accuracy(logisticRegression, testSet));
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) File(java.io.File)

Aggregations

ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)35 File (java.io.File)31 StopWatch (org.apache.commons.lang3.time.StopWatch)8 LogisticRegression (edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)7 RidgeLogisticTrainer (edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer)6 LKBoost (edu.neu.ccs.pyramid.classification.lkboost.LKBoost)2 Config (edu.neu.ccs.pyramid.configuration.Config)2 DataSetType (edu.neu.ccs.pyramid.dataset.DataSetType)2 ConjugateGradientDescent (edu.neu.ccs.pyramid.optimization.ConjugateGradientDescent)2 RegTreeConfig (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig)2 RegTreeFactory (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory)2 PriorProbClassifier (edu.neu.ccs.pyramid.classification.PriorProbClassifier)1 LKBOutputCalculator (edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator)1 LKBoostOptimizer (edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer)1 ElasticNetLogisticTrainer (edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer)1 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)1 GradientDescent (edu.neu.ccs.pyramid.optimization.GradientDescent)1 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)1 BufferedWriter (java.io.BufferedWriter)1 FileWriter (java.io.FileWriter)1