Search in sources :

Example 21 with ClfDataSet

use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.

the class ClassificationSynthesizer method multivarLine.

public ClfDataSet multivarLine() {
    ClfDataSet dataSet = ClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(2).dense(true).missingValue(false).build();
    for (int i = 0; i < numDataPoints; i++) {
        for (int j = 0; j < numFeatures; j++) {
            double featureValue = Sampling.doubleUniform(0, 1);
            dataSet.setFeatureValue(i, j, featureValue);
        }
        double sum = 0;
        for (int j = 0; j < numFeatures; j++) {
            sum += dataSet.getRow(i).get(j);
        }
        sum += noise.sample();
        if (sum >= numFeatures / 2.0) {
            dataSet.setLabel(i, 1);
        } else {
            dataSet.setLabel(i, 0);
        }
    }
    return dataSet;
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet)

Example 22 with ClfDataSet

use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.

the class L2BoostTest method loadTest.

static void loadTest() throws Exception {
    ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"), DataSetType.CLF_SPARSE, true);
    L2Boost boost = (L2Boost) Serialization.deserialize(new File(TMP, "boost"));
    double accuracy = Accuracy.accuracy(boost, dataSet);
    System.out.println("accuracy=" + accuracy);
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) File(java.io.File)

Example 23 with ClfDataSet

use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.

the class ClassificationSynthesizerTest method test1.

private static void test1() {
    ClassificationSynthesizer classificationSynthesizer = ClassificationSynthesizer.getBuilder().setNumDataPoints(1000).setNumFeatures(2).setNoiseSD(0.00000001).build();
    ClfDataSet trainSet = classificationSynthesizer.multivarLine();
    ClfDataSet testSet = classificationSynthesizer.multivarLine();
    TRECFormat.save(trainSet, new File(TMP, "line1/train.trec"));
    TRECFormat.save(testSet, new File(TMP, "line1/test.trec"));
    RidgeLogisticTrainer trainer = RidgeLogisticTrainer.getBuilder().setGaussianPriorVariance(1).build();
    LogisticRegression logisticRegression = trainer.train(trainSet);
    System.out.println(Accuracy.accuracy(logisticRegression, trainSet));
    System.out.println(Accuracy.accuracy(logisticRegression, testSet));
    System.out.println(logisticRegression.getWeights().getWeightsForClass(0));
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) RidgeLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) File(java.io.File)

Example 24 with ClfDataSet

use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.

the class ClassificationSynthesizerTest method test2.

private static void test2() {
    ClassificationSynthesizer classificationSynthesizer = ClassificationSynthesizer.getBuilder().setNumDataPoints(100).setNumFeatures(2).setNoiseSD(0.00000001).build();
    ClfDataSet trainSet = classificationSynthesizer.multivarLine();
    ClfDataSet testSet = classificationSynthesizer.multivarLine();
    TRECFormat.save(trainSet, new File(TMP, "line2/train.trec"));
    TRECFormat.save(testSet, new File(TMP, "line2/test.trec"));
    RidgeLogisticTrainer trainer = RidgeLogisticTrainer.getBuilder().setGaussianPriorVariance(1).build();
    LogisticRegression logisticRegression = trainer.train(trainSet);
    System.out.println(Accuracy.accuracy(logisticRegression, trainSet));
    System.out.println(Accuracy.accuracy(logisticRegression, testSet));
    System.out.println(logisticRegression.getWeights().getWeightsForClass(0));
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) RidgeLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) File(java.io.File)

Example 25 with ClfDataSet

use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.

the class ClassificationSynthesizerTest method test3.

private static void test3() {
    ClassificationSynthesizer classificationSynthesizer = ClassificationSynthesizer.getBuilder().setNumDataPoints(1000).setNumFeatures(2).setNoiseSD(0.1).build();
    ClfDataSet trainSet = classificationSynthesizer.multivarLine();
    ClfDataSet testSet = classificationSynthesizer.multivarLine();
    TRECFormat.save(trainSet, new File(TMP, "line3/train.trec"));
    TRECFormat.save(testSet, new File(TMP, "line3/test.trec"));
    RidgeLogisticTrainer trainer = RidgeLogisticTrainer.getBuilder().setGaussianPriorVariance(1).build();
    LogisticRegression logisticRegression = trainer.train(trainSet);
    System.out.println(Accuracy.accuracy(logisticRegression, trainSet));
    System.out.println(Accuracy.accuracy(logisticRegression, testSet));
    System.out.println(logisticRegression.getWeights().getWeightsForClass(0));
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) RidgeLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) File(java.io.File)

Aggregations

ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)35 File (java.io.File)31 StopWatch (org.apache.commons.lang3.time.StopWatch)8 LogisticRegression (edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)7 RidgeLogisticTrainer (edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer)6 LKBoost (edu.neu.ccs.pyramid.classification.lkboost.LKBoost)2 Config (edu.neu.ccs.pyramid.configuration.Config)2 DataSetType (edu.neu.ccs.pyramid.dataset.DataSetType)2 ConjugateGradientDescent (edu.neu.ccs.pyramid.optimization.ConjugateGradientDescent)2 RegTreeConfig (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig)2 RegTreeFactory (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory)2 PriorProbClassifier (edu.neu.ccs.pyramid.classification.PriorProbClassifier)1 LKBOutputCalculator (edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator)1 LKBoostOptimizer (edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer)1 ElasticNetLogisticTrainer (edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer)1 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)1 GradientDescent (edu.neu.ccs.pyramid.optimization.GradientDescent)1 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)1 BufferedWriter (java.io.BufferedWriter)1 FileWriter (java.io.FileWriter)1