Search in sources :

Example 1 with LSLogisticBoost

use of edu.neu.ccs.pyramid.regression.ls_logistic_boost.LSLogisticBoost in project pyramid by cheng-li.

the class RerankerTrainer method trainWithSigmoid.

public Reranker trainWithSigmoid(RegDataSet regDataSet, double[] instanceWeights, MultiLabelClassifier.ClassProbEstimator classProbEstimator, PredictionFeatureExtractor predictionFeatureExtractor, LabelCalibrator labelCalibrator, RegDataSet validation, double[] noiseRates0, double[] noiseRates1) {
    LSLogisticBoost lsLogisticBoost = new LSLogisticBoost();
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeaves).setMinDataPerLeaf(minDataPerLeaf).setMonotonicityType(monotonicityType);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    LSLogisticBoostOptimizer optimizer = new LSLogisticBoostOptimizer(lsLogisticBoost, regDataSet, regTreeFactory, instanceWeights, regDataSet.getLabels(), noiseRates0, noiseRates1);
    // optimizer.setNoiseRates1(noiseRates1);
    if (!monotonicityType.equals("none")) {
        int[][] mono = new int[1][regDataSet.getNumFeatures()];
        mono[0] = predictionFeatureExtractor.featureMonotonicity();
        optimizer.setMonotonicity(mono);
    }
    optimizer.setShrinkage(shrinkage);
    optimizer.initialize();
    EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, 5);
    LSLogisticBoost bestModel = null;
    for (int i = 1; i <= maxIter; i++) {
        optimizer.iterate();
        if (i % 10 == 0) {
            double mse = MSE.mse(lsLogisticBoost, validation);
            // todo
            // double trainMse = MSE.mse(lsLogisticBoost, regDataSet);
            // System.out.println("iter="+i+", train mse="+trainMse+" , valid mse="+mse);
            earlyStopper.add(i, mse);
            if (earlyStopper.getBestIteration() == i) {
                try {
                    bestModel = (LSLogisticBoost) Serialization.deepCopy(lsLogisticBoost);
                } catch (IOException e) {
                    e.printStackTrace();
                } catch (ClassNotFoundException e) {
                    e.printStackTrace();
                }
            }
            if (earlyStopper.shouldStop()) {
                break;
            }
        }
    }
    return new Reranker(bestModel, classProbEstimator, numCandidates, predictionFeatureExtractor, labelCalibrator);
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) LSLogisticBoostOptimizer(edu.neu.ccs.pyramid.regression.ls_logistic_boost.LSLogisticBoostOptimizer) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) LSLogisticBoost(edu.neu.ccs.pyramid.regression.ls_logistic_boost.LSLogisticBoost) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) IOException(java.io.IOException)

Aggregations

EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)1 LSLogisticBoost (edu.neu.ccs.pyramid.regression.ls_logistic_boost.LSLogisticBoost)1 LSLogisticBoostOptimizer (edu.neu.ccs.pyramid.regression.ls_logistic_boost.LSLogisticBoostOptimizer)1 RegTreeConfig (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig)1 RegTreeFactory (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory)1 IOException (java.io.IOException)1