use of edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator in project pyramid by cheng-li.
the class SparkCBMOptimizer method updateMultiClassBoost.
private void updateMultiClassBoost() {
int numComponents = cbm.numComponents;
int numIterations = numIterationsMultiClass;
double shrinkage = shrinkageMultiClass;
LKBoost boost = (LKBoost) this.cbm.multiClassClassifier;
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeavesMultiClass);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(numComponents));
LKBoostOptimizer optimizer = new LKBoostOptimizer(boost, dataSet, regTreeFactory, gammas);
optimizer.setShrinkage(shrinkage);
optimizer.initialize();
optimizer.iterate(numIterations);
}
use of edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator in project pyramid by cheng-li.
the class GBCBMOptimizer method updateMultiClassClassifier.
@Override
protected void updateMultiClassClassifier() {
if (logger.isDebugEnabled()) {
logger.debug("start updateMultiClassClassifier");
}
// parallel
LKBoost boost = (LKBoost) this.cbm.multiClassClassifier;
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeaves);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(cbm.getNumComponents()));
LKBoostOptimizer optimizer = new LKBoostOptimizer(boost, dataSet, regTreeFactory, gammas);
optimizer.setShrinkage(shrinkage);
optimizer.initialize();
optimizer.iterate(multiclassUpdatesPerIter);
if (logger.isDebugEnabled()) {
logger.debug("finish updateMultiClassClassifier");
}
}
use of edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator in project pyramid by cheng-li.
the class CBMUtilityOptimizer method updateMultiClassBoost.
private void updateMultiClassBoost() {
int numComponents = cbm.numComponents;
int numIterations = numIterationsMultiClass;
double shrinkage = shrinkageMultiClass;
LKBoost boost = (LKBoost) this.cbm.multiClassClassifier;
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeavesMultiClass);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(numComponents));
LKBoostOptimizer optimizer = new LKBoostOptimizer(boost, dataSet, regTreeFactory, gammas);
optimizer.setShrinkage(shrinkage);
optimizer.initialize();
optimizer.iterate(numIterations);
}
use of edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator in project pyramid by cheng-li.
the class CBMUtilityOptimizer method updateBinaryBoosting.
private void updateBinaryBoosting(int componentIndex, int labelIndex) {
int numIterations = numIterationsBinary;
double shrinkage = shrinkageBinary;
LKBoost boost = (LKBoost) this.cbm.binaryClassifiers[componentIndex][labelIndex];
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeavesBinary);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(2));
LKBoostOptimizer optimizer = new LKBoostOptimizer(boost, dataSet, regTreeFactory, gammasT[componentIndex], binaryTargetsDistributions[labelIndex]);
optimizer.setShrinkage(shrinkage);
optimizer.initialize();
optimizer.iterate(numIterations);
}
use of edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator in project pyramid by cheng-li.
the class GBClassifier method train.
private static void train(Config config) throws Exception {
String sparsity = config.getString("input.matrixType");
DataSetType dataSetType = null;
switch(sparsity) {
case "dense":
dataSetType = DataSetType.CLF_DENSE;
break;
case "sparse":
dataSetType = DataSetType.CLF_SPARSE;
break;
default:
throw new IllegalArgumentException("input.matrixType should be dense or sparse");
}
ClfDataSet trainSet = TRECFormat.loadClfDataSet(config.getString("input.trainData"), dataSetType, true);
ClfDataSet testSet = null;
if (config.getBoolean("train.showTestProgress")) {
testSet = TRECFormat.loadClfDataSet(config.getString("input.testData"), dataSetType, true);
}
int numClasses = trainSet.getNumClasses();
LKBoost lkBoost = new LKBoost(numClasses);
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(config.getInt("train.numLeaves"));
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(numClasses));
LKBoostOptimizer optimizer = new LKBoostOptimizer(lkBoost, trainSet, regTreeFactory);
optimizer.setShrinkage(config.getDouble("train.shrinkage"));
optimizer.initialize();
int progressInterval = config.getInt("train.showProgress.interval");
int numIterations = config.getInt("train.numIterations");
for (int i = 1; i <= numIterations; i++) {
System.out.println("iteration " + i);
optimizer.iterate();
if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
System.out.println("training accuracy = " + Accuracy.accuracy(lkBoost, trainSet));
}
if (config.getBoolean("train.showTestProgress") && (i % progressInterval == 0 || i == numIterations)) {
System.out.println("test accuracy = " + Accuracy.accuracy(lkBoost, testSet));
}
}
System.out.println("training done!");
String output = config.getString("output.folder");
new File(output).mkdirs();
File serializedModel = new File(output, "model");
Serialization.serialize(lkBoost, serializedModel);
System.out.println("model saved to " + serializedModel.getAbsolutePath());
File reportFile = new File(output, "train_predictions.txt");
report(lkBoost, trainSet, reportFile);
System.out.println("predictions on the training set are written to " + reportFile.getAbsolutePath());
File probabilitiesFile = new File(output, "train_predicted_probabilities.txt");
probabilities(lkBoost, trainSet, probabilitiesFile);
System.out.println("predicted probabilities on the training set are written to " + probabilitiesFile.getAbsolutePath());
}
Aggregations