Search in sources :

Example 21 with RegTreeConfig

use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.

the class CBMOptimizer method updateBinaryBoosting.

private void updateBinaryBoosting(int componentIndex, int labelIndex) {
    int numIterations = numIterationsBinary;
    double shrinkage = shrinkageBinary;
    LKBoost boost = (LKBoost) this.cbm.binaryClassifiers[componentIndex][labelIndex];
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeavesBinary);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(2));
    LKBoostOptimizer optimizer = new LKBoostOptimizer(boost, dataSet, regTreeFactory, gammasT[componentIndex], targetsDistributions[labelIndex]);
    optimizer.setShrinkage(shrinkage);
    optimizer.initialize();
    optimizer.iterate(numIterations);
}
Also used : LKBoostOptimizer(edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) LKBOutputCalculator(edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator) LKBoost(edu.neu.ccs.pyramid.classification.lkboost.LKBoost)

Example 22 with RegTreeConfig

use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.

the class PMMLConverterTest method main.

public static void main(String[] args) throws Exception {
    RegDataSet trainSet = TRECFormat.loadRegDataSet(new File("/Users/chengli/Dropbox/Public/pyramid/abalone//train"), DataSetType.REG_DENSE, true);
    RegDataSet testSet = TRECFormat.loadRegDataSet(new File("/Users/chengli/Dropbox/Public/pyramid/abalone//test"), DataSetType.REG_DENSE, true);
    LSBoost lsBoost = new LSBoost();
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(3);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, trainSet, regTreeFactory);
    optimizer.setShrinkage(0.1);
    optimizer.initialize();
    for (int i = 0; i < 10; i++) {
        System.out.println("iteration " + i);
        System.out.println("train RMSE = " + RMSE.rmse(lsBoost, trainSet));
        System.out.println("test RMSE = " + RMSE.rmse(lsBoost, testSet));
        optimizer.iterate();
    }
    FeatureList featureList = trainSet.getFeatureList();
    List<RegressionTree> regressionTrees = lsBoost.getEnsemble(0).getRegressors().stream().filter(a -> a instanceof RegressionTree).map(a -> (RegressionTree) a).collect(Collectors.toList());
    System.out.println(regressionTrees);
    double constant = ((ConstantRegressor) lsBoost.getEnsemble(0).get(0)).getScore();
    PMML pmml = PMMLConverter.encodePMML(null, null, featureList, regressionTrees, (float) constant);
    System.out.println(pmml.toString());
    try (OutputStream os = new FileOutputStream("/Users/chengli/tmp/pmml.xml")) {
        MetroJAXBUtil.marshalPMML(pmml, os);
    }
}
Also used : OutputStream(java.io.OutputStream) DataSetType(edu.neu.ccs.pyramid.dataset.DataSetType) TRECFormat(edu.neu.ccs.pyramid.dataset.TRECFormat) PMML(org.dmg.pmml.PMML) FileOutputStream(java.io.FileOutputStream) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) Collectors(java.util.stream.Collectors) File(java.io.File) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) List(java.util.List) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) MetroJAXBUtil(org.jpmml.model.MetroJAXBUtil) ConstantRegressor(edu.neu.ccs.pyramid.regression.ConstantRegressor) RMSE(edu.neu.ccs.pyramid.eval.RMSE) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) OutputStream(java.io.OutputStream) FileOutputStream(java.io.FileOutputStream) ConstantRegressor(edu.neu.ccs.pyramid.regression.ConstantRegressor) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) FileOutputStream(java.io.FileOutputStream) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList) PMML(org.dmg.pmml.PMML) File(java.io.File)

Example 23 with RegTreeConfig

use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.

the class LSLogisticBoostTest method test1.

private static void test1() {
    RegDataSet regDataSet = RegDataSetBuilder.getBuilder().numDataPoints(10000).numFeatures(2).build();
    for (int i = 0; i < regDataSet.getNumDataPoints(); i++) {
        double r = Math.random();
        double s = Math.random();
        regDataSet.setFeatureValue(i, 0, r);
        regDataSet.setFeatureValue(i, 0, s);
        if (r + s < 0.1) {
            regDataSet.setLabel(i, 2);
        } else {
            regDataSet.setLabel(i, -2);
        }
    }
    LSLogisticBoost lsLogisticBoost = new LSLogisticBoost();
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(10);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    LSLogisticBoostOptimizer optimizer = new LSLogisticBoostOptimizer(lsLogisticBoost, regDataSet, regTreeFactory);
    optimizer.setShrinkage(100);
    optimizer.initialize();
    for (int i = 0; i < 100; i++) {
        System.out.println("training rmse = " + RMSE.rmse(lsLogisticBoost, regDataSet));
        optimizer.iterate();
    }
    for (int i = 0; i < 1000; i++) {
        System.out.println(regDataSet.getLabels()[i] + " " + lsLogisticBoost.predict(regDataSet.getRow(i)));
    }
// System.out.println("********************** LSBOOST **************");
// LSBoost lsBoost = new LSBoost();
// LSBoostOptimizer lsBoostOptimizer = new LSBoostOptimizer(lsBoost,regDataSet,regTreeFactory);
// lsBoostOptimizer.initialize();
// for (int i=0;i<100;i++){
// System.out.println("training rmse = "+ RMSE.rmse(lsBoost,regDataSet));
// lsBoostOptimizer.iterate();
// }
// 
// for (int i=0;i<1000;i++){
// System.out.println(regDataSet.getLabels()[i]+" "+lsBoost.predict(regDataSet.getRow(i)));
// }
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet)

Example 24 with RegTreeConfig

use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.

the class LSBoostTest method test1.

private static void test1() throws Exception {
    RegDataSet trainSet = TRECFormat.loadRegDataSet(new File("/Users/chengli/Dropbox/Public/pyramid/abalone//train"), DataSetType.REG_DENSE, true);
    RegDataSet testSet = TRECFormat.loadRegDataSet(new File("/Users/chengli/Dropbox/Public/pyramid/abalone//test"), DataSetType.REG_DENSE, true);
    LSBoost lsBoost = new LSBoost();
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(3);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, trainSet, regTreeFactory);
    optimizer.setShrinkage(0.1);
    optimizer.initialize();
    for (int i = 0; i < 10; i++) {
        System.out.println("iteration " + i);
        System.out.println("train RMSE = " + RMSE.rmse(lsBoost, trainSet));
        System.out.println("test RMSE = " + RMSE.rmse(lsBoost, testSet));
        optimizer.iterate();
    }
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) File(java.io.File)

Example 25 with RegTreeConfig

use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.

the class CBMOptimizer method updateMultiClassBoost.

private void updateMultiClassBoost() {
    int numComponents = cbm.numComponents;
    int numIterations = numIterationsMultiClass;
    double shrinkage = shrinkageMultiClass;
    LKBoost boost = (LKBoost) this.cbm.multiClassClassifier;
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeavesMultiClass);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(numComponents));
    LKBoostOptimizer optimizer = new LKBoostOptimizer(boost, dataSet, regTreeFactory, gammas);
    optimizer.setShrinkage(shrinkage);
    optimizer.initialize();
    optimizer.iterate(numIterations);
}
Also used : LKBoostOptimizer(edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) LKBOutputCalculator(edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator) LKBoost(edu.neu.ccs.pyramid.classification.lkboost.LKBoost)

Aggregations

RegTreeConfig (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig)26 RegTreeFactory (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory)23 LKBOutputCalculator (edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator)10 LKBoost (edu.neu.ccs.pyramid.classification.lkboost.LKBoost)10 LKBoostOptimizer (edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer)10 RegDataSet (edu.neu.ccs.pyramid.dataset.RegDataSet)9 File (java.io.File)7 LSBoost (edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoost)6 LSBoostOptimizer (edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoostOptimizer)6 DataSetType (edu.neu.ccs.pyramid.dataset.DataSetType)4 RegressionTree (edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree)4 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)3 IOException (java.io.IOException)3 StopWatch (org.apache.commons.lang3.time.StopWatch)3 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)2 ArrayList (java.util.ArrayList)2 List (java.util.List)2 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)1 PriorProbClassifier (edu.neu.ccs.pyramid.classification.PriorProbClassifier)1 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)1