use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.
the class RerankerTrainer method trainWithSigmoid.
public Reranker trainWithSigmoid(RegDataSet regDataSet, double[] instanceWeights, MultiLabelClassifier.ClassProbEstimator classProbEstimator, PredictionFeatureExtractor predictionFeatureExtractor, LabelCalibrator labelCalibrator, RegDataSet validation, double[] noiseRates0, double[] noiseRates1) {
LSLogisticBoost lsLogisticBoost = new LSLogisticBoost();
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeaves).setMinDataPerLeaf(minDataPerLeaf).setMonotonicityType(monotonicityType);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
LSLogisticBoostOptimizer optimizer = new LSLogisticBoostOptimizer(lsLogisticBoost, regDataSet, regTreeFactory, instanceWeights, regDataSet.getLabels(), noiseRates0, noiseRates1);
// optimizer.setNoiseRates1(noiseRates1);
if (!monotonicityType.equals("none")) {
int[][] mono = new int[1][regDataSet.getNumFeatures()];
mono[0] = predictionFeatureExtractor.featureMonotonicity();
optimizer.setMonotonicity(mono);
}
optimizer.setShrinkage(shrinkage);
optimizer.initialize();
EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, 5);
LSLogisticBoost bestModel = null;
for (int i = 1; i <= maxIter; i++) {
optimizer.iterate();
if (i % 10 == 0) {
double mse = MSE.mse(lsLogisticBoost, validation);
// todo
// double trainMse = MSE.mse(lsLogisticBoost, regDataSet);
// System.out.println("iter="+i+", train mse="+trainMse+" , valid mse="+mse);
earlyStopper.add(i, mse);
if (earlyStopper.getBestIteration() == i) {
try {
bestModel = (LSLogisticBoost) Serialization.deepCopy(lsLogisticBoost);
} catch (IOException e) {
e.printStackTrace();
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
if (earlyStopper.shouldStop()) {
break;
}
}
}
return new Reranker(bestModel, classProbEstimator, numCandidates, predictionFeatureExtractor, labelCalibrator);
}
use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.
the class L2BoostOptimizer method defaultFactory.
private static RegressorFactory defaultFactory() {
RegTreeConfig regTreeConfig = new RegTreeConfig();
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
regTreeFactory.setLeafOutputCalculator(new L2BLeafOutputCalculator());
return regTreeFactory;
}
use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.
the class L2BoostTest method buildTest.
static void buildTest() throws Exception {
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"), DataSetType.CLF_SPARSE, true);
System.out.println(dataSet.getMetaInfo());
L2Boost boost = new L2Boost();
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(7);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
regTreeFactory.setLeafOutputCalculator(new L2BLeafOutputCalculator());
L2BoostOptimizer optimizer = new L2BoostOptimizer(boost, dataSet, regTreeFactory);
optimizer.setShrinkage(0.1);
optimizer.initialize();
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int round = 0; round < 200; round++) {
System.out.println("round=" + round);
optimizer.iterate();
}
stopWatch.stop();
System.out.println(stopWatch);
double accuracy = Accuracy.accuracy(boost, dataSet);
System.out.println("accuracy=" + accuracy);
Serialization.serialize(boost, new File(TMP, "boost"));
}
use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.
the class GBClassifier method train.
private static void train(Config config) throws Exception {
String sparsity = config.getString("input.matrixType");
DataSetType dataSetType = null;
switch(sparsity) {
case "dense":
dataSetType = DataSetType.CLF_DENSE;
break;
case "sparse":
dataSetType = DataSetType.CLF_SPARSE;
break;
default:
throw new IllegalArgumentException("input.matrixType should be dense or sparse");
}
ClfDataSet trainSet = TRECFormat.loadClfDataSet(config.getString("input.trainData"), dataSetType, true);
ClfDataSet testSet = null;
if (config.getBoolean("train.showTestProgress")) {
testSet = TRECFormat.loadClfDataSet(config.getString("input.testData"), dataSetType, true);
}
int numClasses = trainSet.getNumClasses();
LKBoost lkBoost = new LKBoost(numClasses);
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(config.getInt("train.numLeaves"));
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(numClasses));
LKBoostOptimizer optimizer = new LKBoostOptimizer(lkBoost, trainSet, regTreeFactory);
optimizer.setShrinkage(config.getDouble("train.shrinkage"));
optimizer.initialize();
int progressInterval = config.getInt("train.showProgress.interval");
int numIterations = config.getInt("train.numIterations");
for (int i = 1; i <= numIterations; i++) {
System.out.println("iteration " + i);
optimizer.iterate();
if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
System.out.println("training accuracy = " + Accuracy.accuracy(lkBoost, trainSet));
}
if (config.getBoolean("train.showTestProgress") && (i % progressInterval == 0 || i == numIterations)) {
System.out.println("test accuracy = " + Accuracy.accuracy(lkBoost, testSet));
}
}
System.out.println("training done!");
String output = config.getString("output.folder");
new File(output).mkdirs();
File serializedModel = new File(output, "model");
Serialization.serialize(lkBoost, serializedModel);
System.out.println("model saved to " + serializedModel.getAbsolutePath());
File pmmlModel = new File(output, "model.pmml");
PMMLConverter.savePMML(lkBoost, pmmlModel);
System.out.println("PMML model saved to " + pmmlModel.getAbsolutePath());
File reportFile = new File(output, "train_predictions.txt");
report(lkBoost, trainSet, reportFile);
System.out.println("predictions on the training set are written to " + reportFile.getAbsolutePath());
File probabilitiesFile = new File(output, "train_predicted_probabilities.txt");
probabilities(lkBoost, trainSet, probabilitiesFile);
System.out.println("predicted probabilities on the training set are written to " + probabilitiesFile.getAbsolutePath());
}
use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.
the class GBRegressor method train.
private static void train(Config config, Logger logger) throws Exception {
String sparsity = config.getString("input.matrixType");
DataSetType dataSetType = null;
switch(sparsity) {
case "dense":
dataSetType = DataSetType.REG_DENSE;
break;
case "sparse":
dataSetType = DataSetType.REG_SPARSE;
break;
default:
throw new IllegalArgumentException("input.matrixType should be dense or sparse");
}
RegDataSet trainSet = TRECFormat.loadRegDataSet(config.getString("input.trainData"), dataSetType, true);
RegDataSet testSet = null;
if (config.getBoolean("train.showTestProgress")) {
testSet = TRECFormat.loadRegDataSet(config.getString("input.testData"), dataSetType, true);
}
LSBoost lsBoost = new LSBoost();
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(config.getInt("train.numLeaves"));
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, trainSet, regTreeFactory);
optimizer.setShrinkage(config.getDouble("train.shrinkage"));
optimizer.initialize();
int progressInterval = config.getInt("train.showProgress.interval");
int numIterations = config.getInt("train.numIterations");
for (int i = 1; i <= numIterations; i++) {
logger.info("iteration " + i);
optimizer.iterate();
if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
logger.info("training RMSE = " + RMSE.rmse(lsBoost, trainSet));
}
if (config.getBoolean("train.showTestProgress") && (i % progressInterval == 0 || i == numIterations)) {
logger.info("test RMSE = " + RMSE.rmse(lsBoost, testSet));
}
}
logger.info("training done!");
String output = config.getString("output.folder");
new File(output).mkdirs();
File serializedModel = new File(output, "model");
Serialization.serialize(lsBoost, serializedModel);
logger.info("model saved to " + serializedModel.getAbsolutePath());
if (config.getBoolean("output.generatePMML")) {
File pmmlModel = new File(output, "model.pmml");
PMMLConverter.savePMML(lsBoost, pmmlModel);
logger.info("PMML model saved to " + pmmlModel.getAbsolutePath());
}
String trainReportName = config.getString("output.trainReportFolderName");
File reportFile = Paths.get(output, trainReportName, "train_predictions.txt").toFile();
report(lsBoost, trainSet, reportFile);
logger.info("predictions on the training set are written to " + reportFile.getAbsolutePath());
}
Aggregations