use of edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer in project pyramid by cheng-li.
the class ClassificationSynthesizerTest method test4.
private static void test4() {
ClassificationSynthesizer classificationSynthesizer = ClassificationSynthesizer.getBuilder().setNumDataPoints(1000).setNumFeatures(3).setNoiseSD(0.00000001).build();
ClfDataSet trainSet = classificationSynthesizer.multivarLine();
ClfDataSet testSet = classificationSynthesizer.multivarLine();
TRECFormat.save(trainSet, new File(TMP, "line4/train.trec"));
TRECFormat.save(testSet, new File(TMP, "line4/test.trec"));
RidgeLogisticTrainer trainer = RidgeLogisticTrainer.getBuilder().setGaussianPriorVariance(1).build();
LogisticRegression logisticRegression = trainer.train(trainSet);
System.out.println(Accuracy.accuracy(logisticRegression, trainSet));
System.out.println(Accuracy.accuracy(logisticRegression, testSet));
System.out.println(logisticRegression.getWeights().getWeightsForClass(0));
}
use of edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer in project pyramid by cheng-li.
the class ClassificationSynthesizerTest method test6.
private static void test6() {
ClassificationSynthesizer classificationSynthesizer = ClassificationSynthesizer.getBuilder().setNumDataPoints(1000).setNumFeatures(20).setNoiseSD(0.00000001).build();
ClfDataSet trainSet = classificationSynthesizer.multivarLine();
ClfDataSet testSet = classificationSynthesizer.multivarLine();
TRECFormat.save(trainSet, new File(TMP, "line6/train.trec"));
TRECFormat.save(testSet, new File(TMP, "line6/test.trec"));
RidgeLogisticTrainer trainer = RidgeLogisticTrainer.getBuilder().setGaussianPriorVariance(1).build();
LogisticRegression logisticRegression = trainer.train(trainSet);
System.out.println(Accuracy.accuracy(logisticRegression, trainSet));
System.out.println(Accuracy.accuracy(logisticRegression, testSet));
System.out.println(logisticRegression.getWeights().getWeightsForClass(0));
}
use of edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer in project pyramid by cheng-li.
the class ClassificationSynthesizerTest method test5.
private static void test5() {
ClassificationSynthesizer classificationSynthesizer = ClassificationSynthesizer.getBuilder().setNumDataPoints(1000).setNumFeatures(10).setNoiseSD(0.00000001).build();
ClfDataSet trainSet = classificationSynthesizer.multivarLine();
ClfDataSet testSet = classificationSynthesizer.multivarLine();
TRECFormat.save(trainSet, new File(TMP, "line5/train.trec"));
TRECFormat.save(testSet, new File(TMP, "line5/test.trec"));
RidgeLogisticTrainer trainer = RidgeLogisticTrainer.getBuilder().setGaussianPriorVariance(1).build();
LogisticRegression logisticRegression = trainer.train(trainSet);
System.out.println(Accuracy.accuracy(logisticRegression, trainSet));
System.out.println(Accuracy.accuracy(logisticRegression, testSet));
System.out.println(logisticRegression.getWeights().getWeightsForClass(0));
}
use of edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer in project pyramid by cheng-li.
the class ClassificationSynthesizerTest method test1.
private static void test1() {
ClassificationSynthesizer classificationSynthesizer = ClassificationSynthesizer.getBuilder().setNumDataPoints(1000).setNumFeatures(2).setNoiseSD(0.00000001).build();
ClfDataSet trainSet = classificationSynthesizer.multivarLine();
ClfDataSet testSet = classificationSynthesizer.multivarLine();
TRECFormat.save(trainSet, new File(TMP, "line1/train.trec"));
TRECFormat.save(testSet, new File(TMP, "line1/test.trec"));
RidgeLogisticTrainer trainer = RidgeLogisticTrainer.getBuilder().setGaussianPriorVariance(1).build();
LogisticRegression logisticRegression = trainer.train(trainSet);
System.out.println(Accuracy.accuracy(logisticRegression, trainSet));
System.out.println(Accuracy.accuracy(logisticRegression, testSet));
System.out.println(logisticRegression.getWeights().getWeightsForClass(0));
}
use of edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer in project pyramid by cheng-li.
the class ClassificationSynthesizerTest method test2.
private static void test2() {
ClassificationSynthesizer classificationSynthesizer = ClassificationSynthesizer.getBuilder().setNumDataPoints(100).setNumFeatures(2).setNoiseSD(0.00000001).build();
ClfDataSet trainSet = classificationSynthesizer.multivarLine();
ClfDataSet testSet = classificationSynthesizer.multivarLine();
TRECFormat.save(trainSet, new File(TMP, "line2/train.trec"));
TRECFormat.save(testSet, new File(TMP, "line2/test.trec"));
RidgeLogisticTrainer trainer = RidgeLogisticTrainer.getBuilder().setGaussianPriorVariance(1).build();
LogisticRegression logisticRegression = trainer.train(trainSet);
System.out.println(Accuracy.accuracy(logisticRegression, trainSet));
System.out.println(Accuracy.accuracy(logisticRegression, testSet));
System.out.println(logisticRegression.getWeights().getWeightsForClass(0));
}
Aggregations