Search in sources :

Example 1 with EarlyStopper

use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.

the class RerankerTrainer method trainWithSigmoid.

public Reranker trainWithSigmoid(RegDataSet regDataSet, double[] instanceWeights, MultiLabelClassifier.ClassProbEstimator classProbEstimator, PredictionFeatureExtractor predictionFeatureExtractor, LabelCalibrator labelCalibrator, RegDataSet validation, double[] noiseRates0, double[] noiseRates1) {
    LSLogisticBoost lsLogisticBoost = new LSLogisticBoost();
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeaves).setMinDataPerLeaf(minDataPerLeaf).setMonotonicityType(monotonicityType);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    LSLogisticBoostOptimizer optimizer = new LSLogisticBoostOptimizer(lsLogisticBoost, regDataSet, regTreeFactory, instanceWeights, regDataSet.getLabels(), noiseRates0, noiseRates1);
    // optimizer.setNoiseRates1(noiseRates1);
    if (!monotonicityType.equals("none")) {
        int[][] mono = new int[1][regDataSet.getNumFeatures()];
        mono[0] = predictionFeatureExtractor.featureMonotonicity();
        optimizer.setMonotonicity(mono);
    }
    optimizer.setShrinkage(shrinkage);
    optimizer.initialize();
    EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, 5);
    LSLogisticBoost bestModel = null;
    for (int i = 1; i <= maxIter; i++) {
        optimizer.iterate();
        if (i % 10 == 0) {
            double mse = MSE.mse(lsLogisticBoost, validation);
            // todo
            // double trainMse = MSE.mse(lsLogisticBoost, regDataSet);
            // System.out.println("iter="+i+", train mse="+trainMse+" , valid mse="+mse);
            earlyStopper.add(i, mse);
            if (earlyStopper.getBestIteration() == i) {
                try {
                    bestModel = (LSLogisticBoost) Serialization.deepCopy(lsLogisticBoost);
                } catch (IOException e) {
                    e.printStackTrace();
                } catch (ClassNotFoundException e) {
                    e.printStackTrace();
                }
            }
            if (earlyStopper.shouldStop()) {
                break;
            }
        }
    }
    return new Reranker(bestModel, classProbEstimator, numCandidates, predictionFeatureExtractor, labelCalibrator);
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) LSLogisticBoostOptimizer(edu.neu.ccs.pyramid.regression.ls_logistic_boost.LSLogisticBoostOptimizer) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) LSLogisticBoost(edu.neu.ccs.pyramid.regression.ls_logistic_boost.LSLogisticBoost) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) IOException(java.io.IOException)

Example 2 with EarlyStopper

use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.

the class CBMEN method loadNewEarlyStopper.

private static EarlyStopper loadNewEarlyStopper(Config config) {
    String earlyStopMetric = config.getString("tune.targetMetric");
    int patience = config.getInt("tune.earlyStop.patience");
    EarlyStopper.Goal earlyStopGoal = null;
    switch(earlyStopMetric) {
        case "instance_set_accuracy":
            earlyStopGoal = EarlyStopper.Goal.MAXIMIZE;
            break;
        case "instance_f1":
            earlyStopGoal = EarlyStopper.Goal.MAXIMIZE;
            break;
        case "instance_hamming_loss":
            earlyStopGoal = EarlyStopper.Goal.MINIMIZE;
            break;
        case "label_map":
            earlyStopGoal = EarlyStopper.Goal.MAXIMIZE;
            break;
        default:
            throw new IllegalArgumentException("unsupported tune.targetMetric " + earlyStopMetric);
    }
    EarlyStopper earlyStopper = new EarlyStopper(earlyStopGoal, patience);
    earlyStopper.setMinimumIterations(config.getInt("tune.earlyStop.minIterations"));
    return earlyStopper;
}
Also used : EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper)

Example 3 with EarlyStopper

use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.

the class CBMGB method loadNewEarlyStopper.

private static EarlyStopper loadNewEarlyStopper(Config config) {
    String earlyStopMetric = config.getString("tune.targetMetric");
    int patience = config.getInt("tune.earlyStop.patience");
    EarlyStopper.Goal earlyStopGoal = null;
    switch(earlyStopMetric) {
        case "instance_set_accuracy":
            earlyStopGoal = EarlyStopper.Goal.MAXIMIZE;
            break;
        case "instance_f1":
            earlyStopGoal = EarlyStopper.Goal.MAXIMIZE;
            break;
        case "instance_hamming_loss":
            earlyStopGoal = EarlyStopper.Goal.MINIMIZE;
            break;
        case "instance_log_likelihood":
            earlyStopGoal = EarlyStopper.Goal.MAXIMIZE;
            break;
        default:
            throw new IllegalArgumentException("unsupported tune.targetMetric " + earlyStopMetric);
    }
    EarlyStopper earlyStopper = new EarlyStopper(earlyStopGoal, patience);
    earlyStopper.setMinimumIterations(config.getInt("tune.earlyStop.minIterations"));
    return earlyStopper;
}
Also used : EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper)

Example 4 with EarlyStopper

use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.

the class CBMLR method tune.

private static TuneResult tune(Config config, HyperParameters hyperParameters, MultiLabelClfDataSet trainSet, MultiLabelClfDataSet validSet) throws Exception {
    List<Integer> unobservedLabels = DataSetUtil.unobservedLabels(trainSet);
    CBM cbm = newCBM(config, trainSet, hyperParameters);
    EarlyStopper earlyStopper = loadNewEarlyStopper(config);
    LRCBMOptimizer optimizer = getOptimizer(config, hyperParameters, cbm, trainSet);
    optimizer.initialize();
    MultiLabelClassifier classifier;
    String predictTarget = config.getString("tune.targetMetric");
    switch(predictTarget) {
        case "instance_set_accuracy":
            AccPredictor accPredictor = new AccPredictor(cbm);
            accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
            classifier = accPredictor;
            break;
        case "instance_f1":
            PluginF1 pluginF1 = new PluginF1(cbm);
            List<MultiLabel> support = DataSetUtil.gatherMultiLabels(trainSet);
            pluginF1.setSupport(support);
            pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
            classifier = pluginF1;
            break;
        case "instance_hamming_loss":
            MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
            marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
            classifier = marginalPredictor;
            break;
        case "instance_log_likelihood":
            // acc predictor seems to be the best match for log likelihood
            AccPredictor accPredictor1 = new AccPredictor(cbm);
            accPredictor1.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
            classifier = accPredictor1;
            break;
        default:
            throw new IllegalArgumentException("predictTarget should be instance_set_accuracy, instance_f1 or instance_hamming_loss");
    }
    int interval = config.getInt("tune.monitorInterval");
    for (int iter = 1; true; iter++) {
        if (VERBOSE) {
            System.out.println("iteration " + iter);
        }
        optimizer.iterate();
        if (iter % interval == 0) {
            MLMeasures validMeasures = new MLMeasures(classifier, validSet);
            if (VERBOSE) {
                System.out.println("validation performance with " + predictTarget + " optimal predictor:");
                System.out.println(validMeasures);
            }
            switch(predictTarget) {
                case "instance_set_accuracy":
                    earlyStopper.add(iter, validMeasures.getInstanceAverage().getAccuracy());
                    break;
                case "instance_f1":
                    earlyStopper.add(iter, validMeasures.getInstanceAverage().getF1());
                    break;
                case "instance_hamming_loss":
                    earlyStopper.add(iter, validMeasures.getInstanceAverage().getHammingLoss());
                    break;
                case "instance_log_likelihood":
                    earlyStopper.add(iter, LogLikelihood.averageLogLikelihood(cbm, validSet, unobservedLabels));
                    break;
                default:
                    throw new IllegalArgumentException("predictTarget should be instance_set_accuracy, instance_f1 or instance_hamming_loss");
            }
            if (earlyStopper.shouldStop()) {
                if (VERBOSE) {
                    System.out.println("Early Stopper: the training should stop now!");
                }
                break;
            }
        }
    }
    if (VERBOSE) {
        System.out.println("done!");
    }
    hyperParameters.iterations = earlyStopper.getBestIteration();
    TuneResult tuneResult = new TuneResult();
    tuneResult.hyperParameters = hyperParameters;
    tuneResult.performance = earlyStopper.getBestValue();
    return tuneResult;
}
Also used : EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures)

Example 5 with EarlyStopper

use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.

the class CBMLR method loadNewEarlyStopper.

private static EarlyStopper loadNewEarlyStopper(Config config) {
    String earlyStopMetric = config.getString("tune.targetMetric");
    int patience = config.getInt("tune.earlyStop.patience");
    EarlyStopper.Goal earlyStopGoal = null;
    switch(earlyStopMetric) {
        case "instance_set_accuracy":
            earlyStopGoal = EarlyStopper.Goal.MAXIMIZE;
            break;
        case "instance_f1":
            earlyStopGoal = EarlyStopper.Goal.MAXIMIZE;
            break;
        case "instance_hamming_loss":
            earlyStopGoal = EarlyStopper.Goal.MINIMIZE;
            break;
        case "instance_log_likelihood":
            earlyStopGoal = EarlyStopper.Goal.MAXIMIZE;
            break;
        default:
            throw new IllegalArgumentException("unsupported tune.targetMetric " + earlyStopMetric);
    }
    EarlyStopper earlyStopper = new EarlyStopper(earlyStopGoal, patience);
    earlyStopper.setMinimumIterations(config.getInt("tune.earlyStop.minIterations"));
    return earlyStopper;
}
Also used : EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper)

Aggregations

EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)13 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)5 IOException (java.io.IOException)5 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)4 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)3 Config (edu.neu.ccs.pyramid.configuration.Config)3 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)3 Feature (edu.neu.ccs.pyramid.feature.Feature)3 TopFeatures (edu.neu.ccs.pyramid.feature.TopFeatures)3 Paths (java.nio.file.Paths)3 FileHandler (java.util.logging.FileHandler)3 Logger (java.util.logging.Logger)3 SimpleFormatter (java.util.logging.SimpleFormatter)3 Collectors (java.util.stream.Collectors)3 IntStream (java.util.stream.IntStream)3 FileUtils (org.apache.commons.io.FileUtils)3 StopWatch (org.apache.commons.lang3.time.StopWatch)3 JsonEncoding (com.fasterxml.jackson.core.JsonEncoding)2 JsonFactory (com.fasterxml.jackson.core.JsonFactory)2 JsonGenerator (com.fasterxml.jackson.core.JsonGenerator)2