Search in sources :

Example 1 with WordCatEqualityChecker

use of edu.stanford.nlp.trees.WordCatEqualityChecker in project CoreNLP by stanfordnlp.

the class ChineseCharacterBasedLexiconTraining method main.

public static void main(String[] args) throws IOException {
    Map<String, Integer> flagsToNumArgs = Generics.newHashMap();
    flagsToNumArgs.put("-parser", Integer.valueOf(3));
    flagsToNumArgs.put("-lex", Integer.valueOf(3));
    flagsToNumArgs.put("-test", Integer.valueOf(2));
    flagsToNumArgs.put("-out", Integer.valueOf(1));
    flagsToNumArgs.put("-lengthPenalty", Integer.valueOf(1));
    flagsToNumArgs.put("-penaltyType", Integer.valueOf(1));
    flagsToNumArgs.put("-maxLength", Integer.valueOf(1));
    flagsToNumArgs.put("-stats", Integer.valueOf(2));
    Map<String, String[]> argMap = StringUtils.argsToMap(args, flagsToNumArgs);
    boolean eval = argMap.containsKey("-eval");
    PrintWriter pw = null;
    if (argMap.containsKey("-out")) {
        pw = new PrintWriter(new OutputStreamWriter(new FileOutputStream((argMap.get("-out"))[0]), "GB18030"), true);
    }
    log.info("ChineseCharacterBasedLexicon called with args:");
    ChineseTreebankParserParams ctpp = new ChineseTreebankParserParams();
    for (int i = 0; i < args.length; i++) {
        ctpp.setOptionFlag(args, i);
        log.info(" " + args[i]);
    }
    log.info();
    Options op = new Options(ctpp);
    if (argMap.containsKey("-stats")) {
        String[] statArgs = (argMap.get("-stats"));
        MemoryTreebank rawTrainTreebank = op.tlpParams.memoryTreebank();
        FileFilter trainFilt = new NumberRangesFileFilter(statArgs[1], false);
        rawTrainTreebank.loadPath(new File(statArgs[0]), trainFilt);
        log.info("Done reading trees.");
        MemoryTreebank trainTreebank;
        if (argMap.containsKey("-annotate")) {
            trainTreebank = new MemoryTreebank();
            TreeAnnotator annotator = new TreeAnnotator(ctpp.headFinder(), ctpp, op);
            for (Tree tree : rawTrainTreebank) {
                trainTreebank.add(annotator.transformTree(tree));
            }
            log.info("Done annotating trees.");
        } else {
            trainTreebank = rawTrainTreebank;
        }
        printStats(trainTreebank, pw);
        System.exit(0);
    }
    int maxLength = 1000000;
    //    Test.verbose = true;
    if (argMap.containsKey("-norm")) {
        op.testOptions.lengthNormalization = true;
    }
    if (argMap.containsKey("-maxLength")) {
        maxLength = Integer.parseInt((argMap.get("-maxLength"))[0]);
    }
    op.testOptions.maxLength = 120;
    boolean combo = argMap.containsKey("-combo");
    if (combo) {
        ctpp.useCharacterBasedLexicon = true;
        op.testOptions.maxSpanForTags = 10;
        op.doDep = false;
        op.dcTags = false;
    }
    LexicalizedParser lp = null;
    Lexicon lex = null;
    if (argMap.containsKey("-parser")) {
        String[] parserArgs = (argMap.get("-parser"));
        if (parserArgs.length > 1) {
            FileFilter trainFilt = new NumberRangesFileFilter(parserArgs[1], false);
            lp = LexicalizedParser.trainFromTreebank(parserArgs[0], trainFilt, op);
            if (parserArgs.length == 3) {
                String filename = parserArgs[2];
                log.info("Writing parser in serialized format to file " + filename + " ");
                System.err.flush();
                ObjectOutputStream out = IOUtils.writeStreamFromString(filename);
                out.writeObject(lp);
                out.close();
                log.info("done.");
            }
        } else {
            String parserFile = parserArgs[0];
            lp = LexicalizedParser.loadModel(parserFile, op);
        }
        lex = lp.getLexicon();
        op = lp.getOp();
        ctpp = (ChineseTreebankParserParams) op.tlpParams;
    }
    if (argMap.containsKey("-rad")) {
        ctpp.useUnknownCharacterModel = true;
    }
    if (argMap.containsKey("-lengthPenalty")) {
        ctpp.lengthPenalty = Double.parseDouble((argMap.get("-lengthPenalty"))[0]);
    }
    if (argMap.containsKey("-penaltyType")) {
        ctpp.penaltyType = Integer.parseInt((argMap.get("-penaltyType"))[0]);
    }
    if (argMap.containsKey("-lex")) {
        String[] lexArgs = (argMap.get("-lex"));
        if (lexArgs.length > 1) {
            Index<String> wordIndex = new HashIndex<>();
            Index<String> tagIndex = new HashIndex<>();
            lex = ctpp.lex(op, wordIndex, tagIndex);
            MemoryTreebank rawTrainTreebank = op.tlpParams.memoryTreebank();
            FileFilter trainFilt = new NumberRangesFileFilter(lexArgs[1], false);
            rawTrainTreebank.loadPath(new File(lexArgs[0]), trainFilt);
            log.info("Done reading trees.");
            MemoryTreebank trainTreebank;
            if (argMap.containsKey("-annotate")) {
                trainTreebank = new MemoryTreebank();
                TreeAnnotator annotator = new TreeAnnotator(ctpp.headFinder(), ctpp, op);
                for (Tree tree : rawTrainTreebank) {
                    tree = annotator.transformTree(tree);
                    trainTreebank.add(tree);
                }
                log.info("Done annotating trees.");
            } else {
                trainTreebank = rawTrainTreebank;
            }
            lex.initializeTraining(trainTreebank.size());
            lex.train(trainTreebank);
            lex.finishTraining();
            log.info("Done training lexicon.");
            if (lexArgs.length == 3) {
                String filename = lexArgs.length == 3 ? lexArgs[2] : "parsers/chineseCharLex.ser.gz";
                log.info("Writing lexicon in serialized format to file " + filename + " ");
                System.err.flush();
                ObjectOutputStream out = IOUtils.writeStreamFromString(filename);
                out.writeObject(lex);
                out.close();
                log.info("done.");
            }
        } else {
            String lexFile = lexArgs.length == 1 ? lexArgs[0] : "parsers/chineseCharLex.ser.gz";
            log.info("Reading Lexicon from file " + lexFile);
            ObjectInputStream in = IOUtils.readStreamFromString(lexFile);
            try {
                lex = (Lexicon) in.readObject();
            } catch (ClassNotFoundException e) {
                throw new RuntimeException("Bad serialized file: " + lexFile);
            }
            in.close();
        }
    }
    if (argMap.containsKey("-test")) {
        boolean segmentWords = ctpp.segment;
        boolean parse = lp != null;
        assert (parse || segmentWords);
        //      WordCatConstituent.collinizeWords = argMap.containsKey("-collinizeWords");
        //      WordCatConstituent.collinizeTags = argMap.containsKey("-collinizeTags");
        WordSegmenter seg = null;
        if (segmentWords) {
            seg = (WordSegmenter) lex;
        }
        String[] testArgs = (argMap.get("-test"));
        MemoryTreebank testTreebank = op.tlpParams.memoryTreebank();
        FileFilter testFilt = new NumberRangesFileFilter(testArgs[1], false);
        testTreebank.loadPath(new File(testArgs[0]), testFilt);
        TreeTransformer subcategoryStripper = op.tlpParams.subcategoryStripper();
        TreeTransformer collinizer = ctpp.collinizer();
        WordCatEquivalenceClasser eqclass = new WordCatEquivalenceClasser();
        WordCatEqualityChecker eqcheck = new WordCatEqualityChecker();
        EquivalenceClassEval basicEval = new EquivalenceClassEval(eqclass, eqcheck, "basic");
        EquivalenceClassEval collinsEval = new EquivalenceClassEval(eqclass, eqcheck, "collinized");
        List<String> evalTypes = new ArrayList<>(3);
        boolean goodPOS = false;
        if (segmentWords) {
            evalTypes.add(WordCatConstituent.wordType);
            if (ctpp.segmentMarkov && !parse) {
                evalTypes.add(WordCatConstituent.tagType);
                goodPOS = true;
            }
        }
        if (parse) {
            evalTypes.add(WordCatConstituent.tagType);
            evalTypes.add(WordCatConstituent.catType);
            if (combo) {
                evalTypes.add(WordCatConstituent.wordType);
                goodPOS = true;
            }
        }
        TreeToBracketProcessor proc = new TreeToBracketProcessor(evalTypes);
        log.info("Testing...");
        for (Tree goldTop : testTreebank) {
            Tree gold = goldTop.firstChild();
            List<HasWord> goldSentence = gold.yieldHasWord();
            if (goldSentence.size() > maxLength) {
                log.info("Skipping sentence; too long: " + goldSentence.size());
                continue;
            } else {
                log.info("Processing sentence; length: " + goldSentence.size());
            }
            List<HasWord> s;
            if (segmentWords) {
                StringBuilder goldCharBuf = new StringBuilder();
                for (HasWord aGoldSentence : goldSentence) {
                    StringLabel word = (StringLabel) aGoldSentence;
                    goldCharBuf.append(word.value());
                }
                String goldChars = goldCharBuf.toString();
                s = seg.segment(goldChars);
            } else {
                s = goldSentence;
            }
            Tree tree;
            if (parse) {
                tree = lp.parseTree(s);
                if (tree == null) {
                    throw new RuntimeException("PARSER RETURNED NULL!!!");
                }
            } else {
                tree = Trees.toFlatTree(s);
                tree = subcategoryStripper.transformTree(tree);
            }
            if (pw != null) {
                if (parse) {
                    tree.pennPrint(pw);
                } else {
                    Iterator sentIter = s.iterator();
                    for (; ; ) {
                        Word word = (Word) sentIter.next();
                        pw.print(word.word());
                        if (sentIter.hasNext()) {
                            pw.print(" ");
                        } else {
                            break;
                        }
                    }
                }
                pw.println();
            }
            if (eval) {
                Collection ourBrackets, goldBrackets;
                ourBrackets = proc.allBrackets(tree);
                goldBrackets = proc.allBrackets(gold);
                if (goodPOS) {
                    ourBrackets.addAll(proc.commonWordTagTypeBrackets(tree, gold));
                    goldBrackets.addAll(proc.commonWordTagTypeBrackets(gold, tree));
                }
                basicEval.eval(ourBrackets, goldBrackets);
                System.out.println("\nScores:");
                basicEval.displayLast();
                Tree collinsTree = collinizer.transformTree(tree);
                Tree collinsGold = collinizer.transformTree(gold);
                ourBrackets = proc.allBrackets(collinsTree);
                goldBrackets = proc.allBrackets(collinsGold);
                if (goodPOS) {
                    ourBrackets.addAll(proc.commonWordTagTypeBrackets(collinsTree, collinsGold));
                    goldBrackets.addAll(proc.commonWordTagTypeBrackets(collinsGold, collinsTree));
                }
                collinsEval.eval(ourBrackets, goldBrackets);
                System.out.println("\nCollinized scores:");
                collinsEval.displayLast();
                System.out.println();
            }
        }
        if (eval) {
            basicEval.display();
            System.out.println();
            collinsEval.display();
        }
    }
}
Also used : HasWord(edu.stanford.nlp.ling.HasWord) TaggedWord(edu.stanford.nlp.ling.TaggedWord) Word(edu.stanford.nlp.ling.Word) NumberRangesFileFilter(edu.stanford.nlp.io.NumberRangesFileFilter) ArrayList(java.util.ArrayList) ObjectOutputStream(java.io.ObjectOutputStream) StringLabel(edu.stanford.nlp.ling.StringLabel) TreeToBracketProcessor(edu.stanford.nlp.trees.TreeToBracketProcessor) WordSegmenter(edu.stanford.nlp.process.WordSegmenter) Iterator(java.util.Iterator) Tree(edu.stanford.nlp.trees.Tree) MemoryTreebank(edu.stanford.nlp.trees.MemoryTreebank) NumberRangesFileFilter(edu.stanford.nlp.io.NumberRangesFileFilter) FileFilter(java.io.FileFilter) PrintWriter(java.io.PrintWriter) HasWord(edu.stanford.nlp.ling.HasWord) WordCatEqualityChecker(edu.stanford.nlp.trees.WordCatEqualityChecker) HashIndex(edu.stanford.nlp.util.HashIndex) WordCatEquivalenceClasser(edu.stanford.nlp.trees.WordCatEquivalenceClasser) FileOutputStream(java.io.FileOutputStream) Collection(java.util.Collection) OutputStreamWriter(java.io.OutputStreamWriter) File(java.io.File) TreeTransformer(edu.stanford.nlp.trees.TreeTransformer) ObjectInputStream(java.io.ObjectInputStream)

Aggregations

NumberRangesFileFilter (edu.stanford.nlp.io.NumberRangesFileFilter)1 HasWord (edu.stanford.nlp.ling.HasWord)1 StringLabel (edu.stanford.nlp.ling.StringLabel)1 TaggedWord (edu.stanford.nlp.ling.TaggedWord)1 Word (edu.stanford.nlp.ling.Word)1 WordSegmenter (edu.stanford.nlp.process.WordSegmenter)1 MemoryTreebank (edu.stanford.nlp.trees.MemoryTreebank)1 Tree (edu.stanford.nlp.trees.Tree)1 TreeToBracketProcessor (edu.stanford.nlp.trees.TreeToBracketProcessor)1 TreeTransformer (edu.stanford.nlp.trees.TreeTransformer)1 WordCatEqualityChecker (edu.stanford.nlp.trees.WordCatEqualityChecker)1 WordCatEquivalenceClasser (edu.stanford.nlp.trees.WordCatEquivalenceClasser)1 HashIndex (edu.stanford.nlp.util.HashIndex)1 File (java.io.File)1 FileFilter (java.io.FileFilter)1 FileOutputStream (java.io.FileOutputStream)1 ObjectInputStream (java.io.ObjectInputStream)1 ObjectOutputStream (java.io.ObjectOutputStream)1 OutputStreamWriter (java.io.OutputStreamWriter)1 PrintWriter (java.io.PrintWriter)1