Search in sources :

Example 11 with EarlyStopper

use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.

the class RerankerTrainer method train.

public Reranker train(RegDataSet regDataSet, double[] instanceWeights, MultiLabelClassifier.ClassProbEstimator classProbEstimator, PredictionFeatureExtractor predictionFeatureExtractor, LabelCalibrator labelCalibrator, RegDataSet validation) {
    LSBoost lsBoost = new LSBoost();
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(numLeaves).setMinDataPerLeaf(minDataPerLeaf).setMonotonicityType(monotonicityType);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, regDataSet, regTreeFactory, instanceWeights, regDataSet.getLabels());
    if (!monotonicityType.equals("none")) {
        int[][] mono = new int[1][regDataSet.getNumFeatures()];
        mono[0] = predictionFeatureExtractor.featureMonotonicity();
        optimizer.setMonotonicity(mono);
    }
    optimizer.setShrinkage(shrinkage);
    optimizer.initialize();
    EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, 5);
    LSBoost bestModel = null;
    for (int i = 1; i <= maxIter; i++) {
        optimizer.iterate();
        if (i % 10 == 0 || i == maxIter) {
            double mse = MSE.mse(lsBoost, validation);
            earlyStopper.add(i, mse);
            if (earlyStopper.getBestIteration() == i) {
                try {
                    bestModel = (LSBoost) Serialization.deepCopy(lsBoost);
                } catch (IOException e) {
                    e.printStackTrace();
                } catch (ClassNotFoundException e) {
                    e.printStackTrace();
                }
            }
            if (earlyStopper.shouldStop()) {
                break;
            }
        }
    }
    // System.out.println("best iteration = "+earlyStopper.getBestIteration());
    return new Reranker(bestModel, classProbEstimator, numCandidates, predictionFeatureExtractor, labelCalibrator);
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) LSBoostOptimizer(edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoostOptimizer) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) IOException(java.io.IOException) LSBoost(edu.neu.ccs.pyramid.regression.least_squares_boost.LSBoost)

Example 12 with EarlyStopper

use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.

the class BRGB method train.

static void train(Config config, Logger logger) throws Exception {
    String output = config.getString("output.folder");
    int numIterations = config.getInt("train.numIterations");
    int numLeaves = config.getInt("train.numLeaves");
    double learningRate = config.getDouble("train.learningRate");
    int minDataPerLeaf = config.getInt("train.minDataPerLeaf");
    int randomSeed = config.getInt("train.randomSeed");
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    MultiLabelClfDataSet allTrainData = loadData(config, config.getString("input.trainData"));
    double[] instanceWeights = new double[allTrainData.getNumDataPoints()];
    Arrays.fill(instanceWeights, 1.0);
    if (config.getBoolean("train.useInstanceWeights")) {
        instanceWeights = loadInstanceWeights(config);
    }
    MultiLabelClfDataSet trainSetForEval = minibatch(allTrainData, instanceWeights, config.getInt("train.showProgress.sampleSize"), 0 + randomSeed).getFirst();
    MultiLabelClfDataSet validSet = loadData(config, config.getString("input.validData"));
    List<MultiLabel> support = DataSetUtil.gatherMultiLabels(allTrainData);
    Serialization.serialize(support, Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "support"));
    int numClasses = allTrainData.getNumClasses();
    logger.info("number of class = " + numClasses);
    IMLGradientBoosting boosting;
    List<EarlyStopper> earlyStoppers;
    List<Terminator> terminators;
    boolean[] shouldStop;
    int numLabelsLeftToTrain;
    int startIter;
    List<Pair<Integer, Double>> trainingTime;
    List<Pair<Integer, Double>> accuracy;
    double startTime = 0;
    boolean earlyStop = config.getBoolean("train.earlyStop");
    CheckPoint checkPoint;
    if (config.getBoolean("train.warmStart")) {
        checkPoint = (CheckPoint) Serialization.deserialize(Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "checkpoint"));
        boosting = checkPoint.boosting;
        earlyStoppers = checkPoint.earlyStoppers;
        terminators = checkPoint.terminators;
        shouldStop = checkPoint.shouldStop;
        numLabelsLeftToTrain = checkPoint.numLabelsLeftToTrain;
        startIter = checkPoint.lastIter + 1;
        trainingTime = checkPoint.trainingTime;
        accuracy = checkPoint.accuracy;
        startTime = checkPoint.trainingTime.get(trainingTime.size() - 1).getSecond();
    } else {
        boosting = new IMLGradientBoosting(numClasses);
        earlyStoppers = new ArrayList<>();
        terminators = new ArrayList<>();
        trainingTime = new ArrayList<>();
        accuracy = new ArrayList<>();
        if (earlyStop) {
            for (int l = 0; l < numClasses; l++) {
                EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, config.getInt("train.earlyStop.patience"));
                earlyStopper.setMinimumIterations(config.getInt("train.earlyStop.minIterations"));
                earlyStoppers.add(earlyStopper);
            }
            for (int l = 0; l < numClasses; l++) {
                Terminator terminator = new Terminator();
                terminator.setMaxStableIterations(config.getInt("train.earlyStop.patience")).setMinIterations(config.getInt("train.earlyStop.minIterations") / config.getInt("train.showProgress.interval")).setAbsoluteEpsilon(config.getDouble("train.earlyStop.absoluteChange")).setRelativeEpsilon(config.getDouble("train.earlyStop.relativeChange")).setOperation(Terminator.Operation.OR);
                terminators.add(terminator);
            }
        }
        shouldStop = new boolean[allTrainData.getNumClasses()];
        numLabelsLeftToTrain = numClasses;
        checkPoint = new CheckPoint();
        checkPoint.boosting = boosting;
        checkPoint.earlyStoppers = earlyStoppers;
        checkPoint.terminators = terminators;
        checkPoint.shouldStop = shouldStop;
        // this is not a pointer, has to be updated
        checkPoint.numLabelsLeftToTrain = numLabelsLeftToTrain;
        checkPoint.lastIter = 0;
        checkPoint.trainingTime = trainingTime;
        checkPoint.accuracy = accuracy;
        startIter = 1;
    }
    logger.info("During training, the performance is reported using Hamming loss optimal predictor. The performance is computed approximately with " + config.getInt("train.showProgress.sampleSize") + " instances.");
    int progressInterval = config.getInt("train.showProgress.interval");
    int interval = config.getInt("train.fullScanInterval");
    int minibatchLifeSpan = config.getInt("train.minibatchLifeSpan");
    int numActiveFeatures = config.getInt("train.numActiveFeatures");
    int numofLabels = allTrainData.getNumClasses();
    List<Integer>[] activeFeaturesLists = new ArrayList[numofLabels];
    for (int labelnum = 0; labelnum < numofLabels; labelnum++) {
        activeFeaturesLists[labelnum] = new ArrayList<>();
    }
    MultiLabelClfDataSet trainBatch = null;
    IMLGBTrainer trainer = null;
    StopWatch timeWatch = new StopWatch();
    timeWatch.start();
    for (int i = startIter; i <= numIterations; i++) {
        logger.info("iteration " + i);
        if (i % minibatchLifeSpan == 1 || i == startIter) {
            Pair<MultiLabelClfDataSet, double[]> sampled = minibatch(allTrainData, instanceWeights, config.getInt("train.batchSize"), i + randomSeed);
            trainBatch = sampled.getFirst();
            IMLGBConfig imlgbConfig = new IMLGBConfig.Builder(trainBatch).learningRate(learningRate).minDataPerLeaf(minDataPerLeaf).numLeaves(numLeaves).numSplitIntervals(config.getInt("train.numSplitIntervals")).usePrior(config.getBoolean("train.usePrior")).numActiveFeatures(numActiveFeatures).build();
            trainer = new IMLGBTrainer(imlgbConfig, boosting, shouldStop);
            trainer.setInstanceWeights(sampled.getSecond());
        }
        if (i % interval == 1) {
            trainer.iterate(activeFeaturesLists, true);
        } else {
            trainer.iterate(activeFeaturesLists, false);
        }
        checkPoint.lastIter += 1;
        if (earlyStop && (i % progressInterval == 0 || i == numIterations)) {
            for (int l = 0; l < numClasses; l++) {
                EarlyStopper earlyStopper = earlyStoppers.get(l);
                Terminator terminator = terminators.get(l);
                if (!shouldStop[l]) {
                    double kl = KL(boosting, validSet, l);
                    earlyStopper.add(i, kl);
                    terminator.add(kl);
                    if (earlyStopper.shouldStop() || terminator.shouldTerminate()) {
                        logger.info("training for label " + l + " (" + allTrainData.getLabelTranslator().toExtLabel(l) + ") should stop now");
                        logger.info("the best number of training iterations for the label is " + earlyStopper.getBestIteration());
                        if (i != earlyStopper.getBestIteration()) {
                            boosting.cutTail(l, earlyStopper.getBestIteration());
                            logger.info("roll back the model for this label to iteration " + earlyStopper.getBestIteration());
                        }
                        shouldStop[l] = true;
                        numLabelsLeftToTrain -= 1;
                        checkPoint.numLabelsLeftToTrain = numLabelsLeftToTrain;
                        logger.info("the number of labels left to be trained on = " + numLabelsLeftToTrain);
                    }
                }
            }
        }
        if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
            logger.info("training set performance (computed approximately with Hamming loss predictor on " + config.getInt("train.showProgress.sampleSize") + " instances).");
            logger.info(new MLMeasures(boosting, trainSetForEval).toString());
        }
        if (config.getBoolean("train.showValidProgress") && (i % progressInterval == 0 || i == numIterations)) {
            logger.info("validation set performance (computed approximately with Hamming loss predictor)");
            MLMeasures validPerformance = new MLMeasures(boosting, validSet);
            logger.info(validPerformance.toString());
            accuracy.add(new Pair<>(i, validPerformance.getInstanceAverage().getF1()));
        }
        trainingTime.add(new Pair<>(i, startTime + timeWatch.getTime() / 1000.0));
        Serialization.serialize(checkPoint, Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "checkpoint"));
        Serialization.serialize(boosting, Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "classifier"));
        if (numLabelsLeftToTrain == 0) {
            logger.info("all label training finished");
            break;
        }
    }
    logger.info("training done");
    logger.info(stopWatch.toString());
    File analysisFolder = Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "analysis").toFile();
    if (true) {
        ObjectMapper objectMapper = new ObjectMapper();
        List<LabelModel> labelModels = IMLGBInspector.getAllRules(boosting);
        new File(analysisFolder, "decision_rules").mkdirs();
        for (int l = 0; l < boosting.getNumClasses(); l++) {
            objectMapper.writeValue(Paths.get(analysisFolder.toString(), "decision_rules", l + ".json").toFile(), labelModels.get(l));
        }
    }
    boolean topFeaturesToFile = true;
    if (topFeaturesToFile) {
        logger.info("start writing top features");
        List<TopFeatures> topFeaturesList = IntStream.range(0, boosting.getNumClasses()).mapToObj(k -> IMLGBInspector.topFeatures(boosting, k, Integer.MAX_VALUE)).collect(Collectors.toList());
        ObjectMapper mapper = new ObjectMapper();
        String file = "top_features.json";
        mapper.writeValue(new File(analysisFolder, file), topFeaturesList);
        StringBuilder sb = new StringBuilder();
        for (int l = 0; l < boosting.getNumClasses(); l++) {
            sb.append("-------------------------").append("\n");
            sb.append(allTrainData.getLabelTranslator().toExtLabel(l)).append(":").append("\n");
            for (Feature feature : topFeaturesList.get(l).getTopFeatures()) {
                sb.append(feature.simpleString()).append(", ");
            }
            sb.append("\n");
        }
        FileUtils.writeStringToFile(new File(analysisFolder, "top_features.txt"), sb.toString());
        logger.info("finish writing top features");
    }
}
Also used : IntStream(java.util.stream.IntStream) Visualizer(edu.neu.ccs.pyramid.visualization.Visualizer) edu.neu.ccs.pyramid.util(edu.neu.ccs.pyramid.util) java.util(java.util) JsonGenerator(com.fasterxml.jackson.core.JsonGenerator) SimpleFormatter(java.util.logging.SimpleFormatter) PluginPredictor(edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor) edu.neu.ccs.pyramid.multilabel_classification.imlgb(edu.neu.ccs.pyramid.multilabel_classification.imlgb) Terminator(edu.neu.ccs.pyramid.optimization.Terminator) FileHandler(java.util.logging.FileHandler) JsonEncoding(com.fasterxml.jackson.core.JsonEncoding) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) TunedMarginalClassifier(edu.neu.ccs.pyramid.multilabel_classification.thresholding.TunedMarginalClassifier) FMeasure(edu.neu.ccs.pyramid.eval.FMeasure) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Logger(java.util.logging.Logger) Collectors(java.util.stream.Collectors) KLDivergence(edu.neu.ccs.pyramid.eval.KLDivergence) JsonFactory(com.fasterxml.jackson.core.JsonFactory) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Feature(edu.neu.ccs.pyramid.feature.Feature) MacroFMeasureTuner(edu.neu.ccs.pyramid.multilabel_classification.thresholding.MacroFMeasureTuner) java.io(java.io) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) EnumeratedIntegerDistribution(org.apache.commons.math3.distribution.EnumeratedIntegerDistribution) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Terminator(edu.neu.ccs.pyramid.optimization.Terminator) Feature(edu.neu.ccs.pyramid.feature.Feature) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) StopWatch(org.apache.commons.lang3.time.StopWatch)

Example 13 with EarlyStopper

use of edu.neu.ccs.pyramid.optimization.EarlyStopper in project pyramid by cheng-li.

the class CBMEN method tune.

private static TuneResult tune(Config config, HyperParameters hyperParameters, MultiLabelClfDataSet trainSet, MultiLabelClfDataSet validSet) throws Exception {
    CBM cbm = newCBM(config, trainSet, hyperParameters);
    EarlyStopper earlyStopper = loadNewEarlyStopper(config);
    ENCBMOptimizer optimizer = getOptimizer(config, hyperParameters, cbm, trainSet);
    if (config.getBoolean("train.randomInitialize")) {
        optimizer.randInitialize();
    } else {
        optimizer.initialize();
    }
    MultiLabelClassifier classifier;
    String predictTarget = config.getString("tune.targetMetric");
    switch(predictTarget) {
        case "instance_set_accuracy":
            AccPredictor accPredictor = new AccPredictor(cbm);
            accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
            classifier = accPredictor;
            break;
        case "instance_f1":
            PluginF1 pluginF1 = new PluginF1(cbm);
            List<MultiLabel> support = DataSetUtil.gatherMultiLabels(trainSet);
            pluginF1.setSupport(support);
            pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
            classifier = pluginF1;
            break;
        case "instance_hamming_loss":
            MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
            marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
            classifier = marginalPredictor;
            break;
        case "label_map":
            AccPredictor accPredictor2 = new AccPredictor(cbm);
            accPredictor2.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
            classifier = accPredictor2;
            break;
        default:
            throw new IllegalArgumentException("predictTarget should be instance_set_accuracy, instance_f1 or instance_hamming_loss");
    }
    int interval = config.getInt("tune.monitorInterval");
    for (int iter = 1; true; iter++) {
        if (VERBOSE) {
            System.out.println("iteration " + iter);
        }
        optimizer.iterate();
        if (iter % interval == 0) {
            MLMeasures validMeasures = new MLMeasures(classifier, validSet);
            if (VERBOSE) {
                System.out.println("validation performance with " + predictTarget + " optimal predictor:");
                System.out.println(validMeasures);
            }
            switch(predictTarget) {
                case "instance_set_accuracy":
                    earlyStopper.add(iter, validMeasures.getInstanceAverage().getAccuracy());
                    break;
                case "instance_f1":
                    earlyStopper.add(iter, validMeasures.getInstanceAverage().getF1());
                    break;
                case "instance_hamming_loss":
                    earlyStopper.add(iter, validMeasures.getInstanceAverage().getHammingLoss());
                    break;
                case "label_map":
                    List<MultiLabel> support = DataSetUtil.gatherMultiLabels(trainSet);
                    double map = MAP.mapBySupport(cbm, validSet, support);
                    earlyStopper.add(iter, map);
                    break;
                default:
                    throw new IllegalArgumentException("predictTarget should be instance_set_accuracy or instance_f1");
            }
            if (earlyStopper.shouldStop()) {
                if (VERBOSE) {
                    System.out.println("Early Stopper: the training should stop now!");
                }
                break;
            }
        }
    }
    if (VERBOSE) {
        System.out.println("done!");
    }
    hyperParameters.iterations = earlyStopper.getBestIteration();
    TuneResult tuneResult = new TuneResult();
    tuneResult.hyperParameters = hyperParameters;
    tuneResult.performance = earlyStopper.getBestValue();
    return tuneResult;
}
Also used : EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures)

Aggregations

EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)13 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)5 IOException (java.io.IOException)5 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)4 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)3 Config (edu.neu.ccs.pyramid.configuration.Config)3 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)3 Feature (edu.neu.ccs.pyramid.feature.Feature)3 TopFeatures (edu.neu.ccs.pyramid.feature.TopFeatures)3 Paths (java.nio.file.Paths)3 FileHandler (java.util.logging.FileHandler)3 Logger (java.util.logging.Logger)3 SimpleFormatter (java.util.logging.SimpleFormatter)3 Collectors (java.util.stream.Collectors)3 IntStream (java.util.stream.IntStream)3 FileUtils (org.apache.commons.io.FileUtils)3 StopWatch (org.apache.commons.lang3.time.StopWatch)3 JsonEncoding (com.fasterxml.jackson.core.JsonEncoding)2 JsonFactory (com.fasterxml.jackson.core.JsonFactory)2 JsonGenerator (com.fasterxml.jackson.core.JsonGenerator)2