Search in sources :

Example 16 with RegTreeFactory

use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory in project pyramid by cheng-li.

the class GBRegressor method train.

private static void train(Config config) throws Exception {
    String sparsity = config.getString("input.matrixType");
    DataSetType dataSetType = null;
    switch(sparsity) {
        case "dense":
            dataSetType = DataSetType.REG_DENSE;
            break;
        case "sparse":
            dataSetType = DataSetType.REG_SPARSE;
            break;
        default:
            throw new IllegalArgumentException("input.matrixType should be dense or sparse");
    }
    RegDataSet trainSet = TRECFormat.loadRegDataSet(config.getString("input.trainData"), dataSetType, true);
    RegDataSet testSet = null;
    if (config.getBoolean("train.showTestProgress")) {
        testSet = TRECFormat.loadRegDataSet(config.getString("input.testData"), dataSetType, true);
    }
    LSBoost lsBoost = new LSBoost();
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(config.getInt("train.numLeaves"));
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, trainSet, regTreeFactory);
    optimizer.setShrinkage(config.getDouble("train.shrinkage"));
    optimizer.initialize();
    int progressInterval = config.getInt("train.showProgress.interval");
    int numIterations = config.getInt("train.numIterations");
    for (int i = 1; i <= numIterations; i++) {
        System.out.println("iteration " + i);
        optimizer.iterate();
        if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
            System.out.println("training RMSE = " + RMSE.rmse(lsBoost, trainSet));
        }
        if (config.getBoolean("train.showTestProgress") && (i % progressInterval == 0 || i == numIterations)) {
            System.out.println("test RMSE = " + RMSE.rmse(lsBoost, testSet));
        }
    }
    System.out.println("training done!");
    String output = config.getString("output.folder");
    new File(output).mkdirs();
    File serializedModel = new File(output, "model");
    Serialization.serialize(lsBoost, serializedModel);
    System.out.println("model saved to " + serializedModel.getAbsolutePath());
    File reportFile = new File(output, "train_predictions.txt");
    report(lsBoost, trainSet, reportFile);
    System.out.println("predictions on the training set are written to " + reportFile.getAbsolutePath());
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) DataSetType(edu.neu.ccs.pyramid.dataset.DataSetType) LSBoostOptimizer(edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoostOptimizer) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) File(java.io.File) LSBoost(edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoost)

Example 17 with RegTreeFactory

use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory in project pyramid by cheng-li.

the class RerankerTrainer method train.

public Reranker train(RegDataSet regDataSet, double[] instanceWeights, MultiLabelClassifier.ClassProbEstimator classProbEstimator, PredictionFeatureExtractor predictionFeatureExtractor, LabelCalibrator labelCalibrator, RegDataSet validation) {
    LSBoost lsBoost = new LSBoost();
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeaves).setMinDataPerLeaf(minDataPerLeaf).setMonotonicityType(monotonicityType);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, regDataSet, regTreeFactory, instanceWeights, regDataSet.getLabels());
    if (!monotonicityType.equals("none")) {
        int[][] mono = new int[1][regDataSet.getNumFeatures()];
        mono[0] = predictionFeatureExtractor.featureMonotonicity();
        optimizer.setMonotonicity(mono);
    }
    optimizer.setShrinkage(shrinkage);
    optimizer.initialize();
    EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, 5);
    LSBoost bestModel = null;
    for (int i = 1; i <= maxIter; i++) {
        optimizer.iterate();
        if (i % 10 == 0 || i == maxIter) {
            double mse = MSE.mse(lsBoost, validation);
            earlyStopper.add(i, mse);
            if (earlyStopper.getBestIteration() == i) {
                try {
                    bestModel = (LSBoost) Serialization.deepCopy(lsBoost);
                } catch (IOException e) {
                    e.printStackTrace();
                } catch (ClassNotFoundException e) {
                    e.printStackTrace();
                }
            }
            if (earlyStopper.shouldStop()) {
                break;
            }
        }
    }
    // System.out.println("best iteration = "+earlyStopper.getBestIteration());
    return new Reranker(bestModel, classProbEstimator, numCandidates, predictionFeatureExtractor, labelCalibrator);
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) LSBoostOptimizer(edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoostOptimizer) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) IOException(java.io.IOException) LSBoost(edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoost)

Example 18 with RegTreeFactory

use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory in project pyramid by cheng-li.

the class GBCBMOptimizer method updateBinaryClassifier.

@Override
protected void updateBinaryClassifier(int component, int label, MultiLabelClfDataSet activeDataset, double[] activeGammas) {
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    if (cbm.binaryClassifiers[component][label] == null || cbm.binaryClassifiers[component][label] instanceof PriorProbClassifier) {
        cbm.binaryClassifiers[component][label] = new LKBoost(2);
    }
    int[] binaryLabels = DataSetUtil.toBinaryLabels(activeDataset.getMultiLabels(), label);
    double[][] targetsDistributions = DataSetUtil.labelsToDistributions(binaryLabels, 2);
    LKBoost boost = (LKBoost) this.cbm.binaryClassifiers[component][label];
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeaves);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(2));
    LKBoostOptimizer optimizer = new LKBoostOptimizer(boost, activeDataset, regTreeFactory, activeGammas, targetsDistributions);
    optimizer.setShrinkage(shrinkage);
    optimizer.initialize();
    optimizer.iterate(binaryUpdatesPerIter);
    if (logger.isDebugEnabled()) {
        logger.debug("time spent on updating component " + component + " label " + label + " = " + stopWatch);
    }
}
Also used : LKBoostOptimizer(edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) LKBOutputCalculator(edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator) PriorProbClassifier(edu.neu.ccs.pyramid.classification.PriorProbClassifier) StopWatch(org.apache.commons.lang3.time.StopWatch) LKBoost(edu.neu.ccs.pyramid.classification.lkboost.LKBoost)

Example 19 with RegTreeFactory

use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory in project pyramid by cheng-li.

the class CBMOptimizer method updateBinaryBoosting.

private void updateBinaryBoosting(int componentIndex, int labelIndex) {
    int numIterations = numIterationsBinary;
    double shrinkage = shrinkageBinary;
    LKBoost boost = (LKBoost) this.cbm.binaryClassifiers[componentIndex][labelIndex];
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeavesBinary);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(2));
    LKBoostOptimizer optimizer = new LKBoostOptimizer(boost, dataSet, regTreeFactory, gammasT[componentIndex], targetsDistributions[labelIndex]);
    optimizer.setShrinkage(shrinkage);
    optimizer.initialize();
    optimizer.iterate(numIterations);
}
Also used : LKBoostOptimizer(edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) LKBOutputCalculator(edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator) LKBoost(edu.neu.ccs.pyramid.classification.lkboost.LKBoost)

Example 20 with RegTreeFactory

use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory in project pyramid by cheng-li.

the class PMMLConverterTest method main.

public static void main(String[] args) throws Exception {
    RegDataSet trainSet = TRECFormat.loadRegDataSet(new File("/Users/chengli/Dropbox/Public/pyramid/abalone//train"), DataSetType.REG_DENSE, true);
    RegDataSet testSet = TRECFormat.loadRegDataSet(new File("/Users/chengli/Dropbox/Public/pyramid/abalone//test"), DataSetType.REG_DENSE, true);
    LSBoost lsBoost = new LSBoost();
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(3);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, trainSet, regTreeFactory);
    optimizer.setShrinkage(0.1);
    optimizer.initialize();
    for (int i = 0; i < 10; i++) {
        System.out.println("iteration " + i);
        System.out.println("train RMSE = " + RMSE.rmse(lsBoost, trainSet));
        System.out.println("test RMSE = " + RMSE.rmse(lsBoost, testSet));
        optimizer.iterate();
    }
    FeatureList featureList = trainSet.getFeatureList();
    List<RegressionTree> regressionTrees = lsBoost.getEnsemble(0).getRegressors().stream().filter(a -> a instanceof RegressionTree).map(a -> (RegressionTree) a).collect(Collectors.toList());
    System.out.println(regressionTrees);
    double constant = ((ConstantRegressor) lsBoost.getEnsemble(0).get(0)).getScore();
    PMML pmml = PMMLConverter.encodePMML(null, null, featureList, regressionTrees, (float) constant);
    System.out.println(pmml.toString());
    try (OutputStream os = new FileOutputStream("/Users/chengli/tmp/pmml.xml")) {
        MetroJAXBUtil.marshalPMML(pmml, os);
    }
}
Also used : OutputStream(java.io.OutputStream) DataSetType(edu.neu.ccs.pyramid.dataset.DataSetType) TRECFormat(edu.neu.ccs.pyramid.dataset.TRECFormat) PMML(org.dmg.pmml.PMML) FileOutputStream(java.io.FileOutputStream) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) Collectors(java.util.stream.Collectors) File(java.io.File) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) List(java.util.List) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) MetroJAXBUtil(org.jpmml.model.MetroJAXBUtil) ConstantRegressor(edu.neu.ccs.pyramid.regression.ConstantRegressor) RMSE(edu.neu.ccs.pyramid.eval.RMSE) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) OutputStream(java.io.OutputStream) FileOutputStream(java.io.FileOutputStream) ConstantRegressor(edu.neu.ccs.pyramid.regression.ConstantRegressor) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) FileOutputStream(java.io.FileOutputStream) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList) PMML(org.dmg.pmml.PMML) File(java.io.File)

Aggregations

RegTreeConfig (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig)23 RegTreeFactory (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory)23 LKBOutputCalculator (edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator)10 LKBoost (edu.neu.ccs.pyramid.classification.lkboost.LKBoost)10 LKBoostOptimizer (edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer)10 RegDataSet (edu.neu.ccs.pyramid.dataset.RegDataSet)8 LSBoost (edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoost)6 LSBoostOptimizer (edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoostOptimizer)6 File (java.io.File)6 DataSetType (edu.neu.ccs.pyramid.dataset.DataSetType)4 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)3 IOException (java.io.IOException)3 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)2 List (java.util.List)2 StopWatch (org.apache.commons.lang3.time.StopWatch)2 PriorProbClassifier (edu.neu.ccs.pyramid.classification.PriorProbClassifier)1 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)1 RMSE (edu.neu.ccs.pyramid.eval.RMSE)1 FeatureList (edu.neu.ccs.pyramid.feature.FeatureList)1 ConstantRegressor (edu.neu.ccs.pyramid.regression.ConstantRegressor)1