Search in sources :

Example 6 with EarlyStopper

use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.

the class Calibration method trainCalibrator.

private static LSBoost trainCalibrator(RegDataSet calib, RegDataSet valid, int[] monotonicity) {
    LSBoost lsBoost = new LSBoost();
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(10).setMinDataPerLeaf(5);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, calib, regTreeFactory, calib.getLabels());
    if (true) {
        int[][] mono = new int[1][calib.getNumFeatures()];
        mono[0] = monotonicity;
        optimizer.setMonotonicity(mono);
    }
    optimizer.setShrinkage(0.1);
    optimizer.initialize();
    EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, 5);
    LSBoost bestModel = null;
    for (int i = 1; i < 1000; i++) {
        optimizer.iterate();
        if (i % 10 == 0) {
            double mse = MSE.mse(lsBoost, valid);
            earlyStopper.add(i, mse);
            if (earlyStopper.getBestIteration() == i) {
                try {
                    bestModel = (LSBoost) Serialization.deepCopy(lsBoost);
                } catch (IOException e) {
                    e.printStackTrace();
                } catch (ClassNotFoundException e) {
                    e.printStackTrace();
                }
            }
            if (earlyStopper.shouldStop()) {
                break;
            }
        }
    }
    return bestModel;
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) LSBoostOptimizer(edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoostOptimizer) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) IOException(java.io.IOException) LSBoost(edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoost)

Example 7 with EarlyStopper

use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.

the class App2 method train.

static void train(Config config, Logger logger) throws Exception {
    String output = config.getString("output.folder");
    int numIterations = config.getInt("train.numIterations");
    int numLeaves = config.getInt("train.numLeaves");
    double learningRate = config.getDouble("train.learningRate");
    int minDataPerLeaf = config.getInt("train.minDataPerLeaf");
    String modelName = "model_app3";
    //        double featureSamplingRate = config.getDouble("train.featureSamplingRate");
    //        double dataSamplingRate = config.getDouble("train.dataSamplingRate");
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    MultiLabelClfDataSet dataSet = loadData(config, config.getString("input.trainData"));
    MultiLabelClfDataSet testSet = null;
    if (config.getBoolean("train.showTestProgress")) {
        testSet = loadData(config, config.getString("input.testData"));
    }
    int numClasses = dataSet.getNumClasses();
    logger.info("number of class = " + numClasses);
    IMLGBConfig imlgbConfig = new IMLGBConfig.Builder(dataSet).learningRate(learningRate).minDataPerLeaf(minDataPerLeaf).numLeaves(numLeaves).numSplitIntervals(config.getInt("train.numSplitIntervals")).usePrior(config.getBoolean("train.usePrior")).build();
    IMLGradientBoosting boosting;
    if (config.getBoolean("train.warmStart")) {
        boosting = IMLGradientBoosting.deserialize(new File(output, modelName));
    } else {
        boosting = new IMLGradientBoosting(numClasses);
    }
    logger.info("During training, the performance is reported using Hamming loss optimal predictor");
    logger.info("initialing trainer");
    IMLGBTrainer trainer = new IMLGBTrainer(imlgbConfig, boosting);
    boolean earlyStop = config.getBoolean("train.earlyStop");
    List<EarlyStopper> earlyStoppers = new ArrayList<>();
    List<Terminator> terminators = new ArrayList<>();
    if (earlyStop) {
        for (int l = 0; l < numClasses; l++) {
            EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, config.getInt("train.earlyStop.patience"));
            earlyStopper.setMinimumIterations(config.getInt("train.earlyStop.minIterations"));
            earlyStoppers.add(earlyStopper);
        }
        for (int l = 0; l < numClasses; l++) {
            Terminator terminator = new Terminator();
            terminator.setMaxStableIterations(config.getInt("train.earlyStop.patience")).setMinIterations(config.getInt("train.earlyStop.minIterations") / config.getInt("train.showProgress.interval")).setAbsoluteEpsilon(config.getDouble("train.earlyStop.absoluteChange")).setRelativeEpsilon(config.getDouble("train.earlyStop.relativeChange")).setOperation(Terminator.Operation.OR);
            terminators.add(terminator);
        }
    }
    logger.info("trainer initialized");
    int numLabelsLeftToTrain = numClasses;
    int progressInterval = config.getInt("train.showProgress.interval");
    for (int i = 1; i <= numIterations; i++) {
        logger.info("iteration " + i);
        trainer.iterate();
        if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
            logger.info("training set performance");
            logger.info(new MLMeasures(boosting, dataSet).toString());
        }
        if (config.getBoolean("train.showTestProgress") && (i % progressInterval == 0 || i == numIterations)) {
            logger.info("test set performance");
            logger.info(new MLMeasures(boosting, testSet).toString());
            if (earlyStop) {
                for (int l = 0; l < numClasses; l++) {
                    EarlyStopper earlyStopper = earlyStoppers.get(l);
                    Terminator terminator = terminators.get(l);
                    if (!trainer.getShouldStop()[l]) {
                        double kl = KL(boosting, testSet, l);
                        earlyStopper.add(i, kl);
                        terminator.add(kl);
                        if (earlyStopper.shouldStop() || terminator.shouldTerminate()) {
                            logger.info("training for label " + l + " (" + dataSet.getLabelTranslator().toExtLabel(l) + ") should stop now");
                            logger.info("the best number of training iterations for the label is " + earlyStopper.getBestIteration());
                            trainer.setShouldStop(l);
                            numLabelsLeftToTrain -= 1;
                            logger.info("the number of labels left to be trained on = " + numLabelsLeftToTrain);
                        }
                    }
                }
            }
        }
        if (numLabelsLeftToTrain == 0) {
            logger.info("all label training finished");
            break;
        }
    }
    logger.info("training done");
    File serializedModel = new File(output, modelName);
    //todo pick best models
    boosting.serialize(serializedModel);
    logger.info(stopWatch.toString());
    if (earlyStop) {
        for (int l = 0; l < numClasses; l++) {
            logger.info("----------------------------------------------------");
            logger.info("test performance history for label " + l + ": " + earlyStoppers.get(l).history());
            logger.info("model size for label " + l + " = " + (boosting.getRegressors(l).size() - 1));
        }
    }
    boolean topFeaturesToFile = true;
    if (topFeaturesToFile) {
        logger.info("start writing top features");
        int limit = config.getInt("report.topFeatures.limit");
        List<TopFeatures> topFeaturesList = IntStream.range(0, boosting.getNumClasses()).mapToObj(k -> IMLGBInspector.topFeatures(boosting, k, limit)).collect(Collectors.toList());
        ObjectMapper mapper = new ObjectMapper();
        String file = "top_features.json";
        mapper.writeValue(new File(output, file), topFeaturesList);
        StringBuilder sb = new StringBuilder();
        for (int l = 0; l < boosting.getNumClasses(); l++) {
            sb.append("-------------------------").append("\n");
            sb.append(dataSet.getLabelTranslator().toExtLabel(l)).append(":").append("\n");
            for (Feature feature : topFeaturesList.get(l).getTopFeatures()) {
                sb.append(feature.simpleString()).append(", ");
            }
            sb.append("\n");
        }
        FileUtils.writeStringToFile(new File(output, "top_features.txt"), sb.toString());
        logger.info("finish writing top features");
    }
}
Also used : IntStream(java.util.stream.IntStream) JsonGenerator(com.fasterxml.jackson.core.JsonGenerator) SimpleFormatter(java.util.logging.SimpleFormatter) PluginPredictor(edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor) edu.neu.ccs.pyramid.multilabel_classification.imlgb(edu.neu.ccs.pyramid.multilabel_classification.imlgb) ArrayList(java.util.ArrayList) Terminator(edu.neu.ccs.pyramid.optimization.Terminator) FileHandler(java.util.logging.FileHandler) FeatureDistribution(edu.neu.ccs.pyramid.feature_selection.FeatureDistribution) JsonEncoding(com.fasterxml.jackson.core.JsonEncoding) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) edu.neu.ccs.pyramid.eval(edu.neu.ccs.pyramid.eval) Collection(java.util.Collection) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) TunedMarginalClassifier(edu.neu.ccs.pyramid.multilabel_classification.thresholding.TunedMarginalClassifier) FileWriter(java.io.FileWriter) Set(java.util.Set) FileUtils(org.apache.commons.io.FileUtils) IOException(java.io.IOException) StopWatch(org.apache.commons.lang3.time.StopWatch) Logger(java.util.logging.Logger) Collectors(java.util.stream.Collectors) File(java.io.File) Progress(edu.neu.ccs.pyramid.util.Progress) List(java.util.List) JsonFactory(com.fasterxml.jackson.core.JsonFactory) Feature(edu.neu.ccs.pyramid.feature.Feature) MacroFMeasureTuner(edu.neu.ccs.pyramid.multilabel_classification.thresholding.MacroFMeasureTuner) Serialization(edu.neu.ccs.pyramid.util.Serialization) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) SetUtil(edu.neu.ccs.pyramid.util.SetUtil) ArrayList(java.util.ArrayList) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Terminator(edu.neu.ccs.pyramid.optimization.Terminator) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) Feature(edu.neu.ccs.pyramid.feature.Feature) StopWatch(org.apache.commons.lang3.time.StopWatch) File(java.io.File) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper)

Example 8 with EarlyStopper

use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.

the class CBMGB method tune.

private static TuneResult tune(Config config, HyperParameters hyperParameters, MultiLabelClfDataSet trainSet, MultiLabelClfDataSet validSet) throws Exception {
    List<Integer> unobservedLabels = DataSetUtil.unobservedLabels(trainSet);
    CBM cbm = newCBM(config, trainSet, hyperParameters);
    EarlyStopper earlyStopper = loadNewEarlyStopper(config);
    GBCBMOptimizer optimizer = getOptimizer(config, hyperParameters, cbm, trainSet);
    optimizer.initialize();
    MultiLabelClassifier classifier;
    String predictTarget = config.getString("tune.targetMetric");
    switch(predictTarget) {
        case "instance_set_accuracy":
            AccPredictor accPredictor = new AccPredictor(cbm);
            accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
            classifier = accPredictor;
            break;
        case "instance_f1":
            PluginF1 pluginF1 = new PluginF1(cbm);
            List<MultiLabel> support = DataSetUtil.gatherMultiLabels(trainSet);
            pluginF1.setSupport(support);
            pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
            classifier = pluginF1;
            break;
        case "instance_hamming_loss":
            MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
            marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
            classifier = marginalPredictor;
            break;
        case "instance_log_likelihood":
            // acc predictor seems to be the best match for log likelihood
            AccPredictor accPredictor1 = new AccPredictor(cbm);
            accPredictor1.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
            classifier = accPredictor1;
            break;
        default:
            throw new IllegalArgumentException("predictTarget should be instance_set_accuracy, instance_f1 or instance_hamming_loss");
    }
    int interval = config.getInt("tune.monitorInterval");
    for (int iter = 1; true; iter++) {
        if (VERBOSE) {
            System.out.println("iteration " + iter);
        }
        optimizer.iterate();
        if (iter % interval == 0) {
            MLMeasures validMeasures = new MLMeasures(classifier, validSet);
            if (VERBOSE) {
                System.out.println("validation performance with " + predictTarget + " optimal predictor:");
                System.out.println(validMeasures);
            }
            switch(predictTarget) {
                case "instance_set_accuracy":
                    earlyStopper.add(iter, validMeasures.getInstanceAverage().getAccuracy());
                    break;
                case "instance_f1":
                    earlyStopper.add(iter, validMeasures.getInstanceAverage().getF1());
                    break;
                case "instance_hamming_loss":
                    earlyStopper.add(iter, validMeasures.getInstanceAverage().getHammingLoss());
                    break;
                case "instance_log_likelihood":
                    earlyStopper.add(iter, LogLikelihood.averageLogLikelihood(cbm, validSet, unobservedLabels));
                    break;
                default:
                    throw new IllegalArgumentException("predictTarget should be instance_set_accuracy or instance_f1");
            }
            if (earlyStopper.shouldStop()) {
                if (VERBOSE) {
                    System.out.println("Early Stopper: the training should stop now!");
                }
                break;
            }
        }
    }
    if (VERBOSE) {
        System.out.println("done!");
    }
    hyperParameters.iterations = earlyStopper.getBestIteration();
    TuneResult tuneResult = new TuneResult();
    tuneResult.hyperParameters = hyperParameters;
    tuneResult.performance = earlyStopper.getBestValue();
    return tuneResult;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures)

Example 9 with EarlyStopper

use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.

the class BRLREN method train.

private static void train(Config config, HyperParameters hyperParameters, MultiLabelClfDataSet trainSet, MultiLabelClfDataSet validSet, Logger logger) throws Exception {
    List<Integer> unobservedLabels = DataSetUtil.unobservedLabels(trainSet);
    if (!unobservedLabels.isEmpty()) {
        logger.info("The following labels do not actually appear in the training set and therefore cannot be learned:");
        logger.info(ListUtil.toSimpleString(unobservedLabels));
        FileUtils.writeStringToFile(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "analysis", "unobserved_labels.txt").toFile(), ListUtil.toSimpleString(unobservedLabels));
    }
    String output = config.getString("output.dir");
    EarlyStopper earlyStopper = loadNewEarlyStopper();
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    CBM cbm = newCBM(config, trainSet, hyperParameters, logger);
    ENCBMOptimizer optimizer = getOptimizer(config, hyperParameters, cbm, trainSet);
    logger.info("Initializing the model");
    if (config.getBoolean("train.randomInitialize")) {
        optimizer.randInitialize();
    } else {
        optimizer.initialize();
    }
    logger.info("Initialization done");
    AccPredictor accPredictor = new AccPredictor(cbm);
    accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
    int interval = 1;
    for (int iter = 1; true; iter++) {
        logger.info("Training progress: iteration " + iter);
        optimizer.iterate();
        if (iter % interval == 0) {
            MLMeasures validMeasures = new MLMeasures(accPredictor, validSet);
            if (VERBOSE) {
                logger.info("validation performance");
                logger.info(validMeasures.toString());
            }
            earlyStopper.add(iter, validMeasures.getInstanceAverage().getAccuracy());
            if (earlyStopper.getBestIteration() == iter) {
                Serialization.serialize(cbm, Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "classifier"));
            }
            if (earlyStopper.shouldStop()) {
                if (VERBOSE) {
                    logger.info("Early Stopper: the training should stop now!");
                    logger.info("Early Stopper: best iteration found = " + earlyStopper.getBestIteration());
                    logger.info("Early Stopper: best validation performance = " + earlyStopper.getBestValue());
                }
                break;
            }
        }
    }
    logger.info("training done!");
    logger.info("time spent on training = " + stopWatch);
    List<MultiLabel> support = DataSetUtil.gatherMultiLabels(trainSet);
    Serialization.serialize(support, Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "support"));
    CBM bestModel = (CBM) Serialization.deserialize(Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "classifier"));
    boolean topFeaturesToFile = true;
    if (topFeaturesToFile) {
        File analysisFolder = Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "analysis").toFile();
        analysisFolder.mkdirs();
        logger.info("start writing top features");
        List<TopFeatures> topFeaturesList = IntStream.range(0, bestModel.getNumClasses()).mapToObj(k -> topFeatures(bestModel, trainSet.getFeatureList(), trainSet.getLabelTranslator(), k, 100)).collect(Collectors.toList());
        ObjectMapper mapper = new ObjectMapper();
        String file = "top_features.json";
        mapper.writeValue(new File(analysisFolder, file), topFeaturesList);
        StringBuilder sb = new StringBuilder();
        for (int l = 0; l < bestModel.getNumClasses(); l++) {
            sb.append("-------------------------").append("\n");
            sb.append(bestModel.getLabelTranslator().toExtLabel(l)).append(":").append("\n");
            for (Feature feature : topFeaturesList.get(l).getTopFeatures()) {
                sb.append(feature.simpleString()).append(", ");
            }
            sb.append("\n");
        }
        FileUtils.writeStringToFile(new File(analysisFolder, "top_features.txt"), sb.toString());
        logger.info("finish writing top features");
    }
}
Also used : IntStream(java.util.stream.IntStream) java.util(java.util) SimpleFormatter(java.util.logging.SimpleFormatter) FileHandler(java.util.logging.FileHandler) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) MAP(edu.neu.ccs.pyramid.eval.MAP) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) IOException(java.io.IOException) StopWatch(org.apache.commons.lang3.time.StopWatch) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList) Logger(java.util.logging.Logger) Collectors(java.util.stream.Collectors) File(java.io.File) IMLGBInspector(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBInspector) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Feature(edu.neu.ccs.pyramid.feature.Feature) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) LogisticRegressionInspector(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegressionInspector) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) Feature(edu.neu.ccs.pyramid.feature.Feature) StopWatch(org.apache.commons.lang3.time.StopWatch) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) File(java.io.File) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper)

Example 10 with EarlyStopper

use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.

the class BRLREN method loadNewEarlyStopper.

private static EarlyStopper loadNewEarlyStopper() {
    int patience = 5;
    EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MAXIMIZE, patience);
    earlyStopper.setMinimumIterations(5);
    return earlyStopper;
}
Also used : EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper)

Aggregations

EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)13 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)5 IOException (java.io.IOException)5 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)4 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)3 Config (edu.neu.ccs.pyramid.configuration.Config)3 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)3 Feature (edu.neu.ccs.pyramid.feature.Feature)3 TopFeatures (edu.neu.ccs.pyramid.feature.TopFeatures)3 Paths (java.nio.file.Paths)3 FileHandler (java.util.logging.FileHandler)3 Logger (java.util.logging.Logger)3 SimpleFormatter (java.util.logging.SimpleFormatter)3 Collectors (java.util.stream.Collectors)3 IntStream (java.util.stream.IntStream)3 FileUtils (org.apache.commons.io.FileUtils)3 StopWatch (org.apache.commons.lang3.time.StopWatch)3 JsonEncoding (com.fasterxml.jackson.core.JsonEncoding)2 JsonFactory (com.fasterxml.jackson.core.JsonFactory)2 JsonGenerator (com.fasterxml.jackson.core.JsonGenerator)2