use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.
the class LogisticRegressionTest method test1.
private static void test1() throws Exception {
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/train.trec"), DataSetType.CLF_SPARSE, false);
ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/test.trec"), DataSetType.CLF_SPARSE, false);
System.out.println(dataSet.getMetaInfo());
LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
LogisticLoss function = new LogisticLoss(logisticRegression, dataSet, 1000, true);
GradientDescent gradientDescent = new GradientDescent(function);
gradientDescent.getLineSearcher().setInitialStepLength(1.0E-4);
gradientDescent.optimize();
System.out.println("train: " + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("test: " + Accuracy.accuracy(logisticRegression, testSet));
}
use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.
the class LogisticRegressionTest method test2.
private static void test2() throws Exception {
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/train.trec"), DataSetType.CLF_SPARSE, true);
ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/test.trec"), DataSetType.CLF_SPARSE, true);
System.out.println(dataSet.getMetaInfo());
LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
LogisticLoss function = new LogisticLoss(logisticRegression, dataSet, 0.1, true);
ConjugateGradientDescent conjugateGradientDescent = new ConjugateGradientDescent(function);
conjugateGradientDescent.getLineSearcher().setInitialStepLength(0.01);
conjugateGradientDescent.optimize();
System.out.println("train: " + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("test: " + Accuracy.accuracy(logisticRegression, testSet));
}
use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.
the class ElasticNetLogisticTrainerTest method test7.
private static void test7() throws Exception {
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "20newsgroup/1/train.trec"), DataSetType.CLF_SPARSE, true);
ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "20newsgroup/1/test.trec"), DataSetType.CLF_SPARSE, true);
System.out.println(dataSet.getMetaInfo());
LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
ElasticNetLogisticTrainer trainer = ElasticNetLogisticTrainer.newBuilder(logisticRegression, dataSet).setEpsilon(0.01).setL1Ratio(0.1111111111111111).setRegularization(1.1233240329780266E-6).build();
trainer.optimize();
System.out.println("training accuracy = " + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("test accuracy = " + Accuracy.accuracy(logisticRegression, testSet));
}
use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.
the class ElasticNetLogisticTrainerTest method test3.
private static void test3() throws Exception {
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/train.trec"), DataSetType.CLF_SPARSE, true);
ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/test.trec"), DataSetType.CLF_SPARSE, true);
System.out.println(dataSet.getMetaInfo());
LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
ElasticNetLogisticTrainer trainer = ElasticNetLogisticTrainer.newBuilder(logisticRegression, dataSet).setEpsilon(0.01).setL1Ratio(0.1).setRegularization(0.001).build();
for (int i = 0; i < 10; i++) {
System.out.println("iteration " + i);
trainer.iterate();
System.out.println("training accuracy = " + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("test accuracy = " + Accuracy.accuracy(logisticRegression, testSet));
}
}
use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.
the class NoiseOptimizerLR method buildLrData.
private ClfDataSet buildLrData(int classIndex) {
// Generate data set for current class
int numCombination = this.combinations.size();
ClfDataSet lrDataSet = ClfDataSetBuilder.getBuilder().numDataPoints(dataSet.getNumDataPoints() * numCombination).numFeatures(dataSet.getNumClasses()).numClasses(2).dense(true).missingValue(false).build();
for (int i = 0; i < this.dataSet.getNumDataPoints(); i++) {
int labelToSet = 0;
if (this.dataSet.getMultiLabels()[i].matchClass(classIndex)) {
labelToSet = 1;
}
for (int k = 0; k < numCombination; k++) {
// set feature
for (int j = 0; j < this.dataSet.getNumClasses(); j++) {
if (this.combinations.get(k).matchClass(j)) {
// lrDataSet.setFeatureValue(i * numCombination + k, j, 1);
lrDataSet.setFeatureValue(i * numCombination + k, j, 0.5);
} else {
lrDataSet.setFeatureValue(i * numCombination + k, j, -0.5);
}
}
// set label
lrDataSet.setLabel(i * numCombination + k, labelToSet);
}
}
return lrDataSet;
}
Aggregations