Search in sources :

Example 1 with AbstractEval

use of edu.stanford.nlp.parser.metrics.AbstractEval in project CoreNLP by stanfordnlp.

the class FactoredParser method main.

/* some documentation for Roger's convenience
 * {pcfg,dep,combo}{PE,DE,TE} are precision/dep/tagging evals for the models

 * parser is the PCFG parser
 * dparser is the dependency parser
 * bparser is the combining parser

 * during testing:
 * tree is the test tree (gold tree)
 * binaryTree is the gold tree binarized
 * tree2b is the best PCFG paser, binarized
 * tree2 is the best PCFG parse (debinarized)
 * tree3 is the dependency parse, binarized
 * tree3db is the dependency parser, debinarized
 * tree4 is the best combo parse, binarized and then debinarized
 * tree4b is the best combo parse, binarized
 */
public static void main(String[] args) {
    Options op = new Options(new EnglishTreebankParserParams());
    // op.tlpParams may be changed to something else later, so don't use it till
    // after options are parsed.
    StringUtils.logInvocationString(log, args);
    String path = "/u/nlp/stuff/corpora/Treebank3/parsed/mrg/wsj";
    int trainLow = 200, trainHigh = 2199, testLow = 2200, testHigh = 2219;
    String serializeFile = null;
    int i = 0;
    while (i < args.length && args[i].startsWith("-")) {
        if (args[i].equalsIgnoreCase("-path") && (i + 1 < args.length)) {
            path = args[i + 1];
            i += 2;
        } else if (args[i].equalsIgnoreCase("-train") && (i + 2 < args.length)) {
            trainLow = Integer.parseInt(args[i + 1]);
            trainHigh = Integer.parseInt(args[i + 2]);
            i += 3;
        } else if (args[i].equalsIgnoreCase("-test") && (i + 2 < args.length)) {
            testLow = Integer.parseInt(args[i + 1]);
            testHigh = Integer.parseInt(args[i + 2]);
            i += 3;
        } else if (args[i].equalsIgnoreCase("-serialize") && (i + 1 < args.length)) {
            serializeFile = args[i + 1];
            i += 2;
        } else if (args[i].equalsIgnoreCase("-tLPP") && (i + 1 < args.length)) {
            try {
                op.tlpParams = (TreebankLangParserParams) Class.forName(args[i + 1]).newInstance();
            } catch (ClassNotFoundException e) {
                log.info("Class not found: " + args[i + 1]);
                throw new RuntimeException(e);
            } catch (InstantiationException e) {
                log.info("Couldn't instantiate: " + args[i + 1] + ": " + e.toString());
                throw new RuntimeException(e);
            } catch (IllegalAccessException e) {
                log.info("illegal access" + e);
                throw new RuntimeException(e);
            }
            i += 2;
        } else if (args[i].equals("-encoding")) {
            // sets encoding for TreebankLangParserParams
            op.tlpParams.setInputEncoding(args[i + 1]);
            op.tlpParams.setOutputEncoding(args[i + 1]);
            i += 2;
        } else {
            i = op.setOptionOrWarn(args, i);
        }
    }
    // System.out.println(tlpParams.getClass());
    TreebankLanguagePack tlp = op.tlpParams.treebankLanguagePack();
    op.trainOptions.sisterSplitters = Generics.newHashSet(Arrays.asList(op.tlpParams.sisterSplitters()));
    // BinarizerFactory.TreeAnnotator.setTreebankLang(tlpParams);
    PrintWriter pw = op.tlpParams.pw();
    op.testOptions.display();
    op.trainOptions.display();
    op.display();
    op.tlpParams.display();
    // setup tree transforms
    Treebank trainTreebank = op.tlpParams.memoryTreebank();
    MemoryTreebank testTreebank = op.tlpParams.testMemoryTreebank();
    // Treebank blippTreebank = ((EnglishTreebankParserParams) tlpParams).diskTreebank();
    // String blippPath = "/afs/ir.stanford.edu/data/linguistic-data/BLLIP-WSJ/";
    // blippTreebank.loadPath(blippPath, "", true);
    Timing.startTime();
    log.info("Reading trees...");
    testTreebank.loadPath(path, new NumberRangeFileFilter(testLow, testHigh, true));
    if (op.testOptions.increasingLength) {
        Collections.sort(testTreebank, new TreeLengthComparator());
    }
    trainTreebank.loadPath(path, new NumberRangeFileFilter(trainLow, trainHigh, true));
    Timing.tick("done.");
    log.info("Binarizing trees...");
    TreeAnnotatorAndBinarizer binarizer;
    if (!op.trainOptions.leftToRight) {
        binarizer = new TreeAnnotatorAndBinarizer(op.tlpParams, op.forceCNF, !op.trainOptions.outsideFactor(), true, op);
    } else {
        binarizer = new TreeAnnotatorAndBinarizer(op.tlpParams.headFinder(), new LeftHeadFinder(), op.tlpParams, op.forceCNF, !op.trainOptions.outsideFactor(), true, op);
    }
    CollinsPuncTransformer collinsPuncTransformer = null;
    if (op.trainOptions.collinsPunc) {
        collinsPuncTransformer = new CollinsPuncTransformer(tlp);
    }
    TreeTransformer debinarizer = new Debinarizer(op.forceCNF);
    List<Tree> binaryTrainTrees = new ArrayList<>();
    if (op.trainOptions.selectiveSplit) {
        op.trainOptions.splitters = ParentAnnotationStats.getSplitCategories(trainTreebank, op.trainOptions.tagSelectiveSplit, 0, op.trainOptions.selectiveSplitCutOff, op.trainOptions.tagSelectiveSplitCutOff, op.tlpParams.treebankLanguagePack());
        if (op.trainOptions.deleteSplitters != null) {
            List<String> deleted = new ArrayList<>();
            for (String del : op.trainOptions.deleteSplitters) {
                String baseDel = tlp.basicCategory(del);
                boolean checkBasic = del.equals(baseDel);
                for (Iterator<String> it = op.trainOptions.splitters.iterator(); it.hasNext(); ) {
                    String elem = it.next();
                    String baseElem = tlp.basicCategory(elem);
                    boolean delStr = checkBasic && baseElem.equals(baseDel) || elem.equals(del);
                    if (delStr) {
                        it.remove();
                        deleted.add(elem);
                    }
                }
            }
            log.info("Removed from vertical splitters: " + deleted);
        }
    }
    if (op.trainOptions.selectivePostSplit) {
        TreeTransformer myTransformer = new TreeAnnotator(op.tlpParams.headFinder(), op.tlpParams, op);
        Treebank annotatedTB = trainTreebank.transform(myTransformer);
        op.trainOptions.postSplitters = ParentAnnotationStats.getSplitCategories(annotatedTB, true, 0, op.trainOptions.selectivePostSplitCutOff, op.trainOptions.tagSelectivePostSplitCutOff, op.tlpParams.treebankLanguagePack());
    }
    if (op.trainOptions.hSelSplit) {
        binarizer.setDoSelectiveSplit(false);
        for (Tree tree : trainTreebank) {
            if (op.trainOptions.collinsPunc) {
                tree = collinsPuncTransformer.transformTree(tree);
            }
            // tree.pennPrint(tlpParams.pw());
            tree = binarizer.transformTree(tree);
        // binaryTrainTrees.add(tree);
        }
        binarizer.setDoSelectiveSplit(true);
    }
    for (Tree tree : trainTreebank) {
        if (op.trainOptions.collinsPunc) {
            tree = collinsPuncTransformer.transformTree(tree);
        }
        tree = binarizer.transformTree(tree);
        binaryTrainTrees.add(tree);
    }
    if (op.testOptions.verbose) {
        binarizer.dumpStats();
    }
    List<Tree> binaryTestTrees = new ArrayList<>();
    for (Tree tree : testTreebank) {
        if (op.trainOptions.collinsPunc) {
            tree = collinsPuncTransformer.transformTree(tree);
        }
        tree = binarizer.transformTree(tree);
        binaryTestTrees.add(tree);
    }
    // binarization
    Timing.tick("done.");
    BinaryGrammar bg = null;
    UnaryGrammar ug = null;
    DependencyGrammar dg = null;
    // DependencyGrammar dgBLIPP = null;
    Lexicon lex = null;
    Index<String> stateIndex = new HashIndex<>();
    // extract grammars
    Extractor<Pair<UnaryGrammar, BinaryGrammar>> bgExtractor = new BinaryGrammarExtractor(op, stateIndex);
    if (op.doPCFG) {
        log.info("Extracting PCFG...");
        Pair<UnaryGrammar, BinaryGrammar> bgug = null;
        if (op.trainOptions.cheatPCFG) {
            List<Tree> allTrees = new ArrayList<>(binaryTrainTrees);
            allTrees.addAll(binaryTestTrees);
            bgug = bgExtractor.extract(allTrees);
        } else {
            bgug = bgExtractor.extract(binaryTrainTrees);
        }
        bg = bgug.second;
        bg.splitRules();
        ug = bgug.first;
        ug.purgeRules();
        Timing.tick("done.");
    }
    log.info("Extracting Lexicon...");
    Index<String> wordIndex = new HashIndex<>();
    Index<String> tagIndex = new HashIndex<>();
    lex = op.tlpParams.lex(op, wordIndex, tagIndex);
    lex.initializeTraining(binaryTrainTrees.size());
    lex.train(binaryTrainTrees);
    lex.finishTraining();
    Timing.tick("done.");
    if (op.doDep) {
        log.info("Extracting Dependencies...");
        binaryTrainTrees.clear();
        Extractor<DependencyGrammar> dgExtractor = new MLEDependencyGrammarExtractor(op, wordIndex, tagIndex);
        // dgBLIPP = (DependencyGrammar) dgExtractor.extract(new ConcatenationIterator(trainTreebank.iterator(),blippTreebank.iterator()),new TransformTreeDependency(tlpParams,true));
        // DependencyGrammar dg1 = dgExtractor.extract(trainTreebank.iterator(), new TransformTreeDependency(op.tlpParams, true));
        // dgBLIPP=(DependencyGrammar)dgExtractor.extract(blippTreebank.iterator(),new TransformTreeDependency(tlpParams));
        // dg = (DependencyGrammar) dgExtractor.extract(new ConcatenationIterator(trainTreebank.iterator(),blippTreebank.iterator()),new TransformTreeDependency(tlpParams));
        // dg=new DependencyGrammarCombination(dg1,dgBLIPP,2);
        // uses information whether the words are known or not, discards unknown words
        dg = dgExtractor.extract(binaryTrainTrees);
        Timing.tick("done.");
        // System.out.print("Extracting Unknown Word Model...");
        // UnknownWordModel uwm = (UnknownWordModel)uwmExtractor.extract(binaryTrainTrees);
        // Timing.tick("done.");
        System.out.print("Tuning Dependency Model...");
        dg.tune(binaryTestTrees);
        // System.out.println("TUNE DEPS: "+tuneDeps);
        Timing.tick("done.");
    }
    BinaryGrammar boundBG = bg;
    UnaryGrammar boundUG = ug;
    GrammarProjection gp = new NullGrammarProjection(bg, ug);
    // serialization
    if (serializeFile != null) {
        log.info("Serializing parser...");
        LexicalizedParser parser = new LexicalizedParser(lex, bg, ug, dg, stateIndex, wordIndex, tagIndex, op);
        parser.saveParserToSerialized(serializeFile);
        Timing.tick("done.");
    }
    // test: pcfg-parse and output
    ExhaustivePCFGParser parser = null;
    if (op.doPCFG) {
        parser = new ExhaustivePCFGParser(boundBG, boundUG, lex, op, stateIndex, wordIndex, tagIndex);
    }
    ExhaustiveDependencyParser dparser = ((op.doDep && !op.testOptions.useFastFactored) ? new ExhaustiveDependencyParser(dg, lex, op, wordIndex, tagIndex) : null);
    Scorer scorer = (op.doPCFG ? new TwinScorer(new ProjectionScorer(parser, gp, op), dparser) : null);
    // Scorer scorer = parser;
    BiLexPCFGParser bparser = null;
    if (op.doPCFG && op.doDep) {
        bparser = (op.testOptions.useN5) ? new BiLexPCFGParser.N5BiLexPCFGParser(scorer, parser, dparser, bg, ug, dg, lex, op, gp, stateIndex, wordIndex, tagIndex) : new BiLexPCFGParser(scorer, parser, dparser, bg, ug, dg, lex, op, gp, stateIndex, wordIndex, tagIndex);
    }
    Evalb pcfgPE = new Evalb("pcfg  PE", true);
    Evalb comboPE = new Evalb("combo PE", true);
    AbstractEval pcfgCB = new Evalb.CBEval("pcfg  CB", true);
    AbstractEval pcfgTE = new TaggingEval("pcfg  TE");
    AbstractEval comboTE = new TaggingEval("combo TE");
    AbstractEval pcfgTEnoPunct = new TaggingEval("pcfg nopunct TE");
    AbstractEval comboTEnoPunct = new TaggingEval("combo nopunct TE");
    AbstractEval depTE = new TaggingEval("depnd TE");
    AbstractEval depDE = new UnlabeledAttachmentEval("depnd DE", true, null, tlp.punctuationWordRejectFilter());
    AbstractEval comboDE = new UnlabeledAttachmentEval("combo DE", true, null, tlp.punctuationWordRejectFilter());
    if (op.testOptions.evalb) {
        EvalbFormatWriter.initEVALBfiles(op.tlpParams);
    }
    // int[] countByLength = new int[op.testOptions.maxLength+1];
    // Use a reflection ruse, so one can run this without needing the
    // tagger.  Using a function rather than a MaxentTagger means we
    // can distribute a version of the parser that doesn't include the
    // entire tagger.
    Function<List<? extends HasWord>, ArrayList<TaggedWord>> tagger = null;
    if (op.testOptions.preTag) {
        try {
            Class[] argsClass = { String.class };
            Object[] arguments = new Object[] { op.testOptions.taggerSerializedFile };
            tagger = (Function<List<? extends HasWord>, ArrayList<TaggedWord>>) Class.forName("edu.stanford.nlp.tagger.maxent.MaxentTagger").getConstructor(argsClass).newInstance(arguments);
        } catch (Exception e) {
            log.info(e);
            log.info("Warning: No pretagging of sentences will be done.");
        }
    }
    for (int tNum = 0, ttSize = testTreebank.size(); tNum < ttSize; tNum++) {
        Tree tree = testTreebank.get(tNum);
        int testTreeLen = tree.yield().size();
        if (testTreeLen > op.testOptions.maxLength) {
            continue;
        }
        Tree binaryTree = binaryTestTrees.get(tNum);
        // countByLength[testTreeLen]++;
        System.out.println("-------------------------------------");
        System.out.println("Number: " + (tNum + 1));
        System.out.println("Length: " + testTreeLen);
        // tree.pennPrint(pw);
        // System.out.println("XXXX The binary tree is");
        // binaryTree.pennPrint(pw);
        // System.out.println("Here are the tags in the lexicon:");
        // System.out.println(lex.showTags());
        // System.out.println("Here's the tagnumberer:");
        // System.out.println(Numberer.getGlobalNumberer("tags").toString());
        long timeMil1 = System.currentTimeMillis();
        Timing.tick("Starting parse.");
        if (op.doPCFG) {
            // log.info(op.testOptions.forceTags);
            if (op.testOptions.forceTags) {
                if (tagger != null) {
                    // System.out.println("Using a tagger to set tags");
                    // System.out.println("Tagged sentence as: " + tagger.processSentence(cutLast(wordify(binaryTree.yield()))).toString(false));
                    parser.parse(addLast(tagger.apply(cutLast(wordify(binaryTree.yield())))));
                } else {
                    // System.out.println("Forcing tags to match input.");
                    parser.parse(cleanTags(binaryTree.taggedYield(), tlp));
                }
            } else {
                // System.out.println("XXXX Parsing " + binaryTree.yield());
                parser.parse(binaryTree.yieldHasWord());
            }
        // Timing.tick("Done with pcfg phase.");
        }
        if (op.doDep) {
            dparser.parse(binaryTree.yieldHasWord());
        // Timing.tick("Done with dependency phase.");
        }
        boolean bothPassed = false;
        if (op.doPCFG && op.doDep) {
            bothPassed = bparser.parse(binaryTree.yieldHasWord());
        // Timing.tick("Done with combination phase.");
        }
        long timeMil2 = System.currentTimeMillis();
        long elapsed = timeMil2 - timeMil1;
        log.info("Time: " + ((int) (elapsed / 100)) / 10.00 + " sec.");
        // System.out.println("PCFG Best Parse:");
        Tree tree2b = null;
        Tree tree2 = null;
        // System.out.println("Got full best parse...");
        if (op.doPCFG) {
            tree2b = parser.getBestParse();
            tree2 = debinarizer.transformTree(tree2b);
        }
        // System.out.println("Debinarized parse...");
        // tree2.pennPrint();
        // System.out.println("DepG Best Parse:");
        Tree tree3 = null;
        Tree tree3db = null;
        if (op.doDep) {
            tree3 = dparser.getBestParse();
            // was: but wrong Tree tree3db = debinarizer.transformTree(tree2);
            tree3db = debinarizer.transformTree(tree3);
            tree3.pennPrint(pw);
        }
        // tree.pennPrint();
        // ((Tree)binaryTrainTrees.get(tNum)).pennPrint();
        // System.out.println("Combo Best Parse:");
        Tree tree4 = null;
        if (op.doPCFG && op.doDep) {
            try {
                tree4 = bparser.getBestParse();
                if (tree4 == null) {
                    tree4 = tree2b;
                }
            } catch (NullPointerException e) {
                log.info("Blocked, using PCFG parse!");
                tree4 = tree2b;
            }
        }
        if (op.doPCFG && !bothPassed) {
            tree4 = tree2b;
        }
        // tree4.pennPrint();
        if (op.doDep) {
            depDE.evaluate(tree3, binaryTree, pw);
            depTE.evaluate(tree3db, tree, pw);
        }
        TreeTransformer tc = op.tlpParams.collinizer();
        TreeTransformer tcEvalb = op.tlpParams.collinizerEvalb();
        if (op.doPCFG) {
            // System.out.println("XXXX Best PCFG was: ");
            // tree2.pennPrint();
            // System.out.println("XXXX Transformed best PCFG is: ");
            // tc.transformTree(tree2).pennPrint();
            // System.out.println("True Best Parse:");
            // tree.pennPrint();
            // tc.transformTree(tree).pennPrint();
            pcfgPE.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw);
            pcfgCB.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw);
            Tree tree4b = null;
            if (op.doDep) {
                comboDE.evaluate((bothPassed ? tree4 : tree3), binaryTree, pw);
                tree4b = tree4;
                tree4 = debinarizer.transformTree(tree4);
                if (op.nodePrune) {
                    NodePruner np = new NodePruner(parser, debinarizer);
                    tree4 = np.prune(tree4);
                }
                // tree4.pennPrint();
                comboPE.evaluate(tc.transformTree(tree4), tc.transformTree(tree), pw);
            }
            // pcfgTE.evaluate(tree2, tree);
            pcfgTE.evaluate(tcEvalb.transformTree(tree2), tcEvalb.transformTree(tree), pw);
            pcfgTEnoPunct.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw);
            if (op.doDep) {
                comboTE.evaluate(tcEvalb.transformTree(tree4), tcEvalb.transformTree(tree), pw);
                comboTEnoPunct.evaluate(tc.transformTree(tree4), tc.transformTree(tree), pw);
            }
            System.out.println("PCFG only: " + parser.scoreBinarizedTree(tree2b, 0));
            // tc.transformTree(tree2).pennPrint();
            tree2.pennPrint(pw);
            if (op.doDep) {
                System.out.println("Combo: " + parser.scoreBinarizedTree(tree4b, 0));
                // tc.transformTree(tree4).pennPrint(pw);
                tree4.pennPrint(pw);
            }
            System.out.println("Correct:" + parser.scoreBinarizedTree(binaryTree, 0));
            /*
        if (parser.scoreBinarizedTree(tree2b,true) < parser.scoreBinarizedTree(binaryTree,true)) {
          System.out.println("SCORE INVERSION");
          parser.validateBinarizedTree(binaryTree,0);
        }
        */
            tree.pennPrint(pw);
        }
        if (op.testOptions.evalb) {
            if (op.doPCFG && op.doDep) {
                EvalbFormatWriter.writeEVALBline(tcEvalb.transformTree(tree), tcEvalb.transformTree(tree4));
            } else if (op.doPCFG) {
                EvalbFormatWriter.writeEVALBline(tcEvalb.transformTree(tree), tcEvalb.transformTree(tree2));
            } else if (op.doDep) {
                EvalbFormatWriter.writeEVALBline(tcEvalb.transformTree(tree), tcEvalb.transformTree(tree3db));
            }
        }
    }
    if (op.testOptions.evalb) {
        EvalbFormatWriter.closeEVALBfiles();
    }
    // op.testOptions.display();
    if (op.doPCFG) {
        pcfgPE.display(false, pw);
        System.out.println("Grammar size: " + stateIndex.size());
        pcfgCB.display(false, pw);
        if (op.doDep) {
            comboPE.display(false, pw);
        }
        pcfgTE.display(false, pw);
        pcfgTEnoPunct.display(false, pw);
        if (op.doDep) {
            comboTE.display(false, pw);
            comboTEnoPunct.display(false, pw);
        }
    }
    if (op.doDep) {
        depTE.display(false, pw);
        depDE.display(false, pw);
    }
    if (op.doPCFG && op.doDep) {
        comboDE.display(false, pw);
    }
// pcfgPE.printGoodBad();
}
Also used : Treebank(edu.stanford.nlp.trees.Treebank) MemoryTreebank(edu.stanford.nlp.trees.MemoryTreebank) ArrayList(java.util.ArrayList) Tree(edu.stanford.nlp.trees.Tree) TreebankLanguagePack(edu.stanford.nlp.trees.TreebankLanguagePack) ArrayList(java.util.ArrayList) List(java.util.List) TaggingEval(edu.stanford.nlp.parser.metrics.TaggingEval) NumberRangeFileFilter(edu.stanford.nlp.io.NumberRangeFileFilter) Evalb(edu.stanford.nlp.parser.metrics.Evalb) TreeTransformer(edu.stanford.nlp.trees.TreeTransformer) UnlabeledAttachmentEval(edu.stanford.nlp.parser.metrics.UnlabeledAttachmentEval) MemoryTreebank(edu.stanford.nlp.trees.MemoryTreebank) PrintWriter(java.io.PrintWriter) Pair(edu.stanford.nlp.util.Pair) HasWord(edu.stanford.nlp.ling.HasWord) AbstractEval(edu.stanford.nlp.parser.metrics.AbstractEval) LeftHeadFinder(edu.stanford.nlp.trees.LeftHeadFinder) TreeLengthComparator(edu.stanford.nlp.trees.TreeLengthComparator) HashIndex(edu.stanford.nlp.util.HashIndex)

Aggregations

NumberRangeFileFilter (edu.stanford.nlp.io.NumberRangeFileFilter)1 HasWord (edu.stanford.nlp.ling.HasWord)1 AbstractEval (edu.stanford.nlp.parser.metrics.AbstractEval)1 Evalb (edu.stanford.nlp.parser.metrics.Evalb)1 TaggingEval (edu.stanford.nlp.parser.metrics.TaggingEval)1 UnlabeledAttachmentEval (edu.stanford.nlp.parser.metrics.UnlabeledAttachmentEval)1 LeftHeadFinder (edu.stanford.nlp.trees.LeftHeadFinder)1 MemoryTreebank (edu.stanford.nlp.trees.MemoryTreebank)1 Tree (edu.stanford.nlp.trees.Tree)1 TreeLengthComparator (edu.stanford.nlp.trees.TreeLengthComparator)1 TreeTransformer (edu.stanford.nlp.trees.TreeTransformer)1 Treebank (edu.stanford.nlp.trees.Treebank)1 TreebankLanguagePack (edu.stanford.nlp.trees.TreebankLanguagePack)1 HashIndex (edu.stanford.nlp.util.HashIndex)1 Pair (edu.stanford.nlp.util.Pair)1 PrintWriter (java.io.PrintWriter)1 ArrayList (java.util.ArrayList)1 List (java.util.List)1