use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.
the class RidgeLogisticOptimizerTest method test3.
private static void test3() throws Exception {
// ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/train.trec"),
// DataSetType.CLF_SPARSE, true);
// ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/test.trec"),
// DataSetType.CLF_SPARSE, true);
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "20newsgroup/1/train.trec"), DataSetType.CLF_SPARSE, true);
ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "20newsgroup/1/test.trec"), DataSetType.CLF_SPARSE, true);
// ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"),
// DataSetType.CLF_SPARSE, true);
// ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"),
// DataSetType.CLF_SPARSE, true);
double variance = 1000;
LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
Optimizable.ByGradientValue loss = new LogisticLoss(logisticRegression, dataSet, variance, true);
// GradientDescent optimizer = new GradientDescent(loss);
LBFGS optimizer = new LBFGS(loss);
System.out.println("after initialization");
System.out.println("train acc = " + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("test acc = " + Accuracy.accuracy(logisticRegression, testSet));
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int i = 0; i < 200; i++) {
optimizer.iterate();
System.out.println("after iteration " + i);
System.out.println("loss = " + loss.getValue());
System.out.println("train acc = " + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("test acc = " + Accuracy.accuracy(logisticRegression, testSet));
// System.out.println(logisticRegression);
}
}
use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.
the class RidgeLogisticOptimizerTest method test2.
private static void test2() throws Exception {
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"), DataSetType.CLF_SPARSE, true);
ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"), DataSetType.CLF_SPARSE, true);
LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
// generate equal weights
double[] gammas = new double[dataSet.getNumDataPoints()];
for (int n = 0; n < dataSet.getNumDataPoints(); n++) {
gammas[n] = 1.0;
}
// generate the targets distributions.
int[] labels = dataSet.getLabels();
double[][] targets = new double[dataSet.getNumDataPoints()][2];
for (int n = 0; n < dataSet.getNumDataPoints(); n++) {
int label = labels[n];
if (label == 0.0) {
targets[n][0] = 1;
} else {
targets[n][1] = 1;
}
}
RidgeLogisticOptimizer optimizer = new RidgeLogisticOptimizer(logisticRegression, dataSet, gammas, targets, 500, true);
optimizer.getOptimizer().getTerminator().setMaxIteration(10000).setMode(Terminator.Mode.STANDARD);
System.out.println("after initialization");
System.out.println("train acc = " + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("test acc = " + Accuracy.accuracy(logisticRegression, testSet));
optimizer.optimize();
System.out.println("after training");
System.out.println("train acc = " + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("test acc = " + Accuracy.accuracy(logisticRegression, testSet));
System.out.println(optimizer.getOptimizer().getTerminator().getHistory());
}
use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.
the class RidgeLogisticTrainerTest method test1.
private static void test1() throws Exception {
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/train.trec"), DataSetType.CLF_SPARSE, true);
ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/test.trec"), DataSetType.CLF_SPARSE, true);
System.out.println(dataSet.getMetaInfo());
RidgeLogisticTrainer trainer = RidgeLogisticTrainer.getBuilder().setEpsilon(1).setGaussianPriorVariance(0.5).setHistory(5).build();
LogisticRegression logisticRegression = trainer.train(dataSet);
System.out.println("train: " + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("test: " + Accuracy.accuracy(logisticRegression, testSet));
}
use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.
the class RidgeLogisticTrainerTest method test3.
private static void test3() throws Exception {
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "20newsgroup/1/train.trec"), DataSetType.CLF_SPARSE, true);
ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "20newsgroup/1/test.trec"), DataSetType.CLF_SPARSE, true);
System.out.println(dataSet.getMetaInfo());
RidgeLogisticTrainer trainer = RidgeLogisticTrainer.getBuilder().setEpsilon(0.01).setGaussianPriorVariance(0.5).setHistory(5).build();
StopWatch stopWatch = new StopWatch();
stopWatch.start();
LogisticRegression logisticRegression = trainer.train(dataSet);
System.out.println(stopWatch);
System.out.println("train: " + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("test: " + Accuracy.accuracy(logisticRegression, testSet));
}
use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.
the class ElasticNetLogisticTrainerTest method test2.
private static void test2() throws Exception {
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"), DataSetType.CLF_SPARSE, true);
ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"), DataSetType.CLF_SPARSE, true);
System.out.println(dataSet.getMetaInfo());
LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
ElasticNetLogisticTrainer trainer = ElasticNetLogisticTrainer.newBuilder(logisticRegression, dataSet).setEpsilon(0.01).setL1Ratio(0.5).setRegularization(0.0001).build();
trainer.optimize();
System.out.println("training accuracy = " + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("test accuracy = " + Accuracy.accuracy(logisticRegression, testSet));
}
Aggregations