Search in sources :

Example 6 with RegTreeConfig

use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.

the class GBCBMOptimizer method updateMultiClassClassifier.

@Override
protected void updateMultiClassClassifier() {
    if (logger.isDebugEnabled()) {
        logger.debug("start updateMultiClassClassifier");
    }
    // parallel
    LKBoost boost = (LKBoost) this.cbm.multiClassClassifier;
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeaves);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(cbm.getNumComponents()));
    LKBoostOptimizer optimizer = new LKBoostOptimizer(boost, dataSet, regTreeFactory, gammas);
    optimizer.setShrinkage(shrinkage);
    optimizer.initialize();
    optimizer.iterate(multiclassUpdatesPerIter);
    if (logger.isDebugEnabled()) {
        logger.debug("finish updateMultiClassClassifier");
    }
}
Also used : LKBoostOptimizer(edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) LKBOutputCalculator(edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator) LKBoost(edu.neu.ccs.pyramid.classification.lkboost.LKBoost)

Example 7 with RegTreeConfig

use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.

the class RulesTest method test1.

static void test1() throws Exception {
    int numLeaves = 4;
    RegDataSet dataSet = StandardFormat.loadRegDataSet("/Users/chengli/Datasets/slice_location/standard/featureList.txt", "/Users/chengli/Datasets/slice_location/standard/labels.txt", ",", DataSetType.REG_DENSE, false);
    System.out.println(dataSet.isDense());
    int[] activeFeatures = IntStream.range(0, dataSet.getNumFeatures()).toArray();
    int[] activeDataPoints = IntStream.range(0, dataSet.getNumDataPoints()).toArray();
    RegTreeConfig regTreeConfig = new RegTreeConfig();
    regTreeConfig.setMaxNumLeaves(numLeaves);
    regTreeConfig.setMinDataPerLeaf(5);
    regTreeConfig.setNumSplitIntervals(100);
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    RegressionTree regressionTree = RegTreeTrainer.fit(regTreeConfig, dataSet);
    TreeRule rule1 = new TreeRule(regressionTree, dataSet.getRow(100));
    TreeRule rule2 = new TreeRule(regressionTree, dataSet.getRow(1));
    ConstantRule rule3 = new ConstantRule(0.8);
    Rule rule4 = new LinearRule();
    List<Rule> rules = new ArrayList<>();
    rules.add(rule1);
    rules.add(rule2);
    rules.add(rule3);
    rules.add(rule4);
    ObjectMapper mapper = new ObjectMapper();
    mapper.writeValue(new File(TMP, "decision.json"), rules);
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) ArrayList(java.util.ArrayList) StopWatch(org.apache.commons.lang3.time.StopWatch) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) TreeRule(edu.neu.ccs.pyramid.regression.regression_tree.TreeRule) File(java.io.File) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper)

Example 8 with RegTreeConfig

use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.

the class LSLogisticBoostTest method test3.

private static void test3() throws Exception {
    RegDataSet train = TRECFormat.loadRegDataSet("/Users/chengli/Downloads/spam/train", DataSetType.REG_DENSE, true);
    RegDataSet test = TRECFormat.loadRegDataSet("/Users/chengli/Downloads/spam/test", DataSetType.REG_DENSE, true);
    LSLogisticBoost lsLogisticBoost = new LSLogisticBoost();
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(10);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    LSLogisticBoostOptimizer optimizer = new LSLogisticBoostOptimizer(lsLogisticBoost, train, regTreeFactory);
    optimizer.setShrinkage(1);
    optimizer.initialize();
    for (int i = 0; i < 200; i++) {
        System.out.println("training rmse = " + RMSE.rmse(lsLogisticBoost, train));
        System.out.println("test rmse = " + RMSE.rmse(lsLogisticBoost, test));
        optimizer.iterate();
    }
    for (int i = 0; i < test.getNumDataPoints(); i++) {
        System.out.println(test.getLabels()[i] + " " + lsLogisticBoost.predict(test.getRow(i)));
    }
    System.out.println("********************** LSBOOST **************");
    LSBoost lsBoost = new LSBoost();
    LSBoostOptimizer lsBoostOptimizer = new LSBoostOptimizer(lsBoost, train, regTreeFactory);
    lsBoostOptimizer.setShrinkage(0.1);
    lsBoostOptimizer.initialize();
    for (int i = 0; i < 100; i++) {
        System.out.println("training rmse = " + RMSE.rmse(lsBoost, train));
        System.out.println("test rmse = " + RMSE.rmse(lsBoost, test));
        lsBoostOptimizer.iterate();
    }
    for (int i = 0; i < test.getNumDataPoints(); i++) {
        System.out.println(test.getLabels()[i] + " " + lsBoost.predict(test.getRow(i)));
    }
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) LSBoostOptimizer(edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoostOptimizer) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) LSBoost(edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoost)

Example 9 with RegTreeConfig

use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.

the class LSLogisticBoostTest method test2.

private static void test2() throws Exception {
    RegDataSet train = TRECFormat.loadRegDataSet("/Users/chengli/Downloads/spam/train", DataSetType.REG_DENSE, true);
    RegDataSet test = TRECFormat.loadRegDataSet("/Users/chengli/Downloads/spam/test", DataSetType.REG_DENSE, true);
    for (int i = 0; i < train.getNumDataPoints(); i++) {
        if (train.getLabels()[i] == 0) {
            train.setLabel(i, -1);
        } else {
            train.setLabel(i, 2);
        }
    }
    for (int i = 0; i < test.getNumDataPoints(); i++) {
        if (test.getLabels()[i] == 0) {
            test.setLabel(i, -1);
        } else {
            test.setLabel(i, 2);
        }
    }
    LSLogisticBoost lsLogisticBoost = new LSLogisticBoost();
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(5);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    LSLogisticBoostOptimizer optimizer = new LSLogisticBoostOptimizer(lsLogisticBoost, train, regTreeFactory);
    optimizer.initialize();
    for (int i = 0; i < 100; i++) {
        System.out.println("training rmse = " + RMSE.rmse(lsLogisticBoost, train));
        System.out.println("test rmse = " + RMSE.rmse(lsLogisticBoost, test));
        optimizer.iterate();
    }
    for (int i = 0; i < test.getNumDataPoints(); i++) {
        System.out.println(test.getLabels()[i] + " " + lsLogisticBoost.predict(test.getRow(i)));
    }
    // System.out.println(Arrays.toString(lsLogisticBoost.predict(train)));
    LSBoost lsBoost = new LSBoost();
    LSBoostOptimizer lsBoostOptimizer = new LSBoostOptimizer(lsBoost, train, regTreeFactory);
    System.out.println("LSBOOST");
    lsBoostOptimizer.initialize();
    for (int i = 0; i < 100; i++) {
        System.out.println("training rmse = " + RMSE.rmse(lsBoost, train));
        System.out.println("test rmse = " + RMSE.rmse(lsBoost, test));
        lsBoostOptimizer.iterate();
    }
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) LSBoostOptimizer(edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoostOptimizer) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) LSBoost(edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoost)

Example 10 with RegTreeConfig

use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.

the class LambdaMARTOptimizerTest method test1.

private static void test1() throws Exception {
    RegDataSet train = TRECFormat.loadRegDataSet("/Users/chengli/Downloads/spam/train", DataSetType.REG_DENSE, true);
    train = DataSetUtil.shuffleRows(train, 0);
    RegDataSet test = TRECFormat.loadRegDataSet("/Users/chengli/Downloads/spam/test", DataSetType.REG_DENSE, true);
    test = DataSetUtil.shuffleRows(test, 0);
    LambdaMART lambdaMART = new LambdaMART();
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(5);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    List<List<Integer>> instancesInQuery = new ArrayList<>();
    List<Integer> all = IntStream.range(0, train.getNumDataPoints()).boxed().collect(Collectors.toList());
    instancesInQuery.add(all);
    LambdaMARTOptimizer optimizer = new LambdaMARTOptimizer(lambdaMART, train, train.getLabels(), regTreeFactory, instancesInQuery);
    optimizer.initialize();
    for (int i = 0; i < 500; i++) {
        System.out.println("==================================");
        System.out.println("iter " + i + "");
        System.out.println("ndcg = " + NDCG.ndcg(test.getLabels(), lambdaMART.predict(test)));
        // for (int j=0;j<train.getNumDataPoints();j++){
        // System.out.println("label = "+train.getLabels()[j]+" pred = "+lambdaMART.predict(train.getRow(j)));
        // }
        optimizer.iterate();
    }
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) ArrayList(java.util.ArrayList) List(java.util.List) ArrayList(java.util.ArrayList)

Aggregations

RegTreeConfig (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig)26 RegTreeFactory (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory)23 LKBOutputCalculator (edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator)10 LKBoost (edu.neu.ccs.pyramid.classification.lkboost.LKBoost)10 LKBoostOptimizer (edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer)10 RegDataSet (edu.neu.ccs.pyramid.dataset.RegDataSet)9 File (java.io.File)7 LSBoost (edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoost)6 LSBoostOptimizer (edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoostOptimizer)6 DataSetType (edu.neu.ccs.pyramid.dataset.DataSetType)4 RegressionTree (edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree)4 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)3 IOException (java.io.IOException)3 StopWatch (org.apache.commons.lang3.time.StopWatch)3 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)2 ArrayList (java.util.ArrayList)2 List (java.util.List)2 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)1 PriorProbClassifier (edu.neu.ccs.pyramid.classification.PriorProbClassifier)1 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)1