Search in sources :

Example 1 with Terminator

use of edu.neu.ccs.pyramid.optimization.Terminator in project pyramid by cheng-li.

the class App2 method train.

static void train(Config config, Logger logger) throws Exception {
    String output = config.getString("output.folder");
    int numIterations = config.getInt("train.numIterations");
    int numLeaves = config.getInt("train.numLeaves");
    double learningRate = config.getDouble("train.learningRate");
    int minDataPerLeaf = config.getInt("train.minDataPerLeaf");
    String modelName = "model_app3";
    //        double featureSamplingRate = config.getDouble("train.featureSamplingRate");
    //        double dataSamplingRate = config.getDouble("train.dataSamplingRate");
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    MultiLabelClfDataSet dataSet = loadData(config, config.getString("input.trainData"));
    MultiLabelClfDataSet testSet = null;
    if (config.getBoolean("train.showTestProgress")) {
        testSet = loadData(config, config.getString("input.testData"));
    }
    int numClasses = dataSet.getNumClasses();
    logger.info("number of class = " + numClasses);
    IMLGBConfig imlgbConfig = new IMLGBConfig.Builder(dataSet).learningRate(learningRate).minDataPerLeaf(minDataPerLeaf).numLeaves(numLeaves).numSplitIntervals(config.getInt("train.numSplitIntervals")).usePrior(config.getBoolean("train.usePrior")).build();
    IMLGradientBoosting boosting;
    if (config.getBoolean("train.warmStart")) {
        boosting = IMLGradientBoosting.deserialize(new File(output, modelName));
    } else {
        boosting = new IMLGradientBoosting(numClasses);
    }
    logger.info("During training, the performance is reported using Hamming loss optimal predictor");
    logger.info("initialing trainer");
    IMLGBTrainer trainer = new IMLGBTrainer(imlgbConfig, boosting);
    boolean earlyStop = config.getBoolean("train.earlyStop");
    List<EarlyStopper> earlyStoppers = new ArrayList<>();
    List<Terminator> terminators = new ArrayList<>();
    if (earlyStop) {
        for (int l = 0; l < numClasses; l++) {
            EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, config.getInt("train.earlyStop.patience"));
            earlyStopper.setMinimumIterations(config.getInt("train.earlyStop.minIterations"));
            earlyStoppers.add(earlyStopper);
        }
        for (int l = 0; l < numClasses; l++) {
            Terminator terminator = new Terminator();
            terminator.setMaxStableIterations(config.getInt("train.earlyStop.patience")).setMinIterations(config.getInt("train.earlyStop.minIterations") / config.getInt("train.showProgress.interval")).setAbsoluteEpsilon(config.getDouble("train.earlyStop.absoluteChange")).setRelativeEpsilon(config.getDouble("train.earlyStop.relativeChange")).setOperation(Terminator.Operation.OR);
            terminators.add(terminator);
        }
    }
    logger.info("trainer initialized");
    int numLabelsLeftToTrain = numClasses;
    int progressInterval = config.getInt("train.showProgress.interval");
    for (int i = 1; i <= numIterations; i++) {
        logger.info("iteration " + i);
        trainer.iterate();
        if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
            logger.info("training set performance");
            logger.info(new MLMeasures(boosting, dataSet).toString());
        }
        if (config.getBoolean("train.showTestProgress") && (i % progressInterval == 0 || i == numIterations)) {
            logger.info("test set performance");
            logger.info(new MLMeasures(boosting, testSet).toString());
            if (earlyStop) {
                for (int l = 0; l < numClasses; l++) {
                    EarlyStopper earlyStopper = earlyStoppers.get(l);
                    Terminator terminator = terminators.get(l);
                    if (!trainer.getShouldStop()[l]) {
                        double kl = KL(boosting, testSet, l);
                        earlyStopper.add(i, kl);
                        terminator.add(kl);
                        if (earlyStopper.shouldStop() || terminator.shouldTerminate()) {
                            logger.info("training for label " + l + " (" + dataSet.getLabelTranslator().toExtLabel(l) + ") should stop now");
                            logger.info("the best number of training iterations for the label is " + earlyStopper.getBestIteration());
                            trainer.setShouldStop(l);
                            numLabelsLeftToTrain -= 1;
                            logger.info("the number of labels left to be trained on = " + numLabelsLeftToTrain);
                        }
                    }
                }
            }
        }
        if (numLabelsLeftToTrain == 0) {
            logger.info("all label training finished");
            break;
        }
    }
    logger.info("training done");
    File serializedModel = new File(output, modelName);
    //todo pick best models
    boosting.serialize(serializedModel);
    logger.info(stopWatch.toString());
    if (earlyStop) {
        for (int l = 0; l < numClasses; l++) {
            logger.info("----------------------------------------------------");
            logger.info("test performance history for label " + l + ": " + earlyStoppers.get(l).history());
            logger.info("model size for label " + l + " = " + (boosting.getRegressors(l).size() - 1));
        }
    }
    boolean topFeaturesToFile = true;
    if (topFeaturesToFile) {
        logger.info("start writing top features");
        int limit = config.getInt("report.topFeatures.limit");
        List<TopFeatures> topFeaturesList = IntStream.range(0, boosting.getNumClasses()).mapToObj(k -> IMLGBInspector.topFeatures(boosting, k, limit)).collect(Collectors.toList());
        ObjectMapper mapper = new ObjectMapper();
        String file = "top_features.json";
        mapper.writeValue(new File(output, file), topFeaturesList);
        StringBuilder sb = new StringBuilder();
        for (int l = 0; l < boosting.getNumClasses(); l++) {
            sb.append("-------------------------").append("\n");
            sb.append(dataSet.getLabelTranslator().toExtLabel(l)).append(":").append("\n");
            for (Feature feature : topFeaturesList.get(l).getTopFeatures()) {
                sb.append(feature.simpleString()).append(", ");
            }
            sb.append("\n");
        }
        FileUtils.writeStringToFile(new File(output, "top_features.txt"), sb.toString());
        logger.info("finish writing top features");
    }
}
Also used : IntStream(java.util.stream.IntStream) JsonGenerator(com.fasterxml.jackson.core.JsonGenerator) SimpleFormatter(java.util.logging.SimpleFormatter) PluginPredictor(edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor) edu.neu.ccs.pyramid.multilabel_classification.imlgb(edu.neu.ccs.pyramid.multilabel_classification.imlgb) ArrayList(java.util.ArrayList) Terminator(edu.neu.ccs.pyramid.optimization.Terminator) FileHandler(java.util.logging.FileHandler) FeatureDistribution(edu.neu.ccs.pyramid.feature_selection.FeatureDistribution) JsonEncoding(com.fasterxml.jackson.core.JsonEncoding) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) edu.neu.ccs.pyramid.eval(edu.neu.ccs.pyramid.eval) Collection(java.util.Collection) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) TunedMarginalClassifier(edu.neu.ccs.pyramid.multilabel_classification.thresholding.TunedMarginalClassifier) FileWriter(java.io.FileWriter) Set(java.util.Set) FileUtils(org.apache.commons.io.FileUtils) IOException(java.io.IOException) StopWatch(org.apache.commons.lang3.time.StopWatch) Logger(java.util.logging.Logger) Collectors(java.util.stream.Collectors) File(java.io.File) Progress(edu.neu.ccs.pyramid.util.Progress) List(java.util.List) JsonFactory(com.fasterxml.jackson.core.JsonFactory) Feature(edu.neu.ccs.pyramid.feature.Feature) MacroFMeasureTuner(edu.neu.ccs.pyramid.multilabel_classification.thresholding.MacroFMeasureTuner) Serialization(edu.neu.ccs.pyramid.util.Serialization) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) SetUtil(edu.neu.ccs.pyramid.util.SetUtil) ArrayList(java.util.ArrayList) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Terminator(edu.neu.ccs.pyramid.optimization.Terminator) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) Feature(edu.neu.ccs.pyramid.feature.Feature) StopWatch(org.apache.commons.lang3.time.StopWatch) File(java.io.File) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper)

Example 2 with Terminator

use of edu.neu.ccs.pyramid.optimization.Terminator in project pyramid by cheng-li.

the class BRGB method train.

static void train(Config config, Logger logger) throws Exception {
    String output = config.getString("output.folder");
    int numIterations = config.getInt("train.numIterations");
    int numLeaves = config.getInt("train.numLeaves");
    double learningRate = config.getDouble("train.learningRate");
    int minDataPerLeaf = config.getInt("train.minDataPerLeaf");
    int randomSeed = config.getInt("train.randomSeed");
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    MultiLabelClfDataSet allTrainData = loadData(config, config.getString("input.trainData"));
    double[] instanceWeights = new double[allTrainData.getNumDataPoints()];
    Arrays.fill(instanceWeights, 1.0);
    if (config.getBoolean("train.useInstanceWeights")) {
        instanceWeights = loadInstanceWeights(config);
    }
    MultiLabelClfDataSet trainSetForEval = minibatch(allTrainData, instanceWeights, config.getInt("train.showProgress.sampleSize"), 0 + randomSeed).getFirst();
    MultiLabelClfDataSet validSet = loadData(config, config.getString("input.validData"));
    List<MultiLabel> support = DataSetUtil.gatherMultiLabels(allTrainData);
    Serialization.serialize(support, Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "support"));
    int numClasses = allTrainData.getNumClasses();
    logger.info("number of class = " + numClasses);
    IMLGradientBoosting boosting;
    List<EarlyStopper> earlyStoppers;
    List<Terminator> terminators;
    boolean[] shouldStop;
    int numLabelsLeftToTrain;
    int startIter;
    List<Pair<Integer, Double>> trainingTime;
    List<Pair<Integer, Double>> accuracy;
    double startTime = 0;
    boolean earlyStop = config.getBoolean("train.earlyStop");
    CheckPoint checkPoint;
    if (config.getBoolean("train.warmStart")) {
        checkPoint = (CheckPoint) Serialization.deserialize(Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "checkpoint"));
        boosting = checkPoint.boosting;
        earlyStoppers = checkPoint.earlyStoppers;
        terminators = checkPoint.terminators;
        shouldStop = checkPoint.shouldStop;
        numLabelsLeftToTrain = checkPoint.numLabelsLeftToTrain;
        startIter = checkPoint.lastIter + 1;
        trainingTime = checkPoint.trainingTime;
        accuracy = checkPoint.accuracy;
        startTime = checkPoint.trainingTime.get(trainingTime.size() - 1).getSecond();
    } else {
        boosting = new IMLGradientBoosting(numClasses);
        earlyStoppers = new ArrayList<>();
        terminators = new ArrayList<>();
        trainingTime = new ArrayList<>();
        accuracy = new ArrayList<>();
        if (earlyStop) {
            for (int l = 0; l < numClasses; l++) {
                EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, config.getInt("train.earlyStop.patience"));
                earlyStopper.setMinimumIterations(config.getInt("train.earlyStop.minIterations"));
                earlyStoppers.add(earlyStopper);
            }
            for (int l = 0; l < numClasses; l++) {
                Terminator terminator = new Terminator();
                terminator.setMaxStableIterations(config.getInt("train.earlyStop.patience")).setMinIterations(config.getInt("train.earlyStop.minIterations") / config.getInt("train.showProgress.interval")).setAbsoluteEpsilon(config.getDouble("train.earlyStop.absoluteChange")).setRelativeEpsilon(config.getDouble("train.earlyStop.relativeChange")).setOperation(Terminator.Operation.OR);
                terminators.add(terminator);
            }
        }
        shouldStop = new boolean[allTrainData.getNumClasses()];
        numLabelsLeftToTrain = numClasses;
        checkPoint = new CheckPoint();
        checkPoint.boosting = boosting;
        checkPoint.earlyStoppers = earlyStoppers;
        checkPoint.terminators = terminators;
        checkPoint.shouldStop = shouldStop;
        // this is not a pointer, has to be updated
        checkPoint.numLabelsLeftToTrain = numLabelsLeftToTrain;
        checkPoint.lastIter = 0;
        checkPoint.trainingTime = trainingTime;
        checkPoint.accuracy = accuracy;
        startIter = 1;
    }
    logger.info("During training, the performance is reported using Hamming loss optimal predictor. The performance is computed approximately with " + config.getInt("train.showProgress.sampleSize") + " instances.");
    int progressInterval = config.getInt("train.showProgress.interval");
    int interval = config.getInt("train.fullScanInterval");
    int minibatchLifeSpan = config.getInt("train.minibatchLifeSpan");
    int numActiveFeatures = config.getInt("train.numActiveFeatures");
    int numofLabels = allTrainData.getNumClasses();
    List<Integer>[] activeFeaturesLists = new ArrayList[numofLabels];
    for (int labelnum = 0; labelnum < numofLabels; labelnum++) {
        activeFeaturesLists[labelnum] = new ArrayList<>();
    }
    MultiLabelClfDataSet trainBatch = null;
    IMLGBTrainer trainer = null;
    StopWatch timeWatch = new StopWatch();
    timeWatch.start();
    for (int i = startIter; i <= numIterations; i++) {
        logger.info("iteration " + i);
        if (i % minibatchLifeSpan == 1 || i == startIter) {
            Pair<MultiLabelClfDataSet, double[]> sampled = minibatch(allTrainData, instanceWeights, config.getInt("train.batchSize"), i + randomSeed);
            trainBatch = sampled.getFirst();
            IMLGBConfig imlgbConfig = new IMLGBConfig.Builder(trainBatch).learningRate(learningRate).minDataPerLeaf(minDataPerLeaf).numLeaves(numLeaves).numSplitIntervals(config.getInt("train.numSplitIntervals")).usePrior(config.getBoolean("train.usePrior")).numActiveFeatures(numActiveFeatures).build();
            trainer = new IMLGBTrainer(imlgbConfig, boosting, shouldStop);
            trainer.setInstanceWeights(sampled.getSecond());
        }
        if (i % interval == 1) {
            trainer.iterate(activeFeaturesLists, true);
        } else {
            trainer.iterate(activeFeaturesLists, false);
        }
        checkPoint.lastIter += 1;
        if (earlyStop && (i % progressInterval == 0 || i == numIterations)) {
            for (int l = 0; l < numClasses; l++) {
                EarlyStopper earlyStopper = earlyStoppers.get(l);
                Terminator terminator = terminators.get(l);
                if (!shouldStop[l]) {
                    double kl = KL(boosting, validSet, l);
                    earlyStopper.add(i, kl);
                    terminator.add(kl);
                    if (earlyStopper.shouldStop() || terminator.shouldTerminate()) {
                        logger.info("training for label " + l + " (" + allTrainData.getLabelTranslator().toExtLabel(l) + ") should stop now");
                        logger.info("the best number of training iterations for the label is " + earlyStopper.getBestIteration());
                        if (i != earlyStopper.getBestIteration()) {
                            boosting.cutTail(l, earlyStopper.getBestIteration());
                            logger.info("roll back the model for this label to iteration " + earlyStopper.getBestIteration());
                        }
                        shouldStop[l] = true;
                        numLabelsLeftToTrain -= 1;
                        checkPoint.numLabelsLeftToTrain = numLabelsLeftToTrain;
                        logger.info("the number of labels left to be trained on = " + numLabelsLeftToTrain);
                    }
                }
            }
        }
        if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
            logger.info("training set performance (computed approximately with Hamming loss predictor on " + config.getInt("train.showProgress.sampleSize") + " instances).");
            logger.info(new MLMeasures(boosting, trainSetForEval).toString());
        }
        if (config.getBoolean("train.showValidProgress") && (i % progressInterval == 0 || i == numIterations)) {
            logger.info("validation set performance (computed approximately with Hamming loss predictor)");
            MLMeasures validPerformance = new MLMeasures(boosting, validSet);
            logger.info(validPerformance.toString());
            accuracy.add(new Pair<>(i, validPerformance.getInstanceAverage().getF1()));
        }
        trainingTime.add(new Pair<>(i, startTime + timeWatch.getTime() / 1000.0));
        Serialization.serialize(checkPoint, Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "checkpoint"));
        Serialization.serialize(boosting, Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "classifier"));
        if (numLabelsLeftToTrain == 0) {
            logger.info("all label training finished");
            break;
        }
    }
    logger.info("training done");
    logger.info(stopWatch.toString());
    File analysisFolder = Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "analysis").toFile();
    if (true) {
        ObjectMapper objectMapper = new ObjectMapper();
        List<LabelModel> labelModels = IMLGBInspector.getAllRules(boosting);
        new File(analysisFolder, "decision_rules").mkdirs();
        for (int l = 0; l < boosting.getNumClasses(); l++) {
            objectMapper.writeValue(Paths.get(analysisFolder.toString(), "decision_rules", l + ".json").toFile(), labelModels.get(l));
        }
    }
    boolean topFeaturesToFile = true;
    if (topFeaturesToFile) {
        logger.info("start writing top features");
        List<TopFeatures> topFeaturesList = IntStream.range(0, boosting.getNumClasses()).mapToObj(k -> IMLGBInspector.topFeatures(boosting, k, Integer.MAX_VALUE)).collect(Collectors.toList());
        ObjectMapper mapper = new ObjectMapper();
        String file = "top_features.json";
        mapper.writeValue(new File(analysisFolder, file), topFeaturesList);
        StringBuilder sb = new StringBuilder();
        for (int l = 0; l < boosting.getNumClasses(); l++) {
            sb.append("-------------------------").append("\n");
            sb.append(allTrainData.getLabelTranslator().toExtLabel(l)).append(":").append("\n");
            for (Feature feature : topFeaturesList.get(l).getTopFeatures()) {
                sb.append(feature.simpleString()).append(", ");
            }
            sb.append("\n");
        }
        FileUtils.writeStringToFile(new File(analysisFolder, "top_features.txt"), sb.toString());
        logger.info("finish writing top features");
    }
}
Also used : IntStream(java.util.stream.IntStream) Visualizer(edu.neu.ccs.pyramid.visualization.Visualizer) edu.neu.ccs.pyramid.util(edu.neu.ccs.pyramid.util) java.util(java.util) JsonGenerator(com.fasterxml.jackson.core.JsonGenerator) SimpleFormatter(java.util.logging.SimpleFormatter) PluginPredictor(edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor) edu.neu.ccs.pyramid.multilabel_classification.imlgb(edu.neu.ccs.pyramid.multilabel_classification.imlgb) Terminator(edu.neu.ccs.pyramid.optimization.Terminator) FileHandler(java.util.logging.FileHandler) JsonEncoding(com.fasterxml.jackson.core.JsonEncoding) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) TunedMarginalClassifier(edu.neu.ccs.pyramid.multilabel_classification.thresholding.TunedMarginalClassifier) FMeasure(edu.neu.ccs.pyramid.eval.FMeasure) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Logger(java.util.logging.Logger) Collectors(java.util.stream.Collectors) KLDivergence(edu.neu.ccs.pyramid.eval.KLDivergence) JsonFactory(com.fasterxml.jackson.core.JsonFactory) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Feature(edu.neu.ccs.pyramid.feature.Feature) MacroFMeasureTuner(edu.neu.ccs.pyramid.multilabel_classification.thresholding.MacroFMeasureTuner) java.io(java.io) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) EnumeratedIntegerDistribution(org.apache.commons.math3.distribution.EnumeratedIntegerDistribution) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Terminator(edu.neu.ccs.pyramid.optimization.Terminator) Feature(edu.neu.ccs.pyramid.feature.Feature) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) StopWatch(org.apache.commons.lang3.time.StopWatch)

Aggregations

JsonEncoding (com.fasterxml.jackson.core.JsonEncoding)2 JsonFactory (com.fasterxml.jackson.core.JsonFactory)2 JsonGenerator (com.fasterxml.jackson.core.JsonGenerator)2 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)2 Config (edu.neu.ccs.pyramid.configuration.Config)2 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)2 Feature (edu.neu.ccs.pyramid.feature.Feature)2 TopFeatures (edu.neu.ccs.pyramid.feature.TopFeatures)2 MultiLabelPredictionAnalysis (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis)2 PluginPredictor (edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor)2 edu.neu.ccs.pyramid.multilabel_classification.imlgb (edu.neu.ccs.pyramid.multilabel_classification.imlgb)2 MacroFMeasureTuner (edu.neu.ccs.pyramid.multilabel_classification.thresholding.MacroFMeasureTuner)2 TunedMarginalClassifier (edu.neu.ccs.pyramid.multilabel_classification.thresholding.TunedMarginalClassifier)2 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)2 Terminator (edu.neu.ccs.pyramid.optimization.Terminator)2 Paths (java.nio.file.Paths)2 FileHandler (java.util.logging.FileHandler)2 Logger (java.util.logging.Logger)2 SimpleFormatter (java.util.logging.SimpleFormatter)2 Collectors (java.util.stream.Collectors)2