Search in sources :

Example 6 with RidgeLogisticTrainer

use of edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer in project pyramid by cheng-li.

the class ClassificationSynthesizerTest method test3.

private static void test3() {
    ClassificationSynthesizer classificationSynthesizer = ClassificationSynthesizer.getBuilder().setNumDataPoints(1000).setNumFeatures(2).setNoiseSD(0.1).build();
    ClfDataSet trainSet = classificationSynthesizer.multivarLine();
    ClfDataSet testSet = classificationSynthesizer.multivarLine();
    TRECFormat.save(trainSet, new File(TMP, "line3/train.trec"));
    TRECFormat.save(testSet, new File(TMP, "line3/test.trec"));
    RidgeLogisticTrainer trainer = RidgeLogisticTrainer.getBuilder().setGaussianPriorVariance(1).build();
    LogisticRegression logisticRegression = trainer.train(trainSet);
    System.out.println(Accuracy.accuracy(logisticRegression, trainSet));
    System.out.println(Accuracy.accuracy(logisticRegression, testSet));
    System.out.println(logisticRegression.getWeights().getWeightsForClass(0));
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) RidgeLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) File(java.io.File)

Aggregations

LogisticRegression (edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)6 RidgeLogisticTrainer (edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer)6 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)6 File (java.io.File)6