Search in sources :

Example 11 with RegTreeConfig

use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.

the class RerankerTrainer method trainWithSigmoid.

public Reranker trainWithSigmoid(RegDataSet regDataSet, double[] instanceWeights, MultiLabelClassifier.ClassProbEstimator classProbEstimator, PredictionFeatureExtractor predictionFeatureExtractor, LabelCalibrator labelCalibrator, RegDataSet validation, double[] noiseRates0, double[] noiseRates1) {
    LSLogisticBoost lsLogisticBoost = new LSLogisticBoost();
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeaves).setMinDataPerLeaf(minDataPerLeaf).setMonotonicityType(monotonicityType);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    LSLogisticBoostOptimizer optimizer = new LSLogisticBoostOptimizer(lsLogisticBoost, regDataSet, regTreeFactory, instanceWeights, regDataSet.getLabels(), noiseRates0, noiseRates1);
    // optimizer.setNoiseRates1(noiseRates1);
    if (!monotonicityType.equals("none")) {
        int[][] mono = new int[1][regDataSet.getNumFeatures()];
        mono[0] = predictionFeatureExtractor.featureMonotonicity();
        optimizer.setMonotonicity(mono);
    }
    optimizer.setShrinkage(shrinkage);
    optimizer.initialize();
    EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, 5);
    LSLogisticBoost bestModel = null;
    for (int i = 1; i <= maxIter; i++) {
        optimizer.iterate();
        if (i % 10 == 0) {
            double mse = MSE.mse(lsLogisticBoost, validation);
            // todo
            // double trainMse = MSE.mse(lsLogisticBoost, regDataSet);
            // System.out.println("iter="+i+", train mse="+trainMse+" , valid mse="+mse);
            earlyStopper.add(i, mse);
            if (earlyStopper.getBestIteration() == i) {
                try {
                    bestModel = (LSLogisticBoost) Serialization.deepCopy(lsLogisticBoost);
                } catch (IOException e) {
                    e.printStackTrace();
                } catch (ClassNotFoundException e) {
                    e.printStackTrace();
                }
            }
            if (earlyStopper.shouldStop()) {
                break;
            }
        }
    }
    return new Reranker(bestModel, classProbEstimator, numCandidates, predictionFeatureExtractor, labelCalibrator);
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) LSLogisticBoostOptimizer(edu.neu.ccs.pyramid.regression.ls_logistic_boost.LSLogisticBoostOptimizer) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) LSLogisticBoost(edu.neu.ccs.pyramid.regression.ls_logistic_boost.LSLogisticBoost) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) IOException(java.io.IOException)

Example 12 with RegTreeConfig

use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.

the class L2BoostOptimizer method defaultFactory.

private static RegressorFactory defaultFactory() {
    RegTreeConfig regTreeConfig = new RegTreeConfig();
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    regTreeFactory.setLeafOutputCalculator(new L2BLeafOutputCalculator());
    return regTreeFactory;
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory)

Example 13 with RegTreeConfig

use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.

the class L2BoostTest method buildTest.

static void buildTest() throws Exception {
    ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"), DataSetType.CLF_SPARSE, true);
    System.out.println(dataSet.getMetaInfo());
    L2Boost boost = new L2Boost();
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(7);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    regTreeFactory.setLeafOutputCalculator(new L2BLeafOutputCalculator());
    L2BoostOptimizer optimizer = new L2BoostOptimizer(boost, dataSet, regTreeFactory);
    optimizer.setShrinkage(0.1);
    optimizer.initialize();
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    for (int round = 0; round < 200; round++) {
        System.out.println("round=" + round);
        optimizer.iterate();
    }
    stopWatch.stop();
    System.out.println(stopWatch);
    double accuracy = Accuracy.accuracy(boost, dataSet);
    System.out.println("accuracy=" + accuracy);
    Serialization.serialize(boost, new File(TMP, "boost"));
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) File(java.io.File) StopWatch(org.apache.commons.lang3.time.StopWatch)

Example 14 with RegTreeConfig

use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.

the class GBClassifier method train.

private static void train(Config config) throws Exception {
    String sparsity = config.getString("input.matrixType");
    DataSetType dataSetType = null;
    switch(sparsity) {
        case "dense":
            dataSetType = DataSetType.CLF_DENSE;
            break;
        case "sparse":
            dataSetType = DataSetType.CLF_SPARSE;
            break;
        default:
            throw new IllegalArgumentException("input.matrixType should be dense or sparse");
    }
    ClfDataSet trainSet = TRECFormat.loadClfDataSet(config.getString("input.trainData"), dataSetType, true);
    ClfDataSet testSet = null;
    if (config.getBoolean("train.showTestProgress")) {
        testSet = TRECFormat.loadClfDataSet(config.getString("input.testData"), dataSetType, true);
    }
    int numClasses = trainSet.getNumClasses();
    LKBoost lkBoost = new LKBoost(numClasses);
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(config.getInt("train.numLeaves"));
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(numClasses));
    LKBoostOptimizer optimizer = new LKBoostOptimizer(lkBoost, trainSet, regTreeFactory);
    optimizer.setShrinkage(config.getDouble("train.shrinkage"));
    optimizer.initialize();
    int progressInterval = config.getInt("train.showProgress.interval");
    int numIterations = config.getInt("train.numIterations");
    for (int i = 1; i <= numIterations; i++) {
        System.out.println("iteration " + i);
        optimizer.iterate();
        if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
            System.out.println("training accuracy = " + Accuracy.accuracy(lkBoost, trainSet));
        }
        if (config.getBoolean("train.showTestProgress") && (i % progressInterval == 0 || i == numIterations)) {
            System.out.println("test accuracy = " + Accuracy.accuracy(lkBoost, testSet));
        }
    }
    System.out.println("training done!");
    String output = config.getString("output.folder");
    new File(output).mkdirs();
    File serializedModel = new File(output, "model");
    Serialization.serialize(lkBoost, serializedModel);
    System.out.println("model saved to " + serializedModel.getAbsolutePath());
    File pmmlModel = new File(output, "model.pmml");
    PMMLConverter.savePMML(lkBoost, pmmlModel);
    System.out.println("PMML model saved to " + pmmlModel.getAbsolutePath());
    File reportFile = new File(output, "train_predictions.txt");
    report(lkBoost, trainSet, reportFile);
    System.out.println("predictions on the training set are written to " + reportFile.getAbsolutePath());
    File probabilitiesFile = new File(output, "train_predicted_probabilities.txt");
    probabilities(lkBoost, trainSet, probabilitiesFile);
    System.out.println("predicted probabilities on the training set are written to " + probabilitiesFile.getAbsolutePath());
}
Also used : LKBoostOptimizer(edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) DataSetType(edu.neu.ccs.pyramid.dataset.DataSetType) ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) LKBOutputCalculator(edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator) File(java.io.File) LKBoost(edu.neu.ccs.pyramid.classification.lkboost.LKBoost)

Example 15 with RegTreeConfig

use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.

the class GBRegressor method train.

private static void train(Config config, Logger logger) throws Exception {
    String sparsity = config.getString("input.matrixType");
    DataSetType dataSetType = null;
    switch(sparsity) {
        case "dense":
            dataSetType = DataSetType.REG_DENSE;
            break;
        case "sparse":
            dataSetType = DataSetType.REG_SPARSE;
            break;
        default:
            throw new IllegalArgumentException("input.matrixType should be dense or sparse");
    }
    RegDataSet trainSet = TRECFormat.loadRegDataSet(config.getString("input.trainData"), dataSetType, true);
    RegDataSet testSet = null;
    if (config.getBoolean("train.showTestProgress")) {
        testSet = TRECFormat.loadRegDataSet(config.getString("input.testData"), dataSetType, true);
    }
    LSBoost lsBoost = new LSBoost();
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(config.getInt("train.numLeaves"));
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, trainSet, regTreeFactory);
    optimizer.setShrinkage(config.getDouble("train.shrinkage"));
    optimizer.initialize();
    int progressInterval = config.getInt("train.showProgress.interval");
    int numIterations = config.getInt("train.numIterations");
    for (int i = 1; i <= numIterations; i++) {
        logger.info("iteration " + i);
        optimizer.iterate();
        if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
            logger.info("training RMSE = " + RMSE.rmse(lsBoost, trainSet));
        }
        if (config.getBoolean("train.showTestProgress") && (i % progressInterval == 0 || i == numIterations)) {
            logger.info("test RMSE = " + RMSE.rmse(lsBoost, testSet));
        }
    }
    logger.info("training done!");
    String output = config.getString("output.folder");
    new File(output).mkdirs();
    File serializedModel = new File(output, "model");
    Serialization.serialize(lsBoost, serializedModel);
    logger.info("model saved to " + serializedModel.getAbsolutePath());
    if (config.getBoolean("output.generatePMML")) {
        File pmmlModel = new File(output, "model.pmml");
        PMMLConverter.savePMML(lsBoost, pmmlModel);
        logger.info("PMML model saved to " + pmmlModel.getAbsolutePath());
    }
    String trainReportName = config.getString("output.trainReportFolderName");
    File reportFile = Paths.get(output, trainReportName, "train_predictions.txt").toFile();
    report(lsBoost, trainSet, reportFile);
    logger.info("predictions on the training set are written to " + reportFile.getAbsolutePath());
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) DataSetType(edu.neu.ccs.pyramid.dataset.DataSetType) LSBoostOptimizer(edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoostOptimizer) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) File(java.io.File) LSBoost(edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoost)

Aggregations

RegTreeConfig (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig)26 RegTreeFactory (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory)23 LKBOutputCalculator (edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator)10 LKBoost (edu.neu.ccs.pyramid.classification.lkboost.LKBoost)10 LKBoostOptimizer (edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer)10 RegDataSet (edu.neu.ccs.pyramid.dataset.RegDataSet)9 File (java.io.File)7 LSBoost (edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoost)6 LSBoostOptimizer (edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoostOptimizer)6 DataSetType (edu.neu.ccs.pyramid.dataset.DataSetType)4 RegressionTree (edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree)4 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)3 IOException (java.io.IOException)3 StopWatch (org.apache.commons.lang3.time.StopWatch)3 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)2 ArrayList (java.util.ArrayList)2 List (java.util.List)2 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)1 PriorProbClassifier (edu.neu.ccs.pyramid.classification.PriorProbClassifier)1 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)1