Search in sources :

Example 16 with RegDataSet

use of edu.neu.ccs.pyramid.dataset.RegDataSet in project pyramid by cheng-li.

the class RegressionSynthesizer method univarStep.

public RegDataSet univarStep() {
    RegDataSet dataSet = RegDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(1).dense(true).missingValue(false).build();
    for (int i = 0; i < numDataPoints; i++) {
        double featureValue = Sampling.doubleUniform(0, 1);
        double label;
        if (featureValue > 0.5) {
            label = 0.7;
        } else {
            label = 0.2;
        }
        label += noise.sample();
        dataSet.setFeatureValue(i, 0, featureValue);
        dataSet.setLabel(i, label);
    }
    return dataSet;
}
Also used : RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet)

Example 17 with RegDataSet

use of edu.neu.ccs.pyramid.dataset.RegDataSet in project pyramid by cheng-li.

the class LSBoostTest method test1.

private static void test1() throws Exception {
    RegDataSet trainSet = TRECFormat.loadRegDataSet(new File(DATASETS, "abalone/folds/fold_1/train.trec"), DataSetType.REG_DENSE, true);
    RegDataSet testSet = TRECFormat.loadRegDataSet(new File(DATASETS, "abalone/folds/fold_1/test.trec"), DataSetType.REG_DENSE, true);
    LSBoost lsBoost = new LSBoost();
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(3);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, trainSet, regTreeFactory);
    optimizer.setShrinkage(0.1);
    optimizer.initialize();
    for (int i = 0; i < 100; i++) {
        System.out.println("iteration " + i);
        System.out.println("train RMSE = " + RMSE.rmse(lsBoost, trainSet));
        System.out.println("test RMSE = " + RMSE.rmse(lsBoost, testSet));
        optimizer.iterate();
    }
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) File(java.io.File)

Example 18 with RegDataSet

use of edu.neu.ccs.pyramid.dataset.RegDataSet in project pyramid by cheng-li.

the class ElasticNetLinearRegTrainerTest method test1.

private static void test1() throws Exception {
    RegDataSet dataSet = StandardFormat.loadRegDataSet(new File(DATASETS, "spam/train_data.txt"), new File(DATASETS, "spam/train_label.txt"), ",", DataSetType.REG_DENSE, false);
    double[] labels = dataSet.getLabels();
    RegDataSet testDataSet = StandardFormat.loadRegDataSet(new File(DATASETS, "spam/test_data.txt"), new File(DATASETS, "spam/test_label.txt"), ",", DataSetType.REG_DENSE, false);
    LinearRegression linearRegression = new LinearRegression(dataSet.getNumFeatures());
    ElasticNetLinearRegOptimizer trainer = new ElasticNetLinearRegOptimizer(linearRegression, dataSet, labels);
    trainer.setRegularization(10);
    trainer.setL1Ratio(0.5);
    System.out.println("train rmse before training = " + RMSE.rmse(linearRegression, dataSet));
    System.out.println("test rmse before training = " + RMSE.rmse(linearRegression, testDataSet));
    trainer.optimize();
    System.out.println("train rmse after training = " + RMSE.rmse(linearRegression, dataSet));
    System.out.println("test rmse after training = " + RMSE.rmse(linearRegression, testDataSet));
    System.out.println("non-zeros = " + linearRegression.getWeights().getWeightsWithoutBias().getNumNonZeroElements());
}
Also used : RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) File(java.io.File)

Example 19 with RegDataSet

use of edu.neu.ccs.pyramid.dataset.RegDataSet in project pyramid by cheng-li.

the class ElasticNetLinearRegTrainerTest method test3.

private static void test3() throws Exception {
    RegDataSet dataSet = RegressionSynthesizer.linear();
    LinearRegression linearRegression = new LinearRegression(dataSet.getNumFeatures());
    ElasticNetLinearRegOptimizer trainer = new ElasticNetLinearRegOptimizer(linearRegression, dataSet);
    trainer.setRegularization(0.001);
    trainer.setL1Ratio(0.1);
    System.out.println("train rmse before training = " + RMSE.rmse(linearRegression, dataSet));
    trainer.optimize();
    System.out.println("train rmse after training = " + RMSE.rmse(linearRegression, dataSet));
    System.out.println("non-zeros = " + linearRegression.getWeights().getWeightsWithoutBias());
    List<Pair<Integer, Double>> pairs = new ArrayList<>();
    for (Vector.Element element : linearRegression.getWeights().getWeightsWithoutBias().nonZeroes()) {
        pairs.add(new Pair<>(element.index(), element.get()));
    }
    Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(pair -> pair.getSecond());
    Set<Integer> set = pairs.stream().sorted(comparator.reversed()).limit(4).map(pair -> pair.getFirst()).collect(Collectors.toSet());
    Set<Integer> trueSet = new HashSet<>();
    trueSet.add(0);
    trueSet.add(1);
    trueSet.add(2);
    trueSet.add(3);
    if (set.equals(trueSet)) {
        System.out.println("correct");
    } else {
        System.out.println("incorrect");
    }
}
Also used : DataSetType(edu.neu.ccs.pyramid.dataset.DataSetType) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) java.util(java.util) Grid(edu.neu.ccs.pyramid.util.Grid) StandardFormat(edu.neu.ccs.pyramid.dataset.StandardFormat) Vector(org.apache.mahout.math.Vector) RegressionSynthesizer(edu.neu.ccs.pyramid.simulation.RegressionSynthesizer) RMSE(edu.neu.ccs.pyramid.eval.RMSE) Collectors(java.util.stream.Collectors) Pair(edu.neu.ccs.pyramid.util.Pair) File(java.io.File) Config(edu.neu.ccs.pyramid.configuration.Config) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) Vector(org.apache.mahout.math.Vector) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 20 with RegDataSet

use of edu.neu.ccs.pyramid.dataset.RegDataSet in project pyramid by cheng-li.

the class ElasticNetLinearRegTrainerTest method test2.

private static void test2() throws Exception {
    RegDataSet dataSet = StandardFormat.loadRegDataSet(new File(DATASETS, "spam/train_data.txt"), new File(DATASETS, "spam/train_label.txt"), ",", DataSetType.REG_DENSE, false);
    double[] labels = dataSet.getLabels();
    RegDataSet testDataSet = StandardFormat.loadRegDataSet(new File(DATASETS, "spam/test_data.txt"), new File(DATASETS, "spam/test_label.txt"), ",", DataSetType.REG_DENSE, false);
    LinearRegression linearRegression = new LinearRegression(dataSet.getNumFeatures());
    Comparator<Double> comparator = Comparator.comparing(Double::doubleValue);
    List<Double> grid = Grid.logUniform(0.01, 100, 5).stream().sorted(comparator.reversed()).collect(Collectors.toList());
    List<LinearRegression> regressions = new ArrayList<>();
    for (double regularization : grid) {
        ElasticNetLinearRegOptimizer trainer = new ElasticNetLinearRegOptimizer(linearRegression, dataSet, labels);
        trainer.setRegularization(regularization);
        trainer.setL1Ratio(0.5);
        trainer.optimize();
        regressions.add(linearRegression.deepCopy());
    }
    for (int i = 0; i < grid.size(); i++) {
        System.out.println("regularization = " + grid.get(i));
        System.out.println("non-zeros = " + regressions.get(i).getWeights().getWeightsWithoutBias().getNumNonZeroElements());
        System.out.println("test rmse  = " + RMSE.rmse(regressions.get(i), testDataSet));
    }
}
Also used : RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) File(java.io.File)

Aggregations

RegDataSet (edu.neu.ccs.pyramid.dataset.RegDataSet)21 File (java.io.File)9 DataSetType (edu.neu.ccs.pyramid.dataset.DataSetType)4 RegTreeConfig (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig)3 NormalDistribution (org.apache.commons.math3.distribution.NormalDistribution)3 Vector (org.apache.mahout.math.Vector)3 Config (edu.neu.ccs.pyramid.configuration.Config)2 LSBoost (edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoost)2 RegTreeFactory (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory)2 Pair (edu.neu.ccs.pyramid.util.Pair)2 ArrayList (java.util.ArrayList)2 StopWatch (org.apache.commons.lang3.time.StopWatch)2 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)1 StandardFormat (edu.neu.ccs.pyramid.dataset.StandardFormat)1 RMSE (edu.neu.ccs.pyramid.eval.RMSE)1 LSBoostOptimizer (edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoostOptimizer)1 ElasticNetLinearRegOptimizer (edu.neu.ccs.pyramid.regression.linear_regression.ElasticNetLinearRegOptimizer)1 LinearRegression (edu.neu.ccs.pyramid.regression.linear_regression.LinearRegression)1 RegressionTree (edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree)1 TreeRule (edu.neu.ccs.pyramid.regression.regression_tree.TreeRule)1