use of edu.neu.ccs.pyramid.dataset.RegDataSet in project pyramid by cheng-li.
the class RegressionSynthesizer method univarStep.
public RegDataSet univarStep() {
RegDataSet dataSet = RegDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(1).dense(true).missingValue(false).build();
for (int i = 0; i < numDataPoints; i++) {
double featureValue = Sampling.doubleUniform(0, 1);
double label;
if (featureValue > 0.5) {
label = 0.7;
} else {
label = 0.2;
}
label += noise.sample();
dataSet.setFeatureValue(i, 0, featureValue);
dataSet.setLabel(i, label);
}
return dataSet;
}
use of edu.neu.ccs.pyramid.dataset.RegDataSet in project pyramid by cheng-li.
the class LSBoostTest method test1.
private static void test1() throws Exception {
RegDataSet trainSet = TRECFormat.loadRegDataSet(new File(DATASETS, "abalone/folds/fold_1/train.trec"), DataSetType.REG_DENSE, true);
RegDataSet testSet = TRECFormat.loadRegDataSet(new File(DATASETS, "abalone/folds/fold_1/test.trec"), DataSetType.REG_DENSE, true);
LSBoost lsBoost = new LSBoost();
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(3);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, trainSet, regTreeFactory);
optimizer.setShrinkage(0.1);
optimizer.initialize();
for (int i = 0; i < 100; i++) {
System.out.println("iteration " + i);
System.out.println("train RMSE = " + RMSE.rmse(lsBoost, trainSet));
System.out.println("test RMSE = " + RMSE.rmse(lsBoost, testSet));
optimizer.iterate();
}
}
use of edu.neu.ccs.pyramid.dataset.RegDataSet in project pyramid by cheng-li.
the class ElasticNetLinearRegTrainerTest method test1.
private static void test1() throws Exception {
RegDataSet dataSet = StandardFormat.loadRegDataSet(new File(DATASETS, "spam/train_data.txt"), new File(DATASETS, "spam/train_label.txt"), ",", DataSetType.REG_DENSE, false);
double[] labels = dataSet.getLabels();
RegDataSet testDataSet = StandardFormat.loadRegDataSet(new File(DATASETS, "spam/test_data.txt"), new File(DATASETS, "spam/test_label.txt"), ",", DataSetType.REG_DENSE, false);
LinearRegression linearRegression = new LinearRegression(dataSet.getNumFeatures());
ElasticNetLinearRegOptimizer trainer = new ElasticNetLinearRegOptimizer(linearRegression, dataSet, labels);
trainer.setRegularization(10);
trainer.setL1Ratio(0.5);
System.out.println("train rmse before training = " + RMSE.rmse(linearRegression, dataSet));
System.out.println("test rmse before training = " + RMSE.rmse(linearRegression, testDataSet));
trainer.optimize();
System.out.println("train rmse after training = " + RMSE.rmse(linearRegression, dataSet));
System.out.println("test rmse after training = " + RMSE.rmse(linearRegression, testDataSet));
System.out.println("non-zeros = " + linearRegression.getWeights().getWeightsWithoutBias().getNumNonZeroElements());
}
use of edu.neu.ccs.pyramid.dataset.RegDataSet in project pyramid by cheng-li.
the class ElasticNetLinearRegTrainerTest method test3.
private static void test3() throws Exception {
RegDataSet dataSet = RegressionSynthesizer.linear();
LinearRegression linearRegression = new LinearRegression(dataSet.getNumFeatures());
ElasticNetLinearRegOptimizer trainer = new ElasticNetLinearRegOptimizer(linearRegression, dataSet);
trainer.setRegularization(0.001);
trainer.setL1Ratio(0.1);
System.out.println("train rmse before training = " + RMSE.rmse(linearRegression, dataSet));
trainer.optimize();
System.out.println("train rmse after training = " + RMSE.rmse(linearRegression, dataSet));
System.out.println("non-zeros = " + linearRegression.getWeights().getWeightsWithoutBias());
List<Pair<Integer, Double>> pairs = new ArrayList<>();
for (Vector.Element element : linearRegression.getWeights().getWeightsWithoutBias().nonZeroes()) {
pairs.add(new Pair<>(element.index(), element.get()));
}
Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(pair -> pair.getSecond());
Set<Integer> set = pairs.stream().sorted(comparator.reversed()).limit(4).map(pair -> pair.getFirst()).collect(Collectors.toSet());
Set<Integer> trueSet = new HashSet<>();
trueSet.add(0);
trueSet.add(1);
trueSet.add(2);
trueSet.add(3);
if (set.equals(trueSet)) {
System.out.println("correct");
} else {
System.out.println("incorrect");
}
}
use of edu.neu.ccs.pyramid.dataset.RegDataSet in project pyramid by cheng-li.
the class ElasticNetLinearRegTrainerTest method test2.
private static void test2() throws Exception {
RegDataSet dataSet = StandardFormat.loadRegDataSet(new File(DATASETS, "spam/train_data.txt"), new File(DATASETS, "spam/train_label.txt"), ",", DataSetType.REG_DENSE, false);
double[] labels = dataSet.getLabels();
RegDataSet testDataSet = StandardFormat.loadRegDataSet(new File(DATASETS, "spam/test_data.txt"), new File(DATASETS, "spam/test_label.txt"), ",", DataSetType.REG_DENSE, false);
LinearRegression linearRegression = new LinearRegression(dataSet.getNumFeatures());
Comparator<Double> comparator = Comparator.comparing(Double::doubleValue);
List<Double> grid = Grid.logUniform(0.01, 100, 5).stream().sorted(comparator.reversed()).collect(Collectors.toList());
List<LinearRegression> regressions = new ArrayList<>();
for (double regularization : grid) {
ElasticNetLinearRegOptimizer trainer = new ElasticNetLinearRegOptimizer(linearRegression, dataSet, labels);
trainer.setRegularization(regularization);
trainer.setL1Ratio(0.5);
trainer.optimize();
regressions.add(linearRegression.deepCopy());
}
for (int i = 0; i < grid.size(); i++) {
System.out.println("regularization = " + grid.get(i));
System.out.println("non-zeros = " + regressions.get(i).getWeights().getWeightsWithoutBias().getNumNonZeroElements());
System.out.println("test rmse = " + RMSE.rmse(regressions.get(i), testDataSet));
}
}
Aggregations