Search in sources :

Example 11 with ScoredObject

use of edu.stanford.nlp.util.ScoredObject in project CoreNLP by stanfordnlp.

the class EvaluateTreebank method processResults.

public void processResults(ParserQuery pq, Tree goldTree, PrintWriter pwErr, PrintWriter pwOut, PrintWriter pwFileOut, PrintWriter pwStats, TreePrint treePrint) {
    if (pq.saidMemMessage()) {
        saidMemMessage = true;
    }
    Tree tree;
    List<? extends HasWord> sentence = pq.originalSentence();
    try {
        tree = pq.getBestParse();
    } catch (NoSuchParseException e) {
        tree = null;
    }
    List<ScoredObject<Tree>> kbestPCFGTrees = null;
    if (tree != null && kbestPCFG > 0) {
        kbestPCFGTrees = pq.getKBestPCFGParses(kbestPCFG);
    }
    //combo parse goes to pwOut (System.out)
    if (op.testOptions.verbose) {
        pwOut.println("ComboParser best");
        Tree ot = tree;
        if (ot != null && !op.tlpParams.treebankLanguagePack().isStartSymbol(ot.value())) {
            ot = ot.treeFactory().newTreeNode(op.tlpParams.treebankLanguagePack().startSymbol(), Collections.singletonList(ot));
        }
        treePrint.printTree(ot, pwOut);
    } else {
        treePrint.printTree(tree, pwOut);
    }
    // print various statistics
    if (tree != null) {
        if (op.testOptions.printAllBestParses) {
            List<ScoredObject<Tree>> parses = pq.getBestPCFGParses();
            int sz = parses.size();
            if (sz > 1) {
                pwOut.println("There were " + sz + " best PCFG parses with score " + parses.get(0).score() + '.');
                Tree transGoldTree = collinizer.transformTree(goldTree);
                int iii = 0;
                for (ScoredObject<Tree> sot : parses) {
                    iii++;
                    Tree tb = sot.object();
                    Tree tbd = debinarizer.transformTree(tb);
                    tbd = subcategoryStripper.transformTree(tbd);
                    pq.restoreOriginalWords(tbd);
                    pwOut.println("PCFG Parse #" + iii + " with score " + tbd.score());
                    tbd.pennPrint(pwOut);
                    Tree tbtr = collinizer.transformTree(tbd);
                    // pwOut.println("Tree size = " + tbtr.size() + "; depth = " + tbtr.depth());
                    kGoodLB.evaluate(tbtr, transGoldTree, pwErr);
                }
            }
        } else // Huang and Chiang (2006) Algorithm 3 output from the PCFG parser
        if (op.testOptions.printPCFGkBest > 0 && op.testOptions.outputkBestEquivocation == null) {
            List<ScoredObject<Tree>> trees = kbestPCFGTrees.subList(0, op.testOptions.printPCFGkBest);
            Tree transGoldTree = collinizer.transformTree(goldTree);
            int i = 0;
            for (ScoredObject<Tree> tp : trees) {
                i++;
                pwOut.println("PCFG Parse #" + i + " with score " + tp.score());
                Tree tbd = tp.object();
                tbd.pennPrint(pwOut);
                Tree tbtr = collinizer.transformTree(tbd);
                kGoodLB.evaluate(tbtr, transGoldTree, pwErr);
            }
        } else // Chart parser (factored) n-best list
        if (op.testOptions.printFactoredKGood > 0 && pq.hasFactoredParse()) {
            // DZ: debug n best trees
            List<ScoredObject<Tree>> trees = pq.getKGoodFactoredParses(op.testOptions.printFactoredKGood);
            Tree transGoldTree = collinizer.transformTree(goldTree);
            int ii = 0;
            for (ScoredObject<Tree> tp : trees) {
                ii++;
                pwOut.println("Factored Parse #" + ii + " with score " + tp.score());
                Tree tbd = tp.object();
                tbd.pennPrint(pwOut);
                Tree tbtr = collinizer.transformTree(tbd);
                kGoodLB.evaluate(tbtr, transGoldTree, pwOut);
            }
        } else //1-best output
        if (pwFileOut != null) {
            pwFileOut.println(tree.toString());
        }
        //Print the derivational entropy
        if (op.testOptions.outputkBestEquivocation != null && op.testOptions.printPCFGkBest > 0) {
            List<ScoredObject<Tree>> trees = kbestPCFGTrees.subList(0, op.testOptions.printPCFGkBest);
            double[] logScores = new double[trees.size()];
            int treeId = 0;
            for (ScoredObject<Tree> kBestTree : trees) logScores[treeId++] = kBestTree.score();
            //Re-normalize
            double entropy = 0.0;
            double denom = ArrayMath.logSum(logScores);
            for (double logScore : logScores) {
                double logPr = logScore - denom;
                entropy += Math.exp(logPr) * (logPr / Math.log(2));
            }
            //Convert to bits
            entropy *= -1;
            pwStats.printf("%f\t%d\t%d\n", entropy, trees.size(), sentence.size());
        }
    }
    // Perform various evaluations specified by the user
    if (tree != null) {
        //Strip subcategories and remove punctuation for evaluation
        tree = subcategoryStripper.transformTree(tree);
        Tree treeFact = collinizer.transformTree(tree);
        //Setup the gold tree
        if (op.testOptions.verbose) {
            pwOut.println("Correct parse");
            treePrint.printTree(goldTree, pwOut);
        }
        Tree transGoldTree = collinizer.transformTree(goldTree);
        if (transGoldTree != null)
            transGoldTree = subcategoryStripper.transformTree(transGoldTree);
        //Can't do evaluation in these two cases
        if (transGoldTree == null) {
            pwErr.println("Couldn't transform gold tree for evaluation, skipping eval. Gold tree was:");
            goldTree.pennPrint(pwErr);
            numSkippedEvals++;
            return;
        } else if (treeFact == null) {
            pwErr.println("Couldn't transform hypothesis tree for evaluation, skipping eval. Tree was:");
            tree.pennPrint(pwErr);
            numSkippedEvals++;
            return;
        } else if (treeFact.yield().size() != transGoldTree.yield().size()) {
            List<Label> fYield = treeFact.yield();
            List<Label> gYield = transGoldTree.yield();
            pwErr.println("WARNING: Evaluation could not be performed due to gold/parsed yield mismatch.");
            pwErr.printf("  sizes: gold: %d (transf) %d (orig); parsed: %d (transf) %d (orig).%n", gYield.size(), goldTree.yield().size(), fYield.size(), tree.yield().size());
            pwErr.println("  gold: " + SentenceUtils.listToString(gYield, true));
            pwErr.println("  pars: " + SentenceUtils.listToString(fYield, true));
            numSkippedEvals++;
            return;
        }
        if (topKEvals.size() > 0) {
            List<Tree> transGuesses = new ArrayList<>();
            int kbest = Math.min(op.testOptions.evalPCFGkBest, kbestPCFGTrees.size());
            for (ScoredObject<Tree> guess : kbestPCFGTrees.subList(0, kbest)) {
                transGuesses.add(collinizer.transformTree(guess.object()));
            }
            for (BestOfTopKEval eval : topKEvals) {
                eval.evaluate(transGuesses, transGoldTree, pwErr);
            }
        }
        //PCFG eval
        Tree treePCFG = pq.getBestPCFGParse();
        if (treePCFG != null) {
            Tree treePCFGeval = collinizer.transformTree(treePCFG);
            if (pcfgLB != null) {
                pcfgLB.evaluate(treePCFGeval, transGoldTree, pwErr);
            }
            if (pcfgChildSpecific != null) {
                pcfgChildSpecific.evaluate(treePCFGeval, transGoldTree, pwErr);
            }
            if (pcfgLA != null) {
                pcfgLA.evaluate(treePCFGeval, transGoldTree, pwErr);
            }
            if (pcfgCB != null) {
                pcfgCB.evaluate(treePCFGeval, transGoldTree, pwErr);
            }
            if (pcfgDA != null) {
                // Re-index the leaves after Collinization, stripping traces, etc.
                treePCFGeval.indexLeaves(true);
                transGoldTree.indexLeaves(true);
                pcfgDA.evaluate(treePCFGeval, transGoldTree, pwErr);
            }
            if (pcfgTA != null) {
                pcfgTA.evaluate(treePCFGeval, transGoldTree, pwErr);
            }
            if (pcfgLL != null && pq.getPCFGParser() != null) {
                pcfgLL.recordScore(pq.getPCFGParser(), pwErr);
            }
            if (pcfgRUO != null) {
                pcfgRUO.evaluate(treePCFGeval, transGoldTree, pwErr);
            }
            if (pcfgCUO != null) {
                pcfgCUO.evaluate(treePCFGeval, transGoldTree, pwErr);
            }
            if (pcfgCatE != null) {
                pcfgCatE.evaluate(treePCFGeval, transGoldTree, pwErr);
            }
        }
        //Dependency eval
        // todo: is treeDep really useful here, or should we really use depDAEval tree (debinarized) throughout? We use it for parse, and it sure seems like we could use it for tag eval, but maybe not factDA?
        Tree treeDep = pq.getBestDependencyParse(false);
        if (treeDep != null) {
            Tree goldTreeB = binarizerOnly.transformTree(goldTree);
            Tree goldTreeEval = goldTree.deepCopy();
            goldTreeEval.indexLeaves(true);
            goldTreeEval.percolateHeads(op.langpack().headFinder());
            Tree depDAEval = pq.getBestDependencyParse(true);
            depDAEval.indexLeaves(true);
            depDAEval.percolateHeadIndices();
            if (depDA != null) {
                depDA.evaluate(depDAEval, goldTreeEval, pwErr);
            }
            if (depTA != null) {
                Tree undoneTree = debinarizer.transformTree(treeDep);
                undoneTree = subcategoryStripper.transformTree(undoneTree);
                pq.restoreOriginalWords(undoneTree);
                // pwErr.println("subcategoryStripped tree: " + undoneTree.toStructureDebugString());
                depTA.evaluate(undoneTree, goldTree, pwErr);
            }
            if (depLL != null && pq.getDependencyParser() != null) {
                depLL.recordScore(pq.getDependencyParser(), pwErr);
            }
            Tree factTreeB;
            if (pq.hasFactoredParse()) {
                factTreeB = pq.getBestFactoredParse();
            } else {
                factTreeB = treeDep;
            }
            if (factDA != null) {
                factDA.evaluate(factTreeB, goldTreeB, pwErr);
            }
        }
        //Factored parser (1best) eval
        if (factLB != null) {
            factLB.evaluate(treeFact, transGoldTree, pwErr);
        }
        if (factChildSpecific != null) {
            factChildSpecific.evaluate(treeFact, transGoldTree, pwErr);
        }
        if (factLA != null) {
            factLA.evaluate(treeFact, transGoldTree, pwErr);
        }
        if (factTA != null) {
            factTA.evaluate(tree, boundaryRemover.transformTree(goldTree), pwErr);
        }
        if (factLL != null && pq.getFactoredParser() != null) {
            factLL.recordScore(pq.getFactoredParser(), pwErr);
        }
        if (factCB != null) {
            factCB.evaluate(treeFact, transGoldTree, pwErr);
        }
        for (Eval eval : evals) {
            eval.evaluate(treeFact, transGoldTree, pwErr);
        }
        if (parserQueryEvals != null) {
            for (ParserQueryEval eval : parserQueryEvals) {
                eval.evaluate(pq, transGoldTree, pwErr);
            }
        }
        if (op.testOptions.evalb) {
            // empty out scores just in case
            nanScores(tree);
            EvalbFormatWriter.writeEVALBline(treeFact, transGoldTree);
        }
    }
    pwErr.println();
}
Also used : ArrayList(java.util.ArrayList) ParserQueryEval(edu.stanford.nlp.parser.metrics.ParserQueryEval) TreePrint(edu.stanford.nlp.trees.TreePrint) NoSuchParseException(edu.stanford.nlp.parser.common.NoSuchParseException) ScoredObject(edu.stanford.nlp.util.ScoredObject) Tree(edu.stanford.nlp.trees.Tree) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) List(java.util.List) LeafAncestorEval(edu.stanford.nlp.parser.metrics.LeafAncestorEval) AbstractEval(edu.stanford.nlp.parser.metrics.AbstractEval) TaggingEval(edu.stanford.nlp.parser.metrics.TaggingEval) TopMatchEval(edu.stanford.nlp.parser.metrics.TopMatchEval) FilteredEval(edu.stanford.nlp.parser.metrics.FilteredEval) Eval(edu.stanford.nlp.parser.metrics.Eval) UnlabeledAttachmentEval(edu.stanford.nlp.parser.metrics.UnlabeledAttachmentEval) BestOfTopKEval(edu.stanford.nlp.parser.metrics.BestOfTopKEval) ParserQueryEval(edu.stanford.nlp.parser.metrics.ParserQueryEval) BestOfTopKEval(edu.stanford.nlp.parser.metrics.BestOfTopKEval)

Example 12 with ScoredObject

use of edu.stanford.nlp.util.ScoredObject in project CoreNLP by stanfordnlp.

the class PerceptronModel method trainModel.

private void trainModel(String serializedPath, Tagger tagger, Random random, List<Tree> binarizedTrees, List<List<Transition>> transitionLists, Treebank devTreebank, int nThreads, Set<String> allowedFeatures) {
    double bestScore = 0.0;
    int bestIteration = 0;
    PriorityQueue<ScoredObject<PerceptronModel>> bestModels = null;
    if (op.trainOptions().averagedModels > 0) {
        bestModels = new PriorityQueue<>(op.trainOptions().averagedModels + 1, ScoredComparator.ASCENDING_COMPARATOR);
    }
    List<Integer> indices = Generics.newArrayList();
    for (int i = 0; i < binarizedTrees.size(); ++i) {
        indices.add(i);
    }
    Oracle oracle = null;
    if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ORACLE) {
        oracle = new Oracle(binarizedTrees, op.compoundUnaries, rootStates);
    }
    List<Update> updates = Generics.newArrayList();
    MulticoreWrapper<Integer, Pair<Integer, Integer>> wrapper = null;
    if (nThreads != 1) {
        updates = Collections.synchronizedList(updates);
        wrapper = new MulticoreWrapper<>(op.trainOptions.trainingThreads, new TrainTreeProcessor(binarizedTrees, transitionLists, updates, oracle));
    }
    IntCounter<String> featureFrequencies = null;
    if (op.trainOptions().featureFrequencyCutoff > 1) {
        featureFrequencies = new IntCounter<>();
    }
    for (int iteration = 1; iteration <= op.trainOptions.trainingIterations; ++iteration) {
        Timing trainingTimer = new Timing();
        int numCorrect = 0;
        int numWrong = 0;
        Collections.shuffle(indices, random);
        for (int start = 0; start < indices.size(); start += op.trainOptions.batchSize) {
            int end = Math.min(start + op.trainOptions.batchSize, indices.size());
            Triple<List<Update>, Integer, Integer> result = trainBatch(indices.subList(start, end), binarizedTrees, transitionLists, updates, oracle, wrapper);
            numCorrect += result.second;
            numWrong += result.third;
            for (Update update : result.first) {
                for (String feature : update.features) {
                    if (allowedFeatures != null && !allowedFeatures.contains(feature)) {
                        continue;
                    }
                    Weight weights = featureWeights.get(feature);
                    if (weights == null) {
                        weights = new Weight();
                        featureWeights.put(feature, weights);
                    }
                    weights.updateWeight(update.goldTransition, update.delta);
                    weights.updateWeight(update.predictedTransition, -update.delta);
                    if (featureFrequencies != null) {
                        featureFrequencies.incrementCount(feature, (update.goldTransition >= 0 && update.predictedTransition >= 0) ? 2 : 1);
                    }
                }
            }
            updates.clear();
        }
        trainingTimer.done("Iteration " + iteration);
        log.info("While training, got " + numCorrect + " transitions correct and " + numWrong + " transitions wrong");
        outputStats();
        double labelF1 = 0.0;
        if (devTreebank != null) {
            EvaluateTreebank evaluator = new EvaluateTreebank(op, null, new ShiftReduceParser(op, this), tagger);
            evaluator.testOnTreebank(devTreebank);
            labelF1 = evaluator.getLBScore();
            log.info("Label F1 after " + iteration + " iterations: " + labelF1);
            if (labelF1 > bestScore) {
                log.info("New best dev score (previous best " + bestScore + ")");
                bestScore = labelF1;
                bestIteration = iteration;
            } else {
                log.info("Failed to improve for " + (iteration - bestIteration) + " iteration(s) on previous best score of " + bestScore);
                if (op.trainOptions.stalledIterationLimit > 0 && (iteration - bestIteration >= op.trainOptions.stalledIterationLimit)) {
                    log.info("Failed to improve for too long, stopping training");
                    break;
                }
            }
            log.info();
            if (bestModels != null) {
                bestModels.add(new ScoredObject<>(new PerceptronModel(this), labelF1));
                if (bestModels.size() > op.trainOptions().averagedModels) {
                    bestModels.poll();
                }
            }
        }
        if (op.trainOptions().saveIntermediateModels && serializedPath != null && op.trainOptions.debugOutputFrequency > 0) {
            String tempName = serializedPath.substring(0, serializedPath.length() - 7) + "-" + FILENAME.format(iteration) + "-" + NF.format(labelF1) + ".ser.gz";
            ShiftReduceParser temp = new ShiftReduceParser(op, this);
            temp.saveModel(tempName);
        // TODO: we could save a cutoff version of the model,
        // especially if we also get a dev set number for it, but that
        // might be overkill
        }
    }
    if (wrapper != null) {
        wrapper.join();
    }
    if (bestModels != null) {
        if (op.trainOptions().cvAveragedModels && devTreebank != null) {
            List<ScoredObject<PerceptronModel>> models = Generics.newArrayList();
            while (bestModels.size() > 0) {
                models.add(bestModels.poll());
            }
            Collections.reverse(models);
            double bestF1 = 0.0;
            int bestSize = 0;
            for (int i = 1; i <= models.size(); ++i) {
                log.info("Testing with " + i + " models averaged together");
                // TODO: this is kind of ugly, would prefer a separate object
                averageScoredModels(models.subList(0, i));
                ShiftReduceParser temp = new ShiftReduceParser(op, this);
                EvaluateTreebank evaluator = new EvaluateTreebank(temp.getOp(), null, temp, tagger);
                evaluator.testOnTreebank(devTreebank);
                double labelF1 = evaluator.getLBScore();
                log.info("Label F1 for " + i + " models: " + labelF1);
                if (labelF1 > bestF1) {
                    bestF1 = labelF1;
                    bestSize = i;
                }
            }
            averageScoredModels(models.subList(0, bestSize));
        } else {
            averageScoredModels(bestModels);
        }
    }
    // after filtering.
    if (featureFrequencies != null) {
        filterFeatures(featureFrequencies.keysAbove(op.trainOptions().featureFrequencyCutoff));
    }
    condenseFeatures();
}
Also used : EvaluateTreebank(edu.stanford.nlp.parser.lexparser.EvaluateTreebank) ScoredObject(edu.stanford.nlp.util.ScoredObject) List(java.util.List) Pair(edu.stanford.nlp.util.Pair) ParserConstraint(edu.stanford.nlp.parser.common.ParserConstraint) Timing(edu.stanford.nlp.util.Timing)

Example 13 with ScoredObject

use of edu.stanford.nlp.util.ScoredObject in project CoreNLP by stanfordnlp.

the class PerceptronModel method trainModel.

private void trainModel(String serializedPath, Tagger tagger, Random random, List<TrainingExample> trainingData, Treebank devTreebank, int nThreads, Set<String> allowedFeatures) {
    double bestScore = 0.0;
    int bestIteration = 0;
    PriorityQueue<ScoredObject<PerceptronModel>> bestModels = null;
    if (op.trainOptions().averagedModels > 0) {
        bestModels = new PriorityQueue<>(op.trainOptions().averagedModels + 1, ScoredComparator.ASCENDING_COMPARATOR);
    }
    MulticoreWrapper<TrainingExample, TrainingResult> wrapper = null;
    if (nThreads != 1) {
        wrapper = new MulticoreWrapper<>(op.trainOptions.trainingThreads, new TrainTreeProcessor());
    }
    IntCounter<String> featureFrequencies = null;
    if (op.trainOptions().featureFrequencyCutoff > 1 && allowedFeatures == null) {
        // allowedFeatures != null means we already filtered rare
        // features once.  Sometimes the exact features found are
        // different depending on how the learning proceeds.  The second
        // time training, we only allow rare features to exist if they
        // met the threshold established the first time around
        featureFrequencies = new IntCounter<>();
    }
    for (int iteration = 1; iteration <= op.trainOptions.trainingIterations; ++iteration) {
        Timing trainingTimer = new Timing();
        List<TrainingResult> results = new ArrayList<>();
        List<TrainingExample> augmentedData = new ArrayList<TrainingExample>(trainingData);
        augmentSubsentences(augmentedData, trainingData, random, op.trainOptions().augmentSubsentences);
        Collections.shuffle(augmentedData, random);
        log.info("Original list " + trainingData.size() + "; augmented " + augmentedData.size());
        for (int start = 0; start < augmentedData.size(); start += op.trainOptions.batchSize) {
            int end = Math.min(start + op.trainOptions.batchSize, augmentedData.size());
            TrainingResult result = trainBatch(augmentedData.subList(start, end), wrapper);
            results.add(result);
            for (TrainingUpdate update : result.updates) {
                for (String feature : update.features) {
                    if (allowedFeatures != null && !allowedFeatures.contains(feature)) {
                        continue;
                    }
                    Weight weight = featureWeights.get(feature);
                    if (weight == null) {
                        weight = new Weight();
                        featureWeights.put(feature, weight);
                    }
                    weight.updateWeight(update.goldTransition, update.delta);
                    weight.updateWeight(update.predictedTransition, -update.delta);
                    if (featureFrequencies != null) {
                        featureFrequencies.incrementCount(feature, (update.goldTransition >= 0 && update.predictedTransition >= 0) ? 2 : 1);
                    }
                }
            }
        }
        float l2Reg = op.trainOptions().l2Reg;
        if (l2Reg > 0.0f) {
            for (Map.Entry<String, Weight> weight : featureWeights.entrySet()) {
                weight.getValue().l2Reg(l2Reg);
            }
        }
        float l1Reg = op.trainOptions().l1Reg;
        if (l1Reg > 0.0f) {
            for (Map.Entry<String, Weight> weight : featureWeights.entrySet()) {
                weight.getValue().l1Reg(l1Reg);
            }
        }
        trainingTimer.done("Iteration " + iteration);
        outputStats(new TrainingResult(results));
        double labelF1 = 0.0;
        if (devTreebank != null) {
            labelF1 = evaluate(tagger, devTreebank, "Label F1 for iteration " + iteration);
            if (labelF1 > bestScore) {
                log.info("New best dev score (previous best " + bestScore + ")");
                bestScore = labelF1;
                bestIteration = iteration;
            } else {
                log.info("Failed to improve for " + (iteration - bestIteration) + " iteration(s) on previous best score of " + bestScore);
                if (op.trainOptions.stalledIterationLimit > 0 && (iteration - bestIteration >= op.trainOptions.stalledIterationLimit)) {
                    log.info("Failed to improve for too long, stopping training");
                    break;
                }
            }
            log.info("\n\n");
            if (bestModels != null) {
                PerceptronModel copy = new PerceptronModel(this);
                copy.condenseFeatures();
                bestModels.add(new ScoredObject<>(copy, labelF1));
                if (bestModels.size() > op.trainOptions().averagedModels) {
                    bestModels.poll();
                }
            }
        }
        if (op.trainOptions().saveIntermediateModels && serializedPath != null && op.trainOptions.debugOutputFrequency > 0) {
            String tempName = serializedPath.substring(0, serializedPath.length() - 7) + "-" + FILENAME.format(iteration) + "-" + NF.format(labelF1) + ".ser.gz";
            ShiftReduceParser temp = new ShiftReduceParser(op, this);
            temp.saveModel(tempName);
        // TODO: we could save a cutoff version of the model,
        // especially if we also get a dev set number for it, but that
        // might be overkill
        }
        if (iteration % 10 == 0 && op.trainOptions().decayLearningRate > 0.0) {
            learningRate *= op.trainOptions().decayLearningRate;
        }
    }
    if (wrapper != null) {
        wrapper.join();
    }
    if (bestModels != null) {
        if (op.trainOptions().cvAveragedModels && devTreebank != null) {
            List<ScoredObject<PerceptronModel>> models = Generics.newArrayList();
            while (bestModels.size() > 0) {
                models.add(bestModels.poll());
            }
            Collections.reverse(models);
            double bestF1 = 0.0;
            int bestSize = 0;
            for (int i = 1; i <= models.size(); ++i) {
                log.info("Testing with " + i + " models averaged together");
                // TODO: this is kind of ugly, would prefer a separate object
                averageScoredModels(models.subList(0, i));
                double labelF1 = evaluate(tagger, devTreebank, "Label F1 for " + i + " models");
                if (labelF1 > bestF1) {
                    bestF1 = labelF1;
                    bestSize = i;
                }
            }
            averageScoredModels(models.subList(0, bestSize));
            log.info("Label F1 for " + bestSize + " models: " + bestF1);
        } else {
            averageScoredModels(bestModels);
        }
    }
    // after filtering.
    if (featureFrequencies != null) {
        filterFeatures(featureFrequencies.keysAbove(op.trainOptions().featureFrequencyCutoff));
    }
    condenseFeatures();
}
Also used : ArrayList(java.util.ArrayList) ScoredObject(edu.stanford.nlp.util.ScoredObject) ParserConstraint(edu.stanford.nlp.parser.common.ParserConstraint) Timing(edu.stanford.nlp.util.Timing) Map(java.util.Map)

Example 14 with ScoredObject

use of edu.stanford.nlp.util.ScoredObject in project CoreNLP by stanfordnlp.

the class ShiftReduceParserQuery method parseInternal.

private boolean parseInternal() {
    final int maxBeamSize;
    if (parser.op.testOptions().beamSize == 0) {
        maxBeamSize = Math.max(parser.op.trainOptions().beamSize, 1);
    } else {
        maxBeamSize = parser.op.testOptions().beamSize;
    }
    success = true;
    unparsable = false;
    PriorityQueue<State> oldBeam = new PriorityQueue<>(maxBeamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
    PriorityQueue<State> beam = new PriorityQueue<>(maxBeamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
    // nextBeam will keep track of an unused PriorityQueue to cut down on the number of PriorityQueue objects created
    PriorityQueue<State> nextBeam = new PriorityQueue<>(maxBeamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
    beam.add(initialState);
    while (beam.size() > 0) {
        if (Thread.interrupted()) {
            // Allow interrupting the parser
            throw new RuntimeInterruptedException();
        }
        // log.info("================================================");
        // log.info("Current beam:");
        // log.info(beam);
        PriorityQueue<State> temp = oldBeam;
        oldBeam = beam;
        beam = nextBeam;
        beam.clear();
        nextBeam = temp;
        State bestState = null;
        for (State state : oldBeam) {
            if (Thread.interrupted()) {
                // Allow interrupting the parser
                throw new RuntimeInterruptedException();
            }
            Collection<ScoredObject<Integer>> predictedTransitions = parser.model.findHighestScoringTransitions(state, true, maxBeamSize, constraints);
            // log.info("Examining state: " + state);
            for (ScoredObject<Integer> predictedTransition : predictedTransitions) {
                Transition transition = parser.model.transitionIndex.get(predictedTransition.object());
                State newState = transition.apply(state, predictedTransition.score());
                // log.info("  Transition: " + transition + " (" + predictedTransition.score() + ")");
                if (bestState == null || bestState.score() < newState.score()) {
                    bestState = newState;
                }
                beam.add(newState);
                if (beam.size() > maxBeamSize) {
                    beam.poll();
                }
            }
        }
        if (beam.size() == 0) {
            // will result in some sort of parse.
            for (State state : oldBeam) {
                Transition transition = parser.model.findEmergencyTransition(state, constraints);
                if (transition != null) {
                    State newState = transition.apply(state);
                    if (bestState == null || bestState.score() < newState.score()) {
                        bestState = newState;
                    }
                    beam.add(newState);
                }
            }
        }
        // If the bestState is finished, we are done
        if (bestState == null || bestState.isFinished()) {
            break;
        }
    }
    bestParses = beam.stream().filter((state) -> state.isFinished()).collect(Collectors.toList());
    if (bestParses.size() == 0) {
        success = false;
        unparsable = true;
        debinarized = null;
        finalState = null;
        bestParses = Collections.emptyList();
    } else {
        Collections.sort(bestParses, beam.comparator());
        Collections.reverse(bestParses);
        finalState = bestParses.get(0);
        debinarized = debinarizer.transformTree(finalState.stack.peek());
        debinarized = Tsurgeon.processPattern(rearrangeFinalPunctuationTregex, rearrangeFinalPunctuationTsurgeon, debinarized);
    }
    return success;
}
Also used : RuntimeInterruptedException(edu.stanford.nlp.util.RuntimeInterruptedException) ScoredObject(edu.stanford.nlp.util.ScoredObject) PriorityQueue(java.util.PriorityQueue) ParserConstraint(edu.stanford.nlp.parser.common.ParserConstraint)

Example 15 with ScoredObject

use of edu.stanford.nlp.util.ScoredObject in project CoreNLP by stanfordnlp.

the class EvaluateExternalParser method convertDataset.

public List<Pair<ParserQuery, Tree>> convertDataset(List<Tree> goldTrees, List<List<Tree>> results) {
    List<Pair<ParserQuery, Tree>> dataset = new ArrayList<>();
    if (goldTrees.size() != results.size()) {
        throw new AssertionError("The lists should always be of the same length at this point");
    }
    for (int i = 0; i < goldTrees.size(); ++i) {
        Tree gold = goldTrees.get(i);
        List<CoreLabel> sentence = SentenceUtils.toCoreLabelList(gold.yieldWords());
        List<ScoredObject<Tree>> scoredResult = new ArrayList<>();
        for (Tree tree : results.get(i)) {
            double score = tree.score();
            scoredResult.add(new ScoredObject<>(tree, score));
        }
        ExternalParserQuery pq = new ExternalParserQuery(sentence, scoredResult);
        dataset.add(new Pair<>(pq, gold));
    }
    return dataset;
}
Also used : ArrayList(java.util.ArrayList) CoreLabel(edu.stanford.nlp.ling.CoreLabel) ScoredObject(edu.stanford.nlp.util.ScoredObject) Tree(edu.stanford.nlp.trees.Tree) Pair(edu.stanford.nlp.util.Pair)

Aggregations

ScoredObject (edu.stanford.nlp.util.ScoredObject)17 Tree (edu.stanford.nlp.trees.Tree)11 ArrayList (java.util.ArrayList)11 ParserConstraint (edu.stanford.nlp.parser.common.ParserConstraint)7 List (java.util.List)6 PriorityQueue (java.util.PriorityQueue)5 Pair (edu.stanford.nlp.util.Pair)4 NoSuchParseException (edu.stanford.nlp.parser.common.NoSuchParseException)3 Word (edu.stanford.nlp.ling.Word)2 ParserQuery (edu.stanford.nlp.parser.common.ParserQuery)2 TreePrint (edu.stanford.nlp.trees.TreePrint)2 Timing (edu.stanford.nlp.util.Timing)2 IOException (java.io.IOException)2 LinkedList (java.util.LinkedList)2 Map (java.util.Map)2 CoreLabel (edu.stanford.nlp.ling.CoreLabel)1 EvaluateTreebank (edu.stanford.nlp.parser.lexparser.EvaluateTreebank)1 LexicalizedParser (edu.stanford.nlp.parser.lexparser.LexicalizedParser)1 RerankingParserQuery (edu.stanford.nlp.parser.lexparser.RerankingParserQuery)1 AbstractEval (edu.stanford.nlp.parser.metrics.AbstractEval)1