use of edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoost in project pyramid by cheng-li.
the class GBRegressor method test.
private static void test(Config config) throws Exception {
String output = config.getString("output.folder");
File serializedModel = new File(output, "model");
LSBoost lsBoost = (LSBoost) Serialization.deserialize(serializedModel);
String sparsity = config.getString("input.matrixType");
DataSetType dataSetType = null;
switch(sparsity) {
case "dense":
dataSetType = DataSetType.REG_DENSE;
break;
case "sparse":
dataSetType = DataSetType.REG_SPARSE;
break;
default:
throw new IllegalArgumentException("input.matrixType should be dense or sparse");
}
RegDataSet testSet = TRECFormat.loadRegDataSet(config.getString("input.testData"), dataSetType, true);
System.out.println("test RMSE = " + RMSE.rmse(lsBoost, testSet));
File reportFile = new File(output, "test_predictions.txt");
report(lsBoost, testSet, reportFile);
System.out.println("predictions on the test set are written to " + reportFile.getAbsolutePath());
}
use of edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoost in project pyramid by cheng-li.
the class LSLogisticBoostTest method test3.
private static void test3() throws Exception {
RegDataSet train = TRECFormat.loadRegDataSet("/Users/chengli/Downloads/spam/train", DataSetType.REG_DENSE, true);
RegDataSet test = TRECFormat.loadRegDataSet("/Users/chengli/Downloads/spam/test", DataSetType.REG_DENSE, true);
LSLogisticBoost lsLogisticBoost = new LSLogisticBoost();
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(10);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
LSLogisticBoostOptimizer optimizer = new LSLogisticBoostOptimizer(lsLogisticBoost, train, regTreeFactory);
optimizer.setShrinkage(1);
optimizer.initialize();
for (int i = 0; i < 200; i++) {
System.out.println("training rmse = " + RMSE.rmse(lsLogisticBoost, train));
System.out.println("test rmse = " + RMSE.rmse(lsLogisticBoost, test));
optimizer.iterate();
}
for (int i = 0; i < test.getNumDataPoints(); i++) {
System.out.println(test.getLabels()[i] + " " + lsLogisticBoost.predict(test.getRow(i)));
}
System.out.println("********************** LSBOOST **************");
LSBoost lsBoost = new LSBoost();
LSBoostOptimizer lsBoostOptimizer = new LSBoostOptimizer(lsBoost, train, regTreeFactory);
lsBoostOptimizer.setShrinkage(0.1);
lsBoostOptimizer.initialize();
for (int i = 0; i < 100; i++) {
System.out.println("training rmse = " + RMSE.rmse(lsBoost, train));
System.out.println("test rmse = " + RMSE.rmse(lsBoost, test));
lsBoostOptimizer.iterate();
}
for (int i = 0; i < test.getNumDataPoints(); i++) {
System.out.println(test.getLabels()[i] + " " + lsBoost.predict(test.getRow(i)));
}
}
use of edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoost in project pyramid by cheng-li.
the class LSLogisticBoostTest method test2.
private static void test2() throws Exception {
RegDataSet train = TRECFormat.loadRegDataSet("/Users/chengli/Downloads/spam/train", DataSetType.REG_DENSE, true);
RegDataSet test = TRECFormat.loadRegDataSet("/Users/chengli/Downloads/spam/test", DataSetType.REG_DENSE, true);
for (int i = 0; i < train.getNumDataPoints(); i++) {
if (train.getLabels()[i] == 0) {
train.setLabel(i, -1);
} else {
train.setLabel(i, 2);
}
}
for (int i = 0; i < test.getNumDataPoints(); i++) {
if (test.getLabels()[i] == 0) {
test.setLabel(i, -1);
} else {
test.setLabel(i, 2);
}
}
LSLogisticBoost lsLogisticBoost = new LSLogisticBoost();
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(5);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
LSLogisticBoostOptimizer optimizer = new LSLogisticBoostOptimizer(lsLogisticBoost, train, regTreeFactory);
optimizer.initialize();
for (int i = 0; i < 100; i++) {
System.out.println("training rmse = " + RMSE.rmse(lsLogisticBoost, train));
System.out.println("test rmse = " + RMSE.rmse(lsLogisticBoost, test));
optimizer.iterate();
}
for (int i = 0; i < test.getNumDataPoints(); i++) {
System.out.println(test.getLabels()[i] + " " + lsLogisticBoost.predict(test.getRow(i)));
}
// System.out.println(Arrays.toString(lsLogisticBoost.predict(train)));
LSBoost lsBoost = new LSBoost();
LSBoostOptimizer lsBoostOptimizer = new LSBoostOptimizer(lsBoost, train, regTreeFactory);
System.out.println("LSBOOST");
lsBoostOptimizer.initialize();
for (int i = 0; i < 100; i++) {
System.out.println("training rmse = " + RMSE.rmse(lsBoost, train));
System.out.println("test rmse = " + RMSE.rmse(lsBoost, test));
lsBoostOptimizer.iterate();
}
}
use of edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoost in project pyramid by cheng-li.
the class GBRegressor method train.
private static void train(Config config, Logger logger) throws Exception {
String sparsity = config.getString("input.matrixType");
DataSetType dataSetType = null;
switch(sparsity) {
case "dense":
dataSetType = DataSetType.REG_DENSE;
break;
case "sparse":
dataSetType = DataSetType.REG_SPARSE;
break;
default:
throw new IllegalArgumentException("input.matrixType should be dense or sparse");
}
RegDataSet trainSet = TRECFormat.loadRegDataSet(config.getString("input.trainData"), dataSetType, true);
RegDataSet testSet = null;
if (config.getBoolean("train.showTestProgress")) {
testSet = TRECFormat.loadRegDataSet(config.getString("input.testData"), dataSetType, true);
}
LSBoost lsBoost = new LSBoost();
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(config.getInt("train.numLeaves"));
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, trainSet, regTreeFactory);
optimizer.setShrinkage(config.getDouble("train.shrinkage"));
optimizer.initialize();
int progressInterval = config.getInt("train.showProgress.interval");
int numIterations = config.getInt("train.numIterations");
for (int i = 1; i <= numIterations; i++) {
logger.info("iteration " + i);
optimizer.iterate();
if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
logger.info("training RMSE = " + RMSE.rmse(lsBoost, trainSet));
}
if (config.getBoolean("train.showTestProgress") && (i % progressInterval == 0 || i == numIterations)) {
logger.info("test RMSE = " + RMSE.rmse(lsBoost, testSet));
}
}
logger.info("training done!");
String output = config.getString("output.folder");
new File(output).mkdirs();
File serializedModel = new File(output, "model");
Serialization.serialize(lsBoost, serializedModel);
logger.info("model saved to " + serializedModel.getAbsolutePath());
if (config.getBoolean("output.generatePMML")) {
File pmmlModel = new File(output, "model.pmml");
PMMLConverter.savePMML(lsBoost, pmmlModel);
logger.info("PMML model saved to " + pmmlModel.getAbsolutePath());
}
String trainReportName = config.getString("output.trainReportFolderName");
File reportFile = Paths.get(output, trainReportName, "train_predictions.txt").toFile();
report(lsBoost, trainSet, reportFile);
logger.info("predictions on the training set are written to " + reportFile.getAbsolutePath());
}
use of edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoost in project pyramid by cheng-li.
the class Calibration method trainCalibrator.
private static LSBoost trainCalibrator(RegDataSet calib, RegDataSet valid, int[] monotonicity) {
LSBoost lsBoost = new LSBoost();
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(10).setMinDataPerLeaf(5);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, calib, regTreeFactory, calib.getLabels());
if (true) {
int[][] mono = new int[1][calib.getNumFeatures()];
mono[0] = monotonicity;
optimizer.setMonotonicity(mono);
}
optimizer.setShrinkage(0.1);
optimizer.initialize();
EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, 5);
LSBoost bestModel = null;
for (int i = 1; i < 1000; i++) {
optimizer.iterate();
if (i % 10 == 0) {
double mse = MSE.mse(lsBoost, valid);
earlyStopper.add(i, mse);
if (earlyStopper.getBestIteration() == i) {
try {
bestModel = (LSBoost) Serialization.deepCopy(lsBoost);
} catch (IOException e) {
e.printStackTrace();
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
if (earlyStopper.shouldStop()) {
break;
}
}
}
return bestModel;
}
Aggregations