Search in sources :

Example 16 with ClfDataSet

use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.

the class LogisticRegressionTest method test3.

private static void test3() throws Exception {
    ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/train.trec"), DataSetType.CLF_SPARSE, false);
    ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/test.trec"), DataSetType.CLF_SPARSE, false);
    System.out.println(dataSet.getMetaInfo());
    LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
    LogisticLoss function = new LogisticLoss(logisticRegression, dataSet, 0.1, true);
    LBFGS lbfgs = new LBFGS(function);
    lbfgs.optimize();
    System.out.println("train: " + Accuracy.accuracy(logisticRegression, dataSet));
    System.out.println("test: " + Accuracy.accuracy(logisticRegression, testSet));
}
Also used : LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) File(java.io.File)

Example 17 with ClfDataSet

use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.

the class SplitterTest method test1.

private static void test1() throws Exception {
    ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/train.trec"), DataSetType.CLF_SPARSE, true);
    PriorProbClassifier priorProbClassifier = new PriorProbClassifier(dataSet.getNumClasses());
    priorProbClassifier.fit(dataSet);
    double[] gradient = priorProbClassifier.getGradient(dataSet, 1);
    RegTreeConfig regTreeConfig = new RegTreeConfig();
    int[] activeFeatures = IntStream.range(0, dataSet.getNumFeatures()).toArray();
    int[] activeDataPoints = IntStream.range(0, dataSet.getNumDataPoints()).toArray();
    Comparator<SplitResult> comparator = Comparator.comparing(SplitResult::getReduction);
    List<Integer> results = Splitter.getAllSplits(regTreeConfig, dataSet, gradient).stream().sorted(comparator.reversed()).map(result -> result.getFeatureIndex()).limit(100).collect(Collectors.toList());
    //        results.stream().forEach(i-> System.out.println(dataSet.getFeatureSetting(i).getFeatureName()));
    System.out.println(results);
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) PriorProbClassifier(edu.neu.ccs.pyramid.classification.PriorProbClassifier) File(java.io.File)

Example 18 with ClfDataSet

use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.

the class ClassificationSynthesizerTest method test4.

private static void test4() {
    ClassificationSynthesizer classificationSynthesizer = ClassificationSynthesizer.getBuilder().setNumDataPoints(1000).setNumFeatures(3).setNoiseSD(0.00000001).build();
    ClfDataSet trainSet = classificationSynthesizer.multivarLine();
    ClfDataSet testSet = classificationSynthesizer.multivarLine();
    TRECFormat.save(trainSet, new File(TMP, "line4/train.trec"));
    TRECFormat.save(testSet, new File(TMP, "line4/test.trec"));
    RidgeLogisticTrainer trainer = RidgeLogisticTrainer.getBuilder().setGaussianPriorVariance(1).build();
    LogisticRegression logisticRegression = trainer.train(trainSet);
    System.out.println(Accuracy.accuracy(logisticRegression, trainSet));
    System.out.println(Accuracy.accuracy(logisticRegression, testSet));
    System.out.println(logisticRegression.getWeights().getWeightsForClass(0));
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) RidgeLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) File(java.io.File)

Example 19 with ClfDataSet

use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.

the class ClassificationSynthesizerTest method test6.

private static void test6() {
    ClassificationSynthesizer classificationSynthesizer = ClassificationSynthesizer.getBuilder().setNumDataPoints(1000).setNumFeatures(20).setNoiseSD(0.00000001).build();
    ClfDataSet trainSet = classificationSynthesizer.multivarLine();
    ClfDataSet testSet = classificationSynthesizer.multivarLine();
    TRECFormat.save(trainSet, new File(TMP, "line6/train.trec"));
    TRECFormat.save(testSet, new File(TMP, "line6/test.trec"));
    RidgeLogisticTrainer trainer = RidgeLogisticTrainer.getBuilder().setGaussianPriorVariance(1).build();
    LogisticRegression logisticRegression = trainer.train(trainSet);
    System.out.println(Accuracy.accuracy(logisticRegression, trainSet));
    System.out.println(Accuracy.accuracy(logisticRegression, testSet));
    System.out.println(logisticRegression.getWeights().getWeightsForClass(0));
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) RidgeLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) File(java.io.File)

Example 20 with ClfDataSet

use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.

the class ClassificationSynthesizerTest method test5.

private static void test5() {
    ClassificationSynthesizer classificationSynthesizer = ClassificationSynthesizer.getBuilder().setNumDataPoints(1000).setNumFeatures(10).setNoiseSD(0.00000001).build();
    ClfDataSet trainSet = classificationSynthesizer.multivarLine();
    ClfDataSet testSet = classificationSynthesizer.multivarLine();
    TRECFormat.save(trainSet, new File(TMP, "line5/train.trec"));
    TRECFormat.save(testSet, new File(TMP, "line5/test.trec"));
    RidgeLogisticTrainer trainer = RidgeLogisticTrainer.getBuilder().setGaussianPriorVariance(1).build();
    LogisticRegression logisticRegression = trainer.train(trainSet);
    System.out.println(Accuracy.accuracy(logisticRegression, trainSet));
    System.out.println(Accuracy.accuracy(logisticRegression, testSet));
    System.out.println(logisticRegression.getWeights().getWeightsForClass(0));
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) RidgeLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) File(java.io.File)

Aggregations

ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)35 File (java.io.File)31 StopWatch (org.apache.commons.lang3.time.StopWatch)8 LogisticRegression (edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)7 RidgeLogisticTrainer (edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer)6 LKBoost (edu.neu.ccs.pyramid.classification.lkboost.LKBoost)2 Config (edu.neu.ccs.pyramid.configuration.Config)2 DataSetType (edu.neu.ccs.pyramid.dataset.DataSetType)2 ConjugateGradientDescent (edu.neu.ccs.pyramid.optimization.ConjugateGradientDescent)2 RegTreeConfig (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig)2 RegTreeFactory (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory)2 PriorProbClassifier (edu.neu.ccs.pyramid.classification.PriorProbClassifier)1 LKBOutputCalculator (edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator)1 LKBoostOptimizer (edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer)1 ElasticNetLogisticTrainer (edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer)1 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)1 GradientDescent (edu.neu.ccs.pyramid.optimization.GradientDescent)1 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)1 BufferedWriter (java.io.BufferedWriter)1 FileWriter (java.io.FileWriter)1