Search in sources :

Example 1 with EvaluateTreebank

use of edu.stanford.nlp.parser.metrics.EvaluateTreebank in project CoreNLP by stanfordnlp.

the class StringBinaryRule method testGrammarCompaction.

public void testGrammarCompaction() {
    // these for testing against the markov 3rd order baseline
    // use the parser constructor to extract the grammars from the treebank
    op = new Options();
    LexicalizedParser lp = LexicalizedParser.trainFromTreebank(path, new NumberRangeFileFilter(trainLow, trainHigh, true), op);
    // compact grammars
    if (compactor != null) {
        // extract a bunch of paths
        Timing.startTime();
        System.out.print("Extracting other paths...");
        allTrainPaths = extractPaths(path, trainLow, trainHigh, true);
        allTestPaths = extractPaths(path, testLow, testHigh, true);
        Timing.tick("done");
        // compact grammars
        Timing.startTime();
        System.out.print("Compacting grammars...");
        Pair<UnaryGrammar, BinaryGrammar> grammar = Generics.newPair(lp.ug, lp.bg);
        Triple<Index<String>, UnaryGrammar, BinaryGrammar> compactedGrammar = compactor.compactGrammar(grammar, allTrainPaths, allTestPaths, lp.stateIndex);
        lp.stateIndex = compactedGrammar.first();
        lp.ug = compactedGrammar.second();
        lp.bg = compactedGrammar.third();
        Timing.tick("done.");
    }
    if (asciiOutputPath != null) {
        lp.saveParserToTextFile(asciiOutputPath);
    }
    // test it
    Treebank testTreebank = op.tlpParams.testMemoryTreebank();
    testTreebank.loadPath(path, new NumberRangeFileFilter(testLow, testHigh, true));
    System.out.println("Currently " + new Date());
    EvaluateTreebank evaluator = new EvaluateTreebank(lp);
    evaluator.testOnTreebank(testTreebank);
    System.out.println("Currently " + new Date());
}
Also used : EvaluateTreebank(edu.stanford.nlp.parser.metrics.EvaluateTreebank) NumberRangeFileFilter(edu.stanford.nlp.io.NumberRangeFileFilter) EvaluateTreebank(edu.stanford.nlp.parser.metrics.EvaluateTreebank)

Example 2 with EvaluateTreebank

use of edu.stanford.nlp.parser.metrics.EvaluateTreebank in project CoreNLP by stanfordnlp.

the class DVParser method train.

public void train(List<Tree> sentences, IdentityHashMap<Tree, byte[]> compressedParses, Treebank testTreebank, String modelPath, String resultsRecordPath) throws IOException {
    // process:
    // we come up with a cost and a derivative for the model
    // we always use the gold tree as the example to train towards
    // every time through, we will look at the top N trees from
    // the LexicalizedParser and pick the best one according to
    // our model (at the start, this is essentially random)
    // we use QN to minimize the cost function for the model
    // to do this minimization, we turn all of the matrices in the
    // DVModel into one big Theta, which is the set of variables to
    // be optimized by the QN.
    Timing timing = new Timing();
    long maxTrainTimeMillis = op.trainOptions.maxTrainTimeSeconds * 1000;
    int batchCount = 0;
    int debugCycle = 0;
    double bestLabelF1 = 0.0;
    if (op.trainOptions.useContextWords) {
        for (Tree tree : sentences) {
            Trees.convertToCoreLabels(tree);
            tree.setSpans();
        }
    }
    // for AdaGrad
    double[] sumGradSquare = new double[dvModel.totalParamSize()];
    Arrays.fill(sumGradSquare, 1.0);
    int numBatches = sentences.size() / op.trainOptions.batchSize + 1;
    log.info("Training on " + sentences.size() + " trees in " + numBatches + " batches");
    log.info("Times through each training batch: " + op.trainOptions.trainingIterations);
    log.info("QN iterations per batch: " + op.trainOptions.qnIterationsPerBatch);
    for (int iter = 0; iter < op.trainOptions.trainingIterations; ++iter) {
        List<Tree> shuffledSentences = new ArrayList<>(sentences);
        Collections.shuffle(shuffledSentences, dvModel.rand);
        for (int batch = 0; batch < numBatches; ++batch) {
            ++batchCount;
            // This did not help performance
            // log.info("Setting AdaGrad's sum of squares to 1...");
            // Arrays.fill(sumGradSquare, 1.0);
            log.info("======================================");
            log.info("Iteration " + iter + " batch " + batch);
            // Each batch will be of the specified batch size, except the
            // last batch will include any leftover trees at the end of
            // the list
            int startTree = batch * op.trainOptions.batchSize;
            int endTree = (batch + 1) * op.trainOptions.batchSize;
            if (endTree > shuffledSentences.size()) {
                endTree = shuffledSentences.size();
            }
            executeOneTrainingBatch(shuffledSentences.subList(startTree, endTree), compressedParses, sumGradSquare);
            long totalElapsed = timing.report();
            log.info("Finished iteration " + iter + " batch " + batch + "; total training time " + totalElapsed + " ms");
            if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis) {
                // no need to debug output, we're done now
                break;
            }
            if (op.trainOptions.debugOutputFrequency > 0 && batchCount % op.trainOptions.debugOutputFrequency == 0) {
                log.info("Finished " + batchCount + " total batches, running evaluation cycle");
                // Time for debugging output!
                double tagF1 = 0.0;
                double labelF1 = 0.0;
                if (testTreebank != null) {
                    EvaluateTreebank evaluator = new EvaluateTreebank(attachModelToLexicalizedParser());
                    evaluator.testOnTreebank(testTreebank);
                    labelF1 = evaluator.getLBScore();
                    tagF1 = evaluator.getTagScore();
                    if (labelF1 > bestLabelF1) {
                        bestLabelF1 = labelF1;
                    }
                    log.info("Best label f1 on dev set so far: " + NF.format(bestLabelF1));
                }
                String tempName = null;
                if (modelPath != null) {
                    tempName = modelPath;
                    if (modelPath.endsWith(".ser.gz")) {
                        tempName = modelPath.substring(0, modelPath.length() - 7) + "-" + FILENAME.format(debugCycle) + "-" + NF.format(labelF1) + ".ser.gz";
                    }
                    saveModel(tempName);
                }
                String statusLine = ("CHECKPOINT:" + " iteration " + iter + " batch " + batch + " labelF1 " + NF.format(labelF1) + " tagF1 " + NF.format(tagF1) + " bestLabelF1 " + NF.format(bestLabelF1) + " model " + tempName + op.trainOptions + " word vectors: " + op.lexOptions.wordVectorFile + " numHid: " + op.lexOptions.numHid);
                log.info(statusLine);
                if (resultsRecordPath != null) {
                    // append
                    FileWriter fout = new FileWriter(resultsRecordPath, true);
                    fout.write(statusLine);
                    fout.write("\n");
                    fout.close();
                }
                ++debugCycle;
            }
        }
        long totalElapsed = timing.report();
        if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis) {
            // no need to debug output, we're done now
            log.info("Max training time exceeded, exiting");
            break;
        }
    }
}
Also used : EvaluateTreebank(edu.stanford.nlp.parser.metrics.EvaluateTreebank) FileWriter(java.io.FileWriter) ArrayList(java.util.ArrayList) Tree(edu.stanford.nlp.trees.Tree) Timing(edu.stanford.nlp.util.Timing)

Example 3 with EvaluateTreebank

use of edu.stanford.nlp.parser.metrics.EvaluateTreebank in project CoreNLP by stanfordnlp.

the class DVParser method main.

/**
 * An example command line for training a new parser:
 * <br>
 *  nohup java -mx6g edu.stanford.nlp.parser.dvparser.DVParser -cachedTrees /scr/nlp/data/dvparser/wsj/cached.wsj.train.simple.ser.gz -train -testTreebank  /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj/22 2200-2219 -debugOutputFrequency 400 -nofilter -trainingThreads 5 -parser /u/nlp/data/lexparser/wsjPCFG.nocompact.simple.ser.gz -trainingIterations 40 -batchSize 25 -model /scr/nlp/data/dvparser/wsj/wsj.combine.v2.ser.gz -unkWord "*UNK*" -dvCombineCategories &gt; /scr/nlp/data/dvparser/wsj/wsj.combine.v2.out 2&gt;&amp;1 &amp;
 */
public static void main(String[] args) throws IOException, ClassNotFoundException {
    if (args.length == 0) {
        help();
        System.exit(2);
    }
    log.info("Running DVParser with arguments:");
    for (String arg : args) {
        log.info("  " + arg);
    }
    log.info();
    String parserPath = null;
    String trainTreebankPath = null;
    FileFilter trainTreebankFilter = null;
    String cachedTrainTreesPath = null;
    boolean runGradientCheck = false;
    boolean runTraining = false;
    String testTreebankPath = null;
    FileFilter testTreebankFilter = null;
    String initialModelPath = null;
    String modelPath = null;
    boolean filter = true;
    String resultsRecordPath = null;
    List<String> unusedArgs = new ArrayList<>();
    // These parameters can be null or 0 if the model was not
    // serialized with the new parameters.  Setting the options at the
    // command line will override these defaults.
    // TODO: if/when we integrate back into the main branch and
    // rebuild models, we can get rid of this
    List<String> argsWithDefaults = new ArrayList<>(Arrays.asList(new String[] { "-wordVectorFile", Options.LexOptions.DEFAULT_WORD_VECTOR_FILE, "-dvKBest", Integer.toString(TrainOptions.DEFAULT_K_BEST), "-batchSize", Integer.toString(TrainOptions.DEFAULT_BATCH_SIZE), "-trainingIterations", Integer.toString(TrainOptions.DEFAULT_TRAINING_ITERATIONS), "-qnIterationsPerBatch", Integer.toString(TrainOptions.DEFAULT_QN_ITERATIONS_PER_BATCH), "-regCost", Double.toString(TrainOptions.DEFAULT_REGCOST), "-learningRate", Double.toString(TrainOptions.DEFAULT_LEARNING_RATE), "-deltaMargin", Double.toString(TrainOptions.DEFAULT_DELTA_MARGIN), "-unknownNumberVector", "-unknownDashedWordVectors", "-unknownCapsVector", "-unknownchinesepercentvector", "-unknownchinesenumbervector", "-unknownchineseyearvector", "-unkWord", "*UNK*", "-transformMatrixType", "DIAGONAL", "-scalingForInit", Double.toString(TrainOptions.DEFAULT_SCALING_FOR_INIT), "-trainWordVectors" }));
    argsWithDefaults.addAll(Arrays.asList(args));
    args = argsWithDefaults.toArray(new String[argsWithDefaults.size()]);
    for (int argIndex = 0; argIndex < args.length; ) {
        if (args[argIndex].equalsIgnoreCase("-parser")) {
            parserPath = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
            Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
            argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
            testTreebankPath = treebankDescription.first();
            testTreebankFilter = treebankDescription.second();
        } else if (args[argIndex].equalsIgnoreCase("-treebank")) {
            Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-treebank");
            argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
            trainTreebankPath = treebankDescription.first();
            trainTreebankFilter = treebankDescription.second();
        } else if (args[argIndex].equalsIgnoreCase("-cachedTrees")) {
            cachedTrainTreesPath = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-runGradientCheck")) {
            runGradientCheck = true;
            argIndex++;
        } else if (args[argIndex].equalsIgnoreCase("-train")) {
            runTraining = true;
            argIndex++;
        } else if (args[argIndex].equalsIgnoreCase("-model")) {
            modelPath = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-nofilter")) {
            filter = false;
            argIndex++;
        } else if (args[argIndex].equalsIgnoreCase("-continueTraining")) {
            runTraining = true;
            filter = false;
            initialModelPath = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-resultsRecord")) {
            resultsRecordPath = args[argIndex + 1];
            argIndex += 2;
        } else {
            unusedArgs.add(args[argIndex++]);
        }
    }
    if (parserPath == null && modelPath == null) {
        throw new IllegalArgumentException("Must supply either a base parser model with -parser or a serialized DVParser with -model");
    }
    if (!runTraining && modelPath == null && !runGradientCheck) {
        throw new IllegalArgumentException("Need to either train a new model, run the gradient check or specify a model to load with -model");
    }
    String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
    DVParser dvparser = null;
    LexicalizedParser lexparser = null;
    if (initialModelPath != null) {
        lexparser = LexicalizedParser.loadModel(initialModelPath, newArgs);
        DVModel model = getModelFromLexicalizedParser(lexparser);
        dvparser = new DVParser(model, lexparser);
    } else if (runTraining || runGradientCheck) {
        lexparser = LexicalizedParser.loadModel(parserPath, newArgs);
        dvparser = new DVParser(lexparser);
    } else if (modelPath != null) {
        lexparser = LexicalizedParser.loadModel(modelPath, newArgs);
        DVModel model = getModelFromLexicalizedParser(lexparser);
        dvparser = new DVParser(model, lexparser);
    }
    List<Tree> trainSentences = new ArrayList<>();
    IdentityHashMap<Tree, byte[]> trainCompressedParses = Generics.newIdentityHashMap();
    if (cachedTrainTreesPath != null) {
        for (String path : cachedTrainTreesPath.split(",")) {
            List<Pair<Tree, byte[]>> cache = IOUtils.readObjectFromFile(path);
            for (Pair<Tree, byte[]> pair : cache) {
                trainSentences.add(pair.first());
                trainCompressedParses.put(pair.first(), pair.second());
            }
            log.info("Read in " + cache.size() + " trees from " + path);
        }
    }
    if (trainTreebankPath != null) {
        // TODO: make the transformer a member of the model?
        TreeTransformer transformer = buildTrainTransformer(dvparser.getOp());
        Treebank treebank = dvparser.getOp().tlpParams.memoryTreebank();
        ;
        treebank.loadPath(trainTreebankPath, trainTreebankFilter);
        treebank = treebank.transform(transformer);
        log.info("Read in " + treebank.size() + " trees from " + trainTreebankPath);
        CacheParseHypotheses cacher = new CacheParseHypotheses(dvparser.parser);
        CacheParseHypotheses.CacheProcessor processor = new CacheParseHypotheses.CacheProcessor(cacher, lexparser, dvparser.op.trainOptions.dvKBest, transformer);
        for (Tree tree : treebank) {
            trainSentences.add(tree);
            trainCompressedParses.put(tree, processor.process(tree).second);
        // System.out.println(tree);
        }
        log.info("Finished parsing " + treebank.size() + " trees, getting " + dvparser.op.trainOptions.dvKBest + " hypotheses each");
    }
    if ((runTraining || runGradientCheck) && filter) {
        log.info("Filtering rules for the given training set");
        dvparser.dvModel.setRulesForTrainingSet(trainSentences, trainCompressedParses);
        log.info("Done filtering rules; " + dvparser.dvModel.numBinaryMatrices + " binary matrices, " + dvparser.dvModel.numUnaryMatrices + " unary matrices, " + dvparser.dvModel.wordVectors.size() + " word vectors");
    }
    // dvparser.dvModel.printAllMatrices();
    Treebank testTreebank = null;
    if (testTreebankPath != null) {
        log.info("Reading in trees from " + testTreebankPath);
        if (testTreebankFilter != null) {
            log.info("Filtering on " + testTreebankFilter);
        }
        testTreebank = dvparser.getOp().tlpParams.memoryTreebank();
        ;
        testTreebank.loadPath(testTreebankPath, testTreebankFilter);
        log.info("Read in " + testTreebank.size() + " trees for testing");
    }
    // runGradientCheck= true;
    if (runGradientCheck) {
        log.info("Running gradient check on " + trainSentences.size() + " trees");
        dvparser.runGradientCheck(trainSentences, trainCompressedParses);
    }
    if (runTraining) {
        log.info("Training the RNN parser");
        log.info("Current train options: " + dvparser.getOp().trainOptions);
        dvparser.train(trainSentences, trainCompressedParses, testTreebank, modelPath, resultsRecordPath);
        if (modelPath != null) {
            dvparser.saveModel(modelPath);
        }
    }
    if (testTreebankPath != null) {
        EvaluateTreebank evaluator = new EvaluateTreebank(dvparser.attachModelToLexicalizedParser());
        evaluator.testOnTreebank(testTreebank);
    }
    log.info("Successfully ran DVParser");
}
Also used : Treebank(edu.stanford.nlp.trees.Treebank) EvaluateTreebank(edu.stanford.nlp.parser.metrics.EvaluateTreebank) LexicalizedParser(edu.stanford.nlp.parser.lexparser.LexicalizedParser) ArrayList(java.util.ArrayList) EvaluateTreebank(edu.stanford.nlp.parser.metrics.EvaluateTreebank) Tree(edu.stanford.nlp.trees.Tree) FileFilter(java.io.FileFilter) TreeTransformer(edu.stanford.nlp.trees.TreeTransformer) CompositeTreeTransformer(edu.stanford.nlp.trees.CompositeTreeTransformer) Pair(edu.stanford.nlp.util.Pair)

Example 4 with EvaluateTreebank

use of edu.stanford.nlp.parser.metrics.EvaluateTreebank in project CoreNLP by stanfordnlp.

the class CrossValidateTestOptions method main.

public static void main(String[] args) throws IOException, ClassNotFoundException {
    String dvmodelFile = null;
    String lexparserFile = null;
    String testTreebankPath = null;
    FileFilter testTreebankFilter = null;
    List<String> unusedArgs = new ArrayList<>();
    for (int argIndex = 0; argIndex < args.length; ) {
        if (args[argIndex].equalsIgnoreCase("-lexparser")) {
            lexparserFile = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
            Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
            argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
            testTreebankPath = treebankDescription.first();
            testTreebankFilter = treebankDescription.second();
        } else {
            unusedArgs.add(args[argIndex++]);
        }
    }
    log.info("Loading lexparser from: " + lexparserFile);
    String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
    LexicalizedParser lexparser = LexicalizedParser.loadModel(lexparserFile, newArgs);
    log.info("... done");
    Treebank testTreebank = null;
    if (testTreebankPath != null) {
        log.info("Reading in trees from " + testTreebankPath);
        if (testTreebankFilter != null) {
            log.info("Filtering on " + testTreebankFilter);
        }
        testTreebank = lexparser.getOp().tlpParams.memoryTreebank();
        ;
        testTreebank.loadPath(testTreebankPath, testTreebankFilter);
        log.info("Read in " + testTreebank.size() + " trees for testing");
    }
    double[] labelResults = new double[weights.length];
    double[] tagResults = new double[weights.length];
    for (int i = 0; i < weights.length; ++i) {
        lexparser.getOp().baseParserWeight = weights[i];
        EvaluateTreebank evaluator = new EvaluateTreebank(lexparser);
        evaluator.testOnTreebank(testTreebank);
        labelResults[i] = evaluator.getLBScore();
        tagResults[i] = evaluator.getTagScore();
    }
    for (int i = 0; i < weights.length; ++i) {
        log.info("LexicalizedParser weight " + weights[i] + ": labeled " + labelResults[i] + " tag " + tagResults[i]);
    }
}
Also used : EvaluateTreebank(edu.stanford.nlp.parser.metrics.EvaluateTreebank) Treebank(edu.stanford.nlp.trees.Treebank) EvaluateTreebank(edu.stanford.nlp.parser.metrics.EvaluateTreebank) LexicalizedParser(edu.stanford.nlp.parser.lexparser.LexicalizedParser) ArrayList(java.util.ArrayList) FileFilter(java.io.FileFilter) Pair(edu.stanford.nlp.util.Pair)

Example 5 with EvaluateTreebank

use of edu.stanford.nlp.parser.metrics.EvaluateTreebank in project CoreNLP by stanfordnlp.

the class PerceptronModel method evaluate.

private double evaluate(Tagger tagger, Treebank devTreebank, String message) {
    ShiftReduceParser temp = new ShiftReduceParser(op, this);
    EvaluateTreebank evaluator = new EvaluateTreebank(temp.getOp(), null, temp, tagger, temp.getExtraEvals(), temp.getParserQueryEvals());
    evaluator.testOnTreebank(devTreebank);
    double labelF1 = evaluator.getLBScore();
    log.info(message + ": " + labelF1);
    return labelF1;
}
Also used : EvaluateTreebank(edu.stanford.nlp.parser.metrics.EvaluateTreebank)

Aggregations

EvaluateTreebank (edu.stanford.nlp.parser.metrics.EvaluateTreebank)8 Pair (edu.stanford.nlp.util.Pair)5 Treebank (edu.stanford.nlp.trees.Treebank)4 FileFilter (java.io.FileFilter)4 ArrayList (java.util.ArrayList)4 LexicalizedParser (edu.stanford.nlp.parser.lexparser.LexicalizedParser)3 Tree (edu.stanford.nlp.trees.Tree)2 Timing (edu.stanford.nlp.util.Timing)2 NumberRangeFileFilter (edu.stanford.nlp.io.NumberRangeFileFilter)1 RuntimeIOException (edu.stanford.nlp.io.RuntimeIOException)1 HasWord (edu.stanford.nlp.ling.HasWord)1 TaggedWord (edu.stanford.nlp.ling.TaggedWord)1 ParserQuery (edu.stanford.nlp.parser.common.ParserQuery)1 Options (edu.stanford.nlp.parser.lexparser.Options)1 Reranker (edu.stanford.nlp.parser.lexparser.Reranker)1 TokenizerFactory (edu.stanford.nlp.process.TokenizerFactory)1 TaggedFileRecord (edu.stanford.nlp.tagger.io.TaggedFileRecord)1 CompositeTreeTransformer (edu.stanford.nlp.trees.CompositeTreeTransformer)1 TreeTransformer (edu.stanford.nlp.trees.TreeTransformer)1 Triple (edu.stanford.nlp.util.Triple)1