Search in sources :

Example 6 with EvaluateTreebank

use of edu.stanford.nlp.parser.lexparser.EvaluateTreebank in project CoreNLP by stanfordnlp.

the class PerceptronModel method trainModel.

private void trainModel(String serializedPath, Tagger tagger, Random random, List<Tree> binarizedTrees, List<List<Transition>> transitionLists, Treebank devTreebank, int nThreads, Set<String> allowedFeatures) {
    double bestScore = 0.0;
    int bestIteration = 0;
    PriorityQueue<ScoredObject<PerceptronModel>> bestModels = null;
    if (op.trainOptions().averagedModels > 0) {
        bestModels = new PriorityQueue<>(op.trainOptions().averagedModels + 1, ScoredComparator.ASCENDING_COMPARATOR);
    }
    List<Integer> indices = Generics.newArrayList();
    for (int i = 0; i < binarizedTrees.size(); ++i) {
        indices.add(i);
    }
    Oracle oracle = null;
    if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ORACLE) {
        oracle = new Oracle(binarizedTrees, op.compoundUnaries, rootStates);
    }
    List<Update> updates = Generics.newArrayList();
    MulticoreWrapper<Integer, Pair<Integer, Integer>> wrapper = null;
    if (nThreads != 1) {
        updates = Collections.synchronizedList(updates);
        wrapper = new MulticoreWrapper<>(op.trainOptions.trainingThreads, new TrainTreeProcessor(binarizedTrees, transitionLists, updates, oracle));
    }
    IntCounter<String> featureFrequencies = null;
    if (op.trainOptions().featureFrequencyCutoff > 1) {
        featureFrequencies = new IntCounter<>();
    }
    for (int iteration = 1; iteration <= op.trainOptions.trainingIterations; ++iteration) {
        Timing trainingTimer = new Timing();
        int numCorrect = 0;
        int numWrong = 0;
        Collections.shuffle(indices, random);
        for (int start = 0; start < indices.size(); start += op.trainOptions.batchSize) {
            int end = Math.min(start + op.trainOptions.batchSize, indices.size());
            Triple<List<Update>, Integer, Integer> result = trainBatch(indices.subList(start, end), binarizedTrees, transitionLists, updates, oracle, wrapper);
            numCorrect += result.second;
            numWrong += result.third;
            for (Update update : result.first) {
                for (String feature : update.features) {
                    if (allowedFeatures != null && !allowedFeatures.contains(feature)) {
                        continue;
                    }
                    Weight weights = featureWeights.get(feature);
                    if (weights == null) {
                        weights = new Weight();
                        featureWeights.put(feature, weights);
                    }
                    weights.updateWeight(update.goldTransition, update.delta);
                    weights.updateWeight(update.predictedTransition, -update.delta);
                    if (featureFrequencies != null) {
                        featureFrequencies.incrementCount(feature, (update.goldTransition >= 0 && update.predictedTransition >= 0) ? 2 : 1);
                    }
                }
            }
            updates.clear();
        }
        trainingTimer.done("Iteration " + iteration);
        log.info("While training, got " + numCorrect + " transitions correct and " + numWrong + " transitions wrong");
        outputStats();
        double labelF1 = 0.0;
        if (devTreebank != null) {
            EvaluateTreebank evaluator = new EvaluateTreebank(op, null, new ShiftReduceParser(op, this), tagger);
            evaluator.testOnTreebank(devTreebank);
            labelF1 = evaluator.getLBScore();
            log.info("Label F1 after " + iteration + " iterations: " + labelF1);
            if (labelF1 > bestScore) {
                log.info("New best dev score (previous best " + bestScore + ")");
                bestScore = labelF1;
                bestIteration = iteration;
            } else {
                log.info("Failed to improve for " + (iteration - bestIteration) + " iteration(s) on previous best score of " + bestScore);
                if (op.trainOptions.stalledIterationLimit > 0 && (iteration - bestIteration >= op.trainOptions.stalledIterationLimit)) {
                    log.info("Failed to improve for too long, stopping training");
                    break;
                }
            }
            log.info();
            if (bestModels != null) {
                bestModels.add(new ScoredObject<>(new PerceptronModel(this), labelF1));
                if (bestModels.size() > op.trainOptions().averagedModels) {
                    bestModels.poll();
                }
            }
        }
        if (op.trainOptions().saveIntermediateModels && serializedPath != null && op.trainOptions.debugOutputFrequency > 0) {
            String tempName = serializedPath.substring(0, serializedPath.length() - 7) + "-" + FILENAME.format(iteration) + "-" + NF.format(labelF1) + ".ser.gz";
            ShiftReduceParser temp = new ShiftReduceParser(op, this);
            temp.saveModel(tempName);
        // TODO: we could save a cutoff version of the model,
        // especially if we also get a dev set number for it, but that
        // might be overkill
        }
    }
    if (wrapper != null) {
        wrapper.join();
    }
    if (bestModels != null) {
        if (op.trainOptions().cvAveragedModels && devTreebank != null) {
            List<ScoredObject<PerceptronModel>> models = Generics.newArrayList();
            while (bestModels.size() > 0) {
                models.add(bestModels.poll());
            }
            Collections.reverse(models);
            double bestF1 = 0.0;
            int bestSize = 0;
            for (int i = 1; i <= models.size(); ++i) {
                log.info("Testing with " + i + " models averaged together");
                // TODO: this is kind of ugly, would prefer a separate object
                averageScoredModels(models.subList(0, i));
                ShiftReduceParser temp = new ShiftReduceParser(op, this);
                EvaluateTreebank evaluator = new EvaluateTreebank(temp.getOp(), null, temp, tagger);
                evaluator.testOnTreebank(devTreebank);
                double labelF1 = evaluator.getLBScore();
                log.info("Label F1 for " + i + " models: " + labelF1);
                if (labelF1 > bestF1) {
                    bestF1 = labelF1;
                    bestSize = i;
                }
            }
            averageScoredModels(models.subList(0, bestSize));
        } else {
            averageScoredModels(bestModels);
        }
    }
    // after filtering.
    if (featureFrequencies != null) {
        filterFeatures(featureFrequencies.keysAbove(op.trainOptions().featureFrequencyCutoff));
    }
    condenseFeatures();
}
Also used : EvaluateTreebank(edu.stanford.nlp.parser.lexparser.EvaluateTreebank) ScoredObject(edu.stanford.nlp.util.ScoredObject) List(java.util.List) Pair(edu.stanford.nlp.util.Pair) ParserConstraint(edu.stanford.nlp.parser.common.ParserConstraint) Timing(edu.stanford.nlp.util.Timing)

Aggregations

EvaluateTreebank (edu.stanford.nlp.parser.lexparser.EvaluateTreebank)6 Pair (edu.stanford.nlp.util.Pair)5 Treebank (edu.stanford.nlp.trees.Treebank)4 FileFilter (java.io.FileFilter)4 ArrayList (java.util.ArrayList)4 LexicalizedParser (edu.stanford.nlp.parser.lexparser.LexicalizedParser)3 Tree (edu.stanford.nlp.trees.Tree)2 Timing (edu.stanford.nlp.util.Timing)2 ParserConstraint (edu.stanford.nlp.parser.common.ParserConstraint)1 Options (edu.stanford.nlp.parser.lexparser.Options)1 Reranker (edu.stanford.nlp.parser.lexparser.Reranker)1 CompositeTreeTransformer (edu.stanford.nlp.trees.CompositeTreeTransformer)1 TreeTransformer (edu.stanford.nlp.trees.TreeTransformer)1 ScoredObject (edu.stanford.nlp.util.ScoredObject)1 FileWriter (java.io.FileWriter)1 List (java.util.List)1