Search in sources :

Example 21 with RegDataSet

use of edu.neu.ccs.pyramid.dataset.RegDataSet in project pyramid by cheng-li.

the class GBRegressor method train.

private static void train(Config config) throws Exception {
    String sparsity = config.getString("input.matrixType");
    DataSetType dataSetType = null;
    switch(sparsity) {
        case "dense":
            dataSetType = DataSetType.REG_DENSE;
            break;
        case "sparse":
            dataSetType = DataSetType.REG_SPARSE;
            break;
        default:
            throw new IllegalArgumentException("input.matrixType should be dense or sparse");
    }
    RegDataSet trainSet = TRECFormat.loadRegDataSet(config.getString("input.trainData"), dataSetType, true);
    RegDataSet testSet = null;
    if (config.getBoolean("train.showTestProgress")) {
        testSet = TRECFormat.loadRegDataSet(config.getString("input.testData"), dataSetType, true);
    }
    LSBoost lsBoost = new LSBoost();
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(config.getInt("train.numLeaves"));
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, trainSet, regTreeFactory);
    optimizer.setShrinkage(config.getDouble("train.shrinkage"));
    optimizer.initialize();
    int progressInterval = config.getInt("train.showProgress.interval");
    int numIterations = config.getInt("train.numIterations");
    for (int i = 1; i <= numIterations; i++) {
        System.out.println("iteration " + i);
        optimizer.iterate();
        if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
            System.out.println("training RMSE = " + RMSE.rmse(lsBoost, trainSet));
        }
        if (config.getBoolean("train.showTestProgress") && (i % progressInterval == 0 || i == numIterations)) {
            System.out.println("test RMSE = " + RMSE.rmse(lsBoost, testSet));
        }
    }
    System.out.println("training done!");
    String output = config.getString("output.folder");
    new File(output).mkdirs();
    File serializedModel = new File(output, "model");
    Serialization.serialize(lsBoost, serializedModel);
    System.out.println("model saved to " + serializedModel.getAbsolutePath());
    File reportFile = new File(output, "train_predictions.txt");
    report(lsBoost, trainSet, reportFile);
    System.out.println("predictions on the training set are written to " + reportFile.getAbsolutePath());
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) DataSetType(edu.neu.ccs.pyramid.dataset.DataSetType) LSBoostOptimizer(edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoostOptimizer) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) File(java.io.File) LSBoost(edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoost)

Aggregations

RegDataSet (edu.neu.ccs.pyramid.dataset.RegDataSet)21 File (java.io.File)9 DataSetType (edu.neu.ccs.pyramid.dataset.DataSetType)4 RegTreeConfig (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig)3 NormalDistribution (org.apache.commons.math3.distribution.NormalDistribution)3 Vector (org.apache.mahout.math.Vector)3 Config (edu.neu.ccs.pyramid.configuration.Config)2 LSBoost (edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoost)2 RegTreeFactory (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory)2 Pair (edu.neu.ccs.pyramid.util.Pair)2 ArrayList (java.util.ArrayList)2 StopWatch (org.apache.commons.lang3.time.StopWatch)2 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)1 StandardFormat (edu.neu.ccs.pyramid.dataset.StandardFormat)1 RMSE (edu.neu.ccs.pyramid.eval.RMSE)1 LSBoostOptimizer (edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoostOptimizer)1 ElasticNetLinearRegOptimizer (edu.neu.ccs.pyramid.regression.linear_regression.ElasticNetLinearRegOptimizer)1 LinearRegression (edu.neu.ccs.pyramid.regression.linear_regression.LinearRegression)1 RegressionTree (edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree)1 TreeRule (edu.neu.ccs.pyramid.regression.regression_tree.TreeRule)1