use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.
the class GBCBMOptimizer method updateMultiClassClassifier.
@Override
protected void updateMultiClassClassifier() {
if (logger.isDebugEnabled()) {
logger.debug("start updateMultiClassClassifier");
}
// parallel
LKBoost boost = (LKBoost) this.cbm.multiClassClassifier;
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeaves);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(cbm.getNumComponents()));
LKBoostOptimizer optimizer = new LKBoostOptimizer(boost, dataSet, regTreeFactory, gammas);
optimizer.setShrinkage(shrinkage);
optimizer.initialize();
optimizer.iterate(multiclassUpdatesPerIter);
if (logger.isDebugEnabled()) {
logger.debug("finish updateMultiClassClassifier");
}
}
use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.
the class RulesTest method test1.
static void test1() throws Exception {
int numLeaves = 4;
RegDataSet dataSet = StandardFormat.loadRegDataSet("/Users/chengli/Datasets/slice_location/standard/featureList.txt", "/Users/chengli/Datasets/slice_location/standard/labels.txt", ",", DataSetType.REG_DENSE, false);
System.out.println(dataSet.isDense());
int[] activeFeatures = IntStream.range(0, dataSet.getNumFeatures()).toArray();
int[] activeDataPoints = IntStream.range(0, dataSet.getNumDataPoints()).toArray();
RegTreeConfig regTreeConfig = new RegTreeConfig();
regTreeConfig.setMaxNumLeaves(numLeaves);
regTreeConfig.setMinDataPerLeaf(5);
regTreeConfig.setNumSplitIntervals(100);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
RegressionTree regressionTree = RegTreeTrainer.fit(regTreeConfig, dataSet);
TreeRule rule1 = new TreeRule(regressionTree, dataSet.getRow(100));
TreeRule rule2 = new TreeRule(regressionTree, dataSet.getRow(1));
ConstantRule rule3 = new ConstantRule(0.8);
Rule rule4 = new LinearRule();
List<Rule> rules = new ArrayList<>();
rules.add(rule1);
rules.add(rule2);
rules.add(rule3);
rules.add(rule4);
ObjectMapper mapper = new ObjectMapper();
mapper.writeValue(new File(TMP, "decision.json"), rules);
}
use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.
the class LSLogisticBoostTest method test3.
private static void test3() throws Exception {
RegDataSet train = TRECFormat.loadRegDataSet("/Users/chengli/Downloads/spam/train", DataSetType.REG_DENSE, true);
RegDataSet test = TRECFormat.loadRegDataSet("/Users/chengli/Downloads/spam/test", DataSetType.REG_DENSE, true);
LSLogisticBoost lsLogisticBoost = new LSLogisticBoost();
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(10);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
LSLogisticBoostOptimizer optimizer = new LSLogisticBoostOptimizer(lsLogisticBoost, train, regTreeFactory);
optimizer.setShrinkage(1);
optimizer.initialize();
for (int i = 0; i < 200; i++) {
System.out.println("training rmse = " + RMSE.rmse(lsLogisticBoost, train));
System.out.println("test rmse = " + RMSE.rmse(lsLogisticBoost, test));
optimizer.iterate();
}
for (int i = 0; i < test.getNumDataPoints(); i++) {
System.out.println(test.getLabels()[i] + " " + lsLogisticBoost.predict(test.getRow(i)));
}
System.out.println("********************** LSBOOST **************");
LSBoost lsBoost = new LSBoost();
LSBoostOptimizer lsBoostOptimizer = new LSBoostOptimizer(lsBoost, train, regTreeFactory);
lsBoostOptimizer.setShrinkage(0.1);
lsBoostOptimizer.initialize();
for (int i = 0; i < 100; i++) {
System.out.println("training rmse = " + RMSE.rmse(lsBoost, train));
System.out.println("test rmse = " + RMSE.rmse(lsBoost, test));
lsBoostOptimizer.iterate();
}
for (int i = 0; i < test.getNumDataPoints(); i++) {
System.out.println(test.getLabels()[i] + " " + lsBoost.predict(test.getRow(i)));
}
}
use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.
the class LSLogisticBoostTest method test2.
private static void test2() throws Exception {
RegDataSet train = TRECFormat.loadRegDataSet("/Users/chengli/Downloads/spam/train", DataSetType.REG_DENSE, true);
RegDataSet test = TRECFormat.loadRegDataSet("/Users/chengli/Downloads/spam/test", DataSetType.REG_DENSE, true);
for (int i = 0; i < train.getNumDataPoints(); i++) {
if (train.getLabels()[i] == 0) {
train.setLabel(i, -1);
} else {
train.setLabel(i, 2);
}
}
for (int i = 0; i < test.getNumDataPoints(); i++) {
if (test.getLabels()[i] == 0) {
test.setLabel(i, -1);
} else {
test.setLabel(i, 2);
}
}
LSLogisticBoost lsLogisticBoost = new LSLogisticBoost();
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(5);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
LSLogisticBoostOptimizer optimizer = new LSLogisticBoostOptimizer(lsLogisticBoost, train, regTreeFactory);
optimizer.initialize();
for (int i = 0; i < 100; i++) {
System.out.println("training rmse = " + RMSE.rmse(lsLogisticBoost, train));
System.out.println("test rmse = " + RMSE.rmse(lsLogisticBoost, test));
optimizer.iterate();
}
for (int i = 0; i < test.getNumDataPoints(); i++) {
System.out.println(test.getLabels()[i] + " " + lsLogisticBoost.predict(test.getRow(i)));
}
// System.out.println(Arrays.toString(lsLogisticBoost.predict(train)));
LSBoost lsBoost = new LSBoost();
LSBoostOptimizer lsBoostOptimizer = new LSBoostOptimizer(lsBoost, train, regTreeFactory);
System.out.println("LSBOOST");
lsBoostOptimizer.initialize();
for (int i = 0; i < 100; i++) {
System.out.println("training rmse = " + RMSE.rmse(lsBoost, train));
System.out.println("test rmse = " + RMSE.rmse(lsBoost, test));
lsBoostOptimizer.iterate();
}
}
use of edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig in project pyramid by cheng-li.
the class LambdaMARTOptimizerTest method test1.
private static void test1() throws Exception {
RegDataSet train = TRECFormat.loadRegDataSet("/Users/chengli/Downloads/spam/train", DataSetType.REG_DENSE, true);
train = DataSetUtil.shuffleRows(train, 0);
RegDataSet test = TRECFormat.loadRegDataSet("/Users/chengli/Downloads/spam/test", DataSetType.REG_DENSE, true);
test = DataSetUtil.shuffleRows(test, 0);
LambdaMART lambdaMART = new LambdaMART();
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(5);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
List<List<Integer>> instancesInQuery = new ArrayList<>();
List<Integer> all = IntStream.range(0, train.getNumDataPoints()).boxed().collect(Collectors.toList());
instancesInQuery.add(all);
LambdaMARTOptimizer optimizer = new LambdaMARTOptimizer(lambdaMART, train, train.getLabels(), regTreeFactory, instancesInQuery);
optimizer.initialize();
for (int i = 0; i < 500; i++) {
System.out.println("==================================");
System.out.println("iter " + i + "");
System.out.println("ndcg = " + NDCG.ndcg(test.getLabels(), lambdaMART.predict(test)));
// for (int j=0;j<train.getNumDataPoints();j++){
// System.out.println("label = "+train.getLabels()[j]+" pred = "+lambdaMART.predict(train.getRow(j)));
// }
optimizer.iterate();
}
}
Aggregations