use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.
the class RidgeLogisticTrainerTest method test2.
private static void test2() throws Exception {
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"), DataSetType.CLF_SPARSE, true);
ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"), DataSetType.CLF_SPARSE, true);
System.out.println(dataSet.getMetaInfo());
RidgeLogisticTrainer trainer = RidgeLogisticTrainer.getBuilder().setEpsilon(0.001).setGaussianPriorVariance(10000).setHistory(5).build();
LogisticRegression logisticRegression = trainer.train(dataSet);
System.out.println("train: " + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("test: " + Accuracy.accuracy(logisticRegression, testSet));
}
use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.
the class ElasticNetLogisticTrainerTest method test6.
private static void test6() throws Exception {
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/amazon_book_genre/3/train.trec"), DataSetType.CLF_SPARSE, true);
ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/amazon_book_genre/3/test.trec"), DataSetType.CLF_SPARSE, true);
System.out.println(dataSet.getMetaInfo());
LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
ElasticNetLogisticTrainer trainer = ElasticNetLogisticTrainer.newBuilder(logisticRegression, dataSet).setEpsilon(0.01).setL1Ratio(0).setRegularization(0.10000000000000006).build();
StopWatch stopWatch = new StopWatch();
stopWatch.start();
trainer.optimize();
System.out.println(stopWatch);
System.out.println("training accuracy = " + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("test accuracy = " + Accuracy.accuracy(logisticRegression, testSet));
System.out.println("number of non-zeros= " + logisticRegression.getWeights().getAllWeights().getNumNonZeroElements());
}
use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.
the class ElasticNetLogisticTrainerTest method test1.
private static void test1() throws Exception {
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"), DataSetType.CLF_SPARSE, true);
ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"), DataSetType.CLF_SPARSE, true);
System.out.println(dataSet.getMetaInfo());
LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
ElasticNetLogisticTrainer trainer = ElasticNetLogisticTrainer.newBuilder(logisticRegression, dataSet).setEpsilon(0.01).setL1Ratio(0.5).setRegularization(0.0001).setLineSearch(true).build();
for (int i = 0; i < 100; i++) {
System.out.println("iteration " + i);
trainer.iterate();
System.out.println("training accuracy = " + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("test accuracy = " + Accuracy.accuracy(logisticRegression, testSet));
}
}
use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.
the class ElasticNetLogisticTrainerTest method test4.
private static void test4() throws Exception {
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/cnn/4/train.trec"), DataSetType.CLF_SPARSE, true);
ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/cnn/4/test.trec"), DataSetType.CLF_SPARSE, true);
System.out.println(dataSet.getMetaInfo());
LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
ElasticNetLogisticTrainer trainer = ElasticNetLogisticTrainer.newBuilder(logisticRegression, dataSet).setEpsilon(0.01).setL1Ratio(1).setRegularization(2.4201282647943795E-4).build();
StopWatch stopWatch = new StopWatch();
stopWatch.start();
trainer.optimize();
System.out.println(stopWatch);
System.out.println("training accuracy = " + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("test accuracy = " + Accuracy.accuracy(logisticRegression, testSet));
System.out.println("number of non-zeros= " + logisticRegression.getWeights().getAllWeights().getNumNonZeroElements());
}
use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.
the class ElasticNetLogisticTrainerTest method test5.
private static void test5() throws Exception {
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/cnn/4/train.trec"), DataSetType.CLF_SPARSE, true);
ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/cnn/4/test.trec"), DataSetType.CLF_SPARSE, true);
System.out.println(dataSet.getMetaInfo());
LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
Comparator<Double> comparator = Comparator.comparing(Double::doubleValue);
List<Double> lambdas = Grid.logUniform(0.00000001, 0.1, 100).stream().sorted(comparator.reversed()).collect(Collectors.toList());
for (double lambda : lambdas) {
ElasticNetLogisticTrainer trainer = ElasticNetLogisticTrainer.newBuilder(logisticRegression, dataSet).setEpsilon(0.01).setL1Ratio(1).setRegularization(lambda).build();
System.out.println("=================================");
System.out.println("lambda = " + lambda);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
trainer.optimize();
System.out.println(stopWatch);
System.out.println("training accuracy = " + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("test accuracy = " + Accuracy.accuracy(logisticRegression, testSet));
System.out.println("number of non-zeros= " + logisticRegression.getWeights().getAllWeights().getNumNonZeroElements());
}
}
Aggregations