use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory in project pyramid by cheng-li.
the class GBRegressor method train.
private static void train(Config config) throws Exception {
String sparsity = config.getString("input.matrixType");
DataSetType dataSetType = null;
switch(sparsity) {
case "dense":
dataSetType = DataSetType.REG_DENSE;
break;
case "sparse":
dataSetType = DataSetType.REG_SPARSE;
break;
default:
throw new IllegalArgumentException("input.matrixType should be dense or sparse");
}
RegDataSet trainSet = TRECFormat.loadRegDataSet(config.getString("input.trainData"), dataSetType, true);
RegDataSet testSet = null;
if (config.getBoolean("train.showTestProgress")) {
testSet = TRECFormat.loadRegDataSet(config.getString("input.testData"), dataSetType, true);
}
LSBoost lsBoost = new LSBoost();
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(config.getInt("train.numLeaves"));
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, trainSet, regTreeFactory);
optimizer.setShrinkage(config.getDouble("train.shrinkage"));
optimizer.initialize();
int progressInterval = config.getInt("train.showProgress.interval");
int numIterations = config.getInt("train.numIterations");
for (int i = 1; i <= numIterations; i++) {
System.out.println("iteration " + i);
optimizer.iterate();
if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
System.out.println("training RMSE = " + RMSE.rmse(lsBoost, trainSet));
}
if (config.getBoolean("train.showTestProgress") && (i % progressInterval == 0 || i == numIterations)) {
System.out.println("test RMSE = " + RMSE.rmse(lsBoost, testSet));
}
}
System.out.println("training done!");
String output = config.getString("output.folder");
new File(output).mkdirs();
File serializedModel = new File(output, "model");
Serialization.serialize(lsBoost, serializedModel);
System.out.println("model saved to " + serializedModel.getAbsolutePath());
File reportFile = new File(output, "train_predictions.txt");
report(lsBoost, trainSet, reportFile);
System.out.println("predictions on the training set are written to " + reportFile.getAbsolutePath());
}
use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory in project pyramid by cheng-li.
the class RerankerTrainer method train.
public Reranker train(RegDataSet regDataSet, double[] instanceWeights, MultiLabelClassifier.ClassProbEstimator classProbEstimator, PredictionFeatureExtractor predictionFeatureExtractor, LabelCalibrator labelCalibrator, RegDataSet validation) {
LSBoost lsBoost = new LSBoost();
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeaves).setMinDataPerLeaf(minDataPerLeaf).setMonotonicityType(monotonicityType);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, regDataSet, regTreeFactory, instanceWeights, regDataSet.getLabels());
if (!monotonicityType.equals("none")) {
int[][] mono = new int[1][regDataSet.getNumFeatures()];
mono[0] = predictionFeatureExtractor.featureMonotonicity();
optimizer.setMonotonicity(mono);
}
optimizer.setShrinkage(shrinkage);
optimizer.initialize();
EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, 5);
LSBoost bestModel = null;
for (int i = 1; i <= maxIter; i++) {
optimizer.iterate();
if (i % 10 == 0 || i == maxIter) {
double mse = MSE.mse(lsBoost, validation);
earlyStopper.add(i, mse);
if (earlyStopper.getBestIteration() == i) {
try {
bestModel = (LSBoost) Serialization.deepCopy(lsBoost);
} catch (IOException e) {
e.printStackTrace();
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
if (earlyStopper.shouldStop()) {
break;
}
}
}
// System.out.println("best iteration = "+earlyStopper.getBestIteration());
return new Reranker(bestModel, classProbEstimator, numCandidates, predictionFeatureExtractor, labelCalibrator);
}
use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory in project pyramid by cheng-li.
the class GBCBMOptimizer method updateBinaryClassifier.
@Override
protected void updateBinaryClassifier(int component, int label, MultiLabelClfDataSet activeDataset, double[] activeGammas) {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
if (cbm.binaryClassifiers[component][label] == null || cbm.binaryClassifiers[component][label] instanceof PriorProbClassifier) {
cbm.binaryClassifiers[component][label] = new LKBoost(2);
}
int[] binaryLabels = DataSetUtil.toBinaryLabels(activeDataset.getMultiLabels(), label);
double[][] targetsDistributions = DataSetUtil.labelsToDistributions(binaryLabels, 2);
LKBoost boost = (LKBoost) this.cbm.binaryClassifiers[component][label];
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeaves);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(2));
LKBoostOptimizer optimizer = new LKBoostOptimizer(boost, activeDataset, regTreeFactory, activeGammas, targetsDistributions);
optimizer.setShrinkage(shrinkage);
optimizer.initialize();
optimizer.iterate(binaryUpdatesPerIter);
if (logger.isDebugEnabled()) {
logger.debug("time spent on updating component " + component + " label " + label + " = " + stopWatch);
}
}
use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory in project pyramid by cheng-li.
the class CBMOptimizer method updateBinaryBoosting.
private void updateBinaryBoosting(int componentIndex, int labelIndex) {
int numIterations = numIterationsBinary;
double shrinkage = shrinkageBinary;
LKBoost boost = (LKBoost) this.cbm.binaryClassifiers[componentIndex][labelIndex];
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeavesBinary);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(2));
LKBoostOptimizer optimizer = new LKBoostOptimizer(boost, dataSet, regTreeFactory, gammasT[componentIndex], targetsDistributions[labelIndex]);
optimizer.setShrinkage(shrinkage);
optimizer.initialize();
optimizer.iterate(numIterations);
}
use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory in project pyramid by cheng-li.
the class PMMLConverterTest method main.
public static void main(String[] args) throws Exception {
RegDataSet trainSet = TRECFormat.loadRegDataSet(new File("/Users/chengli/Dropbox/Public/pyramid/abalone//train"), DataSetType.REG_DENSE, true);
RegDataSet testSet = TRECFormat.loadRegDataSet(new File("/Users/chengli/Dropbox/Public/pyramid/abalone//test"), DataSetType.REG_DENSE, true);
LSBoost lsBoost = new LSBoost();
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(3);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, trainSet, regTreeFactory);
optimizer.setShrinkage(0.1);
optimizer.initialize();
for (int i = 0; i < 10; i++) {
System.out.println("iteration " + i);
System.out.println("train RMSE = " + RMSE.rmse(lsBoost, trainSet));
System.out.println("test RMSE = " + RMSE.rmse(lsBoost, testSet));
optimizer.iterate();
}
FeatureList featureList = trainSet.getFeatureList();
List<RegressionTree> regressionTrees = lsBoost.getEnsemble(0).getRegressors().stream().filter(a -> a instanceof RegressionTree).map(a -> (RegressionTree) a).collect(Collectors.toList());
System.out.println(regressionTrees);
double constant = ((ConstantRegressor) lsBoost.getEnsemble(0).get(0)).getScore();
PMML pmml = PMMLConverter.encodePMML(null, null, featureList, regressionTrees, (float) constant);
System.out.println(pmml.toString());
try (OutputStream os = new FileOutputStream("/Users/chengli/tmp/pmml.xml")) {
MetroJAXBUtil.marshalPMML(pmml, os);
}
}
Aggregations