Search in sources :

Example 1 with LKBOutputCalculator

use of edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator in project pyramid by cheng-li.

the class SparkCBMOptimizer method updateMultiClassBoost.

private void updateMultiClassBoost() {
    int numComponents = cbm.numComponents;
    int numIterations = numIterationsMultiClass;
    double shrinkage = shrinkageMultiClass;
    LKBoost boost = (LKBoost) this.cbm.multiClassClassifier;
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeavesMultiClass);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(numComponents));
    LKBoostOptimizer optimizer = new LKBoostOptimizer(boost, dataSet, regTreeFactory, gammas);
    optimizer.setShrinkage(shrinkage);
    optimizer.initialize();
    optimizer.iterate(numIterations);
}
Also used : LKBoostOptimizer(edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) LKBOutputCalculator(edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator) LKBoost(edu.neu.ccs.pyramid.classification.lkboost.LKBoost)

Example 2 with LKBOutputCalculator

use of edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator in project pyramid by cheng-li.

the class GBCBMOptimizer method updateMultiClassClassifier.

@Override
protected void updateMultiClassClassifier() {
    if (logger.isDebugEnabled()) {
        logger.debug("start updateMultiClassClassifier");
    }
    // parallel
    LKBoost boost = (LKBoost) this.cbm.multiClassClassifier;
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeaves);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(cbm.getNumComponents()));
    LKBoostOptimizer optimizer = new LKBoostOptimizer(boost, dataSet, regTreeFactory, gammas);
    optimizer.setShrinkage(shrinkage);
    optimizer.initialize();
    optimizer.iterate(multiclassUpdatesPerIter);
    if (logger.isDebugEnabled()) {
        logger.debug("finish updateMultiClassClassifier");
    }
}
Also used : LKBoostOptimizer(edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) LKBOutputCalculator(edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator) LKBoost(edu.neu.ccs.pyramid.classification.lkboost.LKBoost)

Example 3 with LKBOutputCalculator

use of edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator in project pyramid by cheng-li.

the class CBMUtilityOptimizer method updateMultiClassBoost.

private void updateMultiClassBoost() {
    int numComponents = cbm.numComponents;
    int numIterations = numIterationsMultiClass;
    double shrinkage = shrinkageMultiClass;
    LKBoost boost = (LKBoost) this.cbm.multiClassClassifier;
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeavesMultiClass);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(numComponents));
    LKBoostOptimizer optimizer = new LKBoostOptimizer(boost, dataSet, regTreeFactory, gammas);
    optimizer.setShrinkage(shrinkage);
    optimizer.initialize();
    optimizer.iterate(numIterations);
}
Also used : LKBoostOptimizer(edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) LKBOutputCalculator(edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator) LKBoost(edu.neu.ccs.pyramid.classification.lkboost.LKBoost)

Example 4 with LKBOutputCalculator

use of edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator in project pyramid by cheng-li.

the class CBMUtilityOptimizer method updateBinaryBoosting.

private void updateBinaryBoosting(int componentIndex, int labelIndex) {
    int numIterations = numIterationsBinary;
    double shrinkage = shrinkageBinary;
    LKBoost boost = (LKBoost) this.cbm.binaryClassifiers[componentIndex][labelIndex];
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeavesBinary);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(2));
    LKBoostOptimizer optimizer = new LKBoostOptimizer(boost, dataSet, regTreeFactory, gammasT[componentIndex], binaryTargetsDistributions[labelIndex]);
    optimizer.setShrinkage(shrinkage);
    optimizer.initialize();
    optimizer.iterate(numIterations);
}
Also used : LKBoostOptimizer(edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) LKBOutputCalculator(edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator) LKBoost(edu.neu.ccs.pyramid.classification.lkboost.LKBoost)

Example 5 with LKBOutputCalculator

use of edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator in project pyramid by cheng-li.

the class GBClassifier method train.

private static void train(Config config) throws Exception {
    String sparsity = config.getString("input.matrixType");
    DataSetType dataSetType = null;
    switch(sparsity) {
        case "dense":
            dataSetType = DataSetType.CLF_DENSE;
            break;
        case "sparse":
            dataSetType = DataSetType.CLF_SPARSE;
            break;
        default:
            throw new IllegalArgumentException("input.matrixType should be dense or sparse");
    }
    ClfDataSet trainSet = TRECFormat.loadClfDataSet(config.getString("input.trainData"), dataSetType, true);
    ClfDataSet testSet = null;
    if (config.getBoolean("train.showTestProgress")) {
        testSet = TRECFormat.loadClfDataSet(config.getString("input.testData"), dataSetType, true);
    }
    int numClasses = trainSet.getNumClasses();
    LKBoost lkBoost = new LKBoost(numClasses);
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(config.getInt("train.numLeaves"));
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    regTreeFactory.setLeafOutputCalculator(new LKBOutputCalculator(numClasses));
    LKBoostOptimizer optimizer = new LKBoostOptimizer(lkBoost, trainSet, regTreeFactory);
    optimizer.setShrinkage(config.getDouble("train.shrinkage"));
    optimizer.initialize();
    int progressInterval = config.getInt("train.showProgress.interval");
    int numIterations = config.getInt("train.numIterations");
    for (int i = 1; i <= numIterations; i++) {
        System.out.println("iteration " + i);
        optimizer.iterate();
        if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
            System.out.println("training accuracy = " + Accuracy.accuracy(lkBoost, trainSet));
        }
        if (config.getBoolean("train.showTestProgress") && (i % progressInterval == 0 || i == numIterations)) {
            System.out.println("test accuracy = " + Accuracy.accuracy(lkBoost, testSet));
        }
    }
    System.out.println("training done!");
    String output = config.getString("output.folder");
    new File(output).mkdirs();
    File serializedModel = new File(output, "model");
    Serialization.serialize(lkBoost, serializedModel);
    System.out.println("model saved to " + serializedModel.getAbsolutePath());
    File reportFile = new File(output, "train_predictions.txt");
    report(lkBoost, trainSet, reportFile);
    System.out.println("predictions on the training set are written to " + reportFile.getAbsolutePath());
    File probabilitiesFile = new File(output, "train_predicted_probabilities.txt");
    probabilities(lkBoost, trainSet, probabilitiesFile);
    System.out.println("predicted probabilities on the training set are written to " + probabilitiesFile.getAbsolutePath());
}
Also used : LKBoostOptimizer(edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) DataSetType(edu.neu.ccs.pyramid.dataset.DataSetType) ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) LKBOutputCalculator(edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator) File(java.io.File) LKBoost(edu.neu.ccs.pyramid.classification.lkboost.LKBoost)

Aggregations

LKBOutputCalculator (edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator)10 LKBoost (edu.neu.ccs.pyramid.classification.lkboost.LKBoost)10 LKBoostOptimizer (edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer)10 RegTreeConfig (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig)10 RegTreeFactory (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory)10 PriorProbClassifier (edu.neu.ccs.pyramid.classification.PriorProbClassifier)1 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)1 DataSetType (edu.neu.ccs.pyramid.dataset.DataSetType)1 File (java.io.File)1 StopWatch (org.apache.commons.lang3.time.StopWatch)1