Search in sources :

Example 6 with Treebank

use of edu.stanford.nlp.trees.Treebank in project CoreNLP by stanfordnlp.

the class CrossValidateTestOptions method main.

public static void main(String[] args) throws IOException, ClassNotFoundException {
    String dvmodelFile = null;
    String lexparserFile = null;
    String testTreebankPath = null;
    FileFilter testTreebankFilter = null;
    List<String> unusedArgs = new ArrayList<>();
    for (int argIndex = 0; argIndex < args.length; ) {
        if (args[argIndex].equalsIgnoreCase("-lexparser")) {
            lexparserFile = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
            Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
            argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
            testTreebankPath = treebankDescription.first();
            testTreebankFilter = treebankDescription.second();
        } else {
            unusedArgs.add(args[argIndex++]);
        }
    }
    log.info("Loading lexparser from: " + lexparserFile);
    String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
    LexicalizedParser lexparser = LexicalizedParser.loadModel(lexparserFile, newArgs);
    log.info("... done");
    Treebank testTreebank = null;
    if (testTreebankPath != null) {
        log.info("Reading in trees from " + testTreebankPath);
        if (testTreebankFilter != null) {
            log.info("Filtering on " + testTreebankFilter);
        }
        testTreebank = lexparser.getOp().tlpParams.memoryTreebank();
        ;
        testTreebank.loadPath(testTreebankPath, testTreebankFilter);
        log.info("Read in " + testTreebank.size() + " trees for testing");
    }
    double[] labelResults = new double[weights.length];
    double[] tagResults = new double[weights.length];
    for (int i = 0; i < weights.length; ++i) {
        lexparser.getOp().baseParserWeight = weights[i];
        EvaluateTreebank evaluator = new EvaluateTreebank(lexparser);
        evaluator.testOnTreebank(testTreebank);
        labelResults[i] = evaluator.getLBScore();
        tagResults[i] = evaluator.getTagScore();
    }
    for (int i = 0; i < weights.length; ++i) {
        log.info("LexicalizedParser weight " + weights[i] + ": labeled " + labelResults[i] + " tag " + tagResults[i]);
    }
}
Also used : EvaluateTreebank(edu.stanford.nlp.parser.lexparser.EvaluateTreebank) EvaluateTreebank(edu.stanford.nlp.parser.lexparser.EvaluateTreebank) Treebank(edu.stanford.nlp.trees.Treebank) LexicalizedParser(edu.stanford.nlp.parser.lexparser.LexicalizedParser) ArrayList(java.util.ArrayList) FileFilter(java.io.FileFilter) Pair(edu.stanford.nlp.util.Pair)

Example 7 with Treebank

use of edu.stanford.nlp.trees.Treebank in project CoreNLP by stanfordnlp.

the class FindNearestNeighbors method main.

public static void main(String[] args) throws Exception {
    String modelPath = null;
    String outputPath = null;
    String testTreebankPath = null;
    FileFilter testTreebankFilter = null;
    List<String> unusedArgs = new ArrayList<>();
    for (int argIndex = 0; argIndex < args.length; ) {
        if (args[argIndex].equalsIgnoreCase("-model")) {
            modelPath = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
            Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
            argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
            testTreebankPath = treebankDescription.first();
            testTreebankFilter = treebankDescription.second();
        } else if (args[argIndex].equalsIgnoreCase("-output")) {
            outputPath = args[argIndex + 1];
            argIndex += 2;
        } else {
            unusedArgs.add(args[argIndex++]);
        }
    }
    if (modelPath == null) {
        throw new IllegalArgumentException("Need to specify -model");
    }
    if (testTreebankPath == null) {
        throw new IllegalArgumentException("Need to specify -testTreebank");
    }
    if (outputPath == null) {
        throw new IllegalArgumentException("Need to specify -output");
    }
    String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
    LexicalizedParser lexparser = LexicalizedParser.loadModel(modelPath, newArgs);
    Treebank testTreebank = null;
    if (testTreebankPath != null) {
        log.info("Reading in trees from " + testTreebankPath);
        if (testTreebankFilter != null) {
            log.info("Filtering on " + testTreebankFilter);
        }
        testTreebank = lexparser.getOp().tlpParams.memoryTreebank();
        ;
        testTreebank.loadPath(testTreebankPath, testTreebankFilter);
        log.info("Read in " + testTreebank.size() + " trees for testing");
    }
    FileWriter out = new FileWriter(outputPath);
    BufferedWriter bout = new BufferedWriter(out);
    log.info("Parsing " + testTreebank.size() + " trees");
    int count = 0;
    List<ParseRecord> records = Generics.newArrayList();
    for (Tree goldTree : testTreebank) {
        List<Word> tokens = goldTree.yieldWords();
        ParserQuery parserQuery = lexparser.parserQuery();
        if (!parserQuery.parse(tokens)) {
            throw new AssertionError("Could not parse: " + tokens);
        }
        if (!(parserQuery instanceof RerankingParserQuery)) {
            throw new IllegalArgumentException("Expected a LexicalizedParser with a Reranker attached");
        }
        RerankingParserQuery rpq = (RerankingParserQuery) parserQuery;
        if (!(rpq.rerankerQuery() instanceof DVModelReranker.Query)) {
            throw new IllegalArgumentException("Expected a LexicalizedParser with a DVModel attached");
        }
        DeepTree tree = ((DVModelReranker.Query) rpq.rerankerQuery()).getDeepTrees().get(0);
        SimpleMatrix rootVector = null;
        for (Map.Entry<Tree, SimpleMatrix> entry : tree.getVectors().entrySet()) {
            if (entry.getKey().label().value().equals("ROOT")) {
                rootVector = entry.getValue();
                break;
            }
        }
        if (rootVector == null) {
            throw new AssertionError("Could not find root nodevector");
        }
        out.write(tokens + "\n");
        out.write(tree.getTree() + "\n");
        for (int i = 0; i < rootVector.getNumElements(); ++i) {
            out.write("  " + rootVector.get(i));
        }
        out.write("\n\n\n");
        count++;
        if (count % 10 == 0) {
            log.info("  " + count);
        }
        records.add(new ParseRecord(tokens, goldTree, tree.getTree(), rootVector, tree.getVectors()));
    }
    log.info("  done parsing");
    List<Pair<Tree, SimpleMatrix>> subtrees = Generics.newArrayList();
    for (ParseRecord record : records) {
        for (Map.Entry<Tree, SimpleMatrix> entry : record.nodeVectors.entrySet()) {
            if (entry.getKey().getLeaves().size() <= maxLength) {
                subtrees.add(Pair.makePair(entry.getKey(), entry.getValue()));
            }
        }
    }
    log.info("There are " + subtrees.size() + " subtrees in the set of trees");
    PriorityQueue<ScoredObject<Pair<Tree, Tree>>> bestmatches = new PriorityQueue<>(101, ScoredComparator.DESCENDING_COMPARATOR);
    for (int i = 0; i < subtrees.size(); ++i) {
        log.info(subtrees.get(i).first().yieldWords());
        log.info(subtrees.get(i).first());
        for (int j = 0; j < subtrees.size(); ++j) {
            if (i == j) {
                continue;
            }
            // TODO: look at basic category?
            double normF = subtrees.get(i).second().minus(subtrees.get(j).second()).normF();
            bestmatches.add(new ScoredObject<>(Pair.makePair(subtrees.get(i).first(), subtrees.get(j).first()), normF));
            if (bestmatches.size() > 100) {
                bestmatches.poll();
            }
        }
        List<ScoredObject<Pair<Tree, Tree>>> ordered = Generics.newArrayList();
        while (bestmatches.size() > 0) {
            ordered.add(bestmatches.poll());
        }
        Collections.reverse(ordered);
        for (ScoredObject<Pair<Tree, Tree>> pair : ordered) {
            log.info(" MATCHED " + pair.object().second.yieldWords() + " ... " + pair.object().second() + " with a score of " + pair.score());
        }
        log.info();
        log.info();
        bestmatches.clear();
    }
    /*
    for (int i = 0; i < records.size(); ++i) {
      if (i % 10 == 0) {
        log.info("  " + i);
      }
      List<ScoredObject<ParseRecord>> scored = Generics.newArrayList();
      for (int j = 0; j < records.size(); ++j) {
        if (i == j) continue;

        double score = 0.0;
        int matches = 0;
        for (Map.Entry<Tree, SimpleMatrix> first : records.get(i).nodeVectors.entrySet()) {
          for (Map.Entry<Tree, SimpleMatrix> second : records.get(j).nodeVectors.entrySet()) {
            String firstBasic = dvparser.dvModel.basicCategory(first.getKey().label().value());
            String secondBasic = dvparser.dvModel.basicCategory(second.getKey().label().value());
            if (firstBasic.equals(secondBasic)) {
              ++matches;
              double normF = first.getValue().minus(second.getValue()).normF();
              score += normF * normF;
            }
          }
        }
        if (matches == 0) {
          score = Double.POSITIVE_INFINITY;
        } else {
          score = score / matches;
        }
        //double score = records.get(i).vector.minus(records.get(j).vector).normF();
        scored.add(new ScoredObject<ParseRecord>(records.get(j), score));
      }
      Collections.sort(scored, ScoredComparator.ASCENDING_COMPARATOR);

      out.write(records.get(i).sentence.toString() + "\n");
      for (int j = 0; j < numNeighbors; ++j) {
        out.write("   " + scored.get(j).score() + ": " + scored.get(j).object().sentence + "\n");
      }
      out.write("\n\n");
    }
    log.info();
    */
    bout.flush();
    out.flush();
    out.close();
}
Also used : Word(edu.stanford.nlp.ling.Word) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) ParserQuery(edu.stanford.nlp.parser.common.ParserQuery) Treebank(edu.stanford.nlp.trees.Treebank) LexicalizedParser(edu.stanford.nlp.parser.lexparser.LexicalizedParser) FileWriter(java.io.FileWriter) ArrayList(java.util.ArrayList) BufferedWriter(java.io.BufferedWriter) SimpleMatrix(org.ejml.simple.SimpleMatrix) ScoredObject(edu.stanford.nlp.util.ScoredObject) DeepTree(edu.stanford.nlp.trees.DeepTree) Tree(edu.stanford.nlp.trees.Tree) DeepTree(edu.stanford.nlp.trees.DeepTree) FileFilter(java.io.FileFilter) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) Pair(edu.stanford.nlp.util.Pair) PriorityQueue(java.util.PriorityQueue) IdentityHashMap(java.util.IdentityHashMap) Map(java.util.Map) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) ParserQuery(edu.stanford.nlp.parser.common.ParserQuery)

Example 8 with Treebank

use of edu.stanford.nlp.trees.Treebank in project CoreNLP by stanfordnlp.

the class TreebankFactoredLexiconStats method main.

//  private static String stripTag(String tag) {
//    if (tag.startsWith("DT")) {
//      String newTag = tag.substring(2, tag.length());
//      return newTag.length() > 0 ? newTag : tag;
//    }
//    return tag;
//  }
/**
   * @param args
   */
public static void main(String[] args) {
    if (args.length != 3) {
        System.err.printf("Usage: java %s language filename features%n", TreebankFactoredLexiconStats.class.getName());
        System.exit(-1);
    }
    Language language = Language.valueOf(args[0]);
    TreebankLangParserParams tlpp = language.params;
    if (language.equals(Language.Arabic)) {
        String[] options = { "-arabicFactored" };
        tlpp.setOptionFlag(options, 0);
    } else {
        String[] options = { "-frenchFactored" };
        tlpp.setOptionFlag(options, 0);
    }
    Treebank tb = tlpp.diskTreebank();
    tb.loadPath(args[1]);
    MorphoFeatureSpecification morphoSpec = language.equals(Language.Arabic) ? new ArabicMorphoFeatureSpecification() : new FrenchMorphoFeatureSpecification();
    String[] features = args[2].trim().split(",");
    for (String feature : features) {
        morphoSpec.activate(MorphoFeatureType.valueOf(feature));
    }
    // Counters
    Counter<String> wordTagCounter = new ClassicCounter<>(30000);
    Counter<String> morphTagCounter = new ClassicCounter<>(500);
    //    Counter<String> signatureTagCounter = new ClassicCounter<String>();
    Counter<String> morphCounter = new ClassicCounter<>(500);
    Counter<String> wordCounter = new ClassicCounter<>(30000);
    Counter<String> tagCounter = new ClassicCounter<>(300);
    Counter<String> lemmaCounter = new ClassicCounter<>(25000);
    Counter<String> lemmaTagCounter = new ClassicCounter<>(25000);
    Counter<String> richTagCounter = new ClassicCounter<>(1000);
    Counter<String> reducedTagCounter = new ClassicCounter<>(500);
    Counter<String> reducedTagLemmaCounter = new ClassicCounter<>(500);
    Map<String, Set<String>> wordLemmaMap = Generics.newHashMap();
    TwoDimensionalIntCounter<String, String> lemmaReducedTagCounter = new TwoDimensionalIntCounter<>(30000);
    TwoDimensionalIntCounter<String, String> reducedTagTagCounter = new TwoDimensionalIntCounter<>(500);
    TwoDimensionalIntCounter<String, String> tagReducedTagCounter = new TwoDimensionalIntCounter<>(300);
    int numTrees = 0;
    for (Tree tree : tb) {
        for (Tree subTree : tree) {
            if (!subTree.isLeaf()) {
                tlpp.transformTree(subTree, tree);
            }
        }
        List<Label> pretermList = tree.preTerminalYield();
        List<Label> yield = tree.yield();
        assert yield.size() == pretermList.size();
        int yieldLen = yield.size();
        for (int i = 0; i < yieldLen; ++i) {
            String tag = pretermList.get(i).value();
            String word = yield.get(i).value();
            String morph = ((CoreLabel) yield.get(i)).originalText();
            // Note: if there is no lemma, then we use the surface form.
            Pair<String, String> lemmaTag = MorphoFeatureSpecification.splitMorphString(word, morph);
            String lemma = lemmaTag.first();
            String richTag = lemmaTag.second();
            // WSGDEBUG
            if (tag.contains("MW"))
                lemma += "-MWE";
            lemmaCounter.incrementCount(lemma);
            lemmaTagCounter.incrementCount(lemma + tag);
            richTagCounter.incrementCount(richTag);
            String reducedTag = morphoSpec.strToFeatures(richTag).toString();
            reducedTagCounter.incrementCount(reducedTag);
            reducedTagLemmaCounter.incrementCount(reducedTag + lemma);
            wordTagCounter.incrementCount(word + tag);
            morphTagCounter.incrementCount(morph + tag);
            morphCounter.incrementCount(morph);
            wordCounter.incrementCount(word);
            tagCounter.incrementCount(tag);
            reducedTag = reducedTag.equals("") ? "NONE" : reducedTag;
            if (wordLemmaMap.containsKey(word)) {
                wordLemmaMap.get(word).add(lemma);
            } else {
                Set<String> lemmas = Generics.newHashSet(1);
                wordLemmaMap.put(word, lemmas);
            }
            lemmaReducedTagCounter.incrementCount(lemma, reducedTag);
            reducedTagTagCounter.incrementCount(lemma + reducedTag, tag);
            tagReducedTagCounter.incrementCount(tag, reducedTag);
        }
        ++numTrees;
    }
    // Barf...
    System.out.println("Language: " + language.toString());
    System.out.printf("#trees:\t%d%n", numTrees);
    System.out.printf("#tokens:\t%d%n", (int) wordCounter.totalCount());
    System.out.printf("#words:\t%d%n", wordCounter.keySet().size());
    System.out.printf("#tags:\t%d%n", tagCounter.keySet().size());
    System.out.printf("#wordTagPairs:\t%d%n", wordTagCounter.keySet().size());
    System.out.printf("#lemmas:\t%d%n", lemmaCounter.keySet().size());
    System.out.printf("#lemmaTagPairs:\t%d%n", lemmaTagCounter.keySet().size());
    System.out.printf("#feattags:\t%d%n", reducedTagCounter.keySet().size());
    System.out.printf("#feattag+lemmas:\t%d%n", reducedTagLemmaCounter.keySet().size());
    System.out.printf("#richtags:\t%d%n", richTagCounter.keySet().size());
    System.out.printf("#richtag+lemma:\t%d%n", morphCounter.keySet().size());
    System.out.printf("#richtag+lemmaTagPairs:\t%d%n", morphTagCounter.keySet().size());
    // Extra
    System.out.println("==================");
    StringBuilder sbNoLemma = new StringBuilder();
    StringBuilder sbMultLemmas = new StringBuilder();
    for (Map.Entry<String, Set<String>> wordLemmas : wordLemmaMap.entrySet()) {
        String word = wordLemmas.getKey();
        Set<String> lemmas = wordLemmas.getValue();
        if (lemmas.size() == 0) {
            sbNoLemma.append("NO LEMMAS FOR WORD: " + word + "\n");
            continue;
        }
        if (lemmas.size() > 1) {
            sbMultLemmas.append("MULTIPLE LEMMAS: " + word + " " + setToString(lemmas) + "\n");
            continue;
        }
        String lemma = lemmas.iterator().next();
        Set<String> reducedTags = lemmaReducedTagCounter.getCounter(lemma).keySet();
        if (reducedTags.size() > 1) {
            System.out.printf("%s --> %s%n", word, lemma);
            for (String reducedTag : reducedTags) {
                int count = lemmaReducedTagCounter.getCount(lemma, reducedTag);
                String posTags = setToString(reducedTagTagCounter.getCounter(lemma + reducedTag).keySet());
                System.out.printf("\t%s\t%d\t%s%n", reducedTag, count, posTags);
            }
            System.out.println();
        }
    }
    System.out.println("==================");
    System.out.println(sbNoLemma.toString());
    System.out.println(sbMultLemmas.toString());
    System.out.println("==================");
    List<String> tags = new ArrayList<>(tagReducedTagCounter.firstKeySet());
    Collections.sort(tags);
    for (String tag : tags) {
        System.out.println(tag);
        Set<String> reducedTags = tagReducedTagCounter.getCounter(tag).keySet();
        for (String reducedTag : reducedTags) {
            int count = tagReducedTagCounter.getCount(tag, reducedTag);
            //        reducedTag = reducedTag.equals("") ? "NONE" : reducedTag;
            System.out.printf("\t%s\t%d%n", reducedTag, count);
        }
        System.out.println();
    }
    System.out.println("==================");
}
Also used : FrenchMorphoFeatureSpecification(edu.stanford.nlp.international.french.FrenchMorphoFeatureSpecification) Set(java.util.Set) Treebank(edu.stanford.nlp.trees.Treebank) CoreLabel(edu.stanford.nlp.ling.CoreLabel) Label(edu.stanford.nlp.ling.Label) ArrayList(java.util.ArrayList) TreebankLangParserParams(edu.stanford.nlp.parser.lexparser.TreebankLangParserParams) Language(edu.stanford.nlp.international.Language) Tree(edu.stanford.nlp.trees.Tree) ArabicMorphoFeatureSpecification(edu.stanford.nlp.international.arabic.ArabicMorphoFeatureSpecification) MorphoFeatureSpecification(edu.stanford.nlp.international.morph.MorphoFeatureSpecification) ArabicMorphoFeatureSpecification(edu.stanford.nlp.international.arabic.ArabicMorphoFeatureSpecification) FrenchMorphoFeatureSpecification(edu.stanford.nlp.international.french.FrenchMorphoFeatureSpecification) TwoDimensionalIntCounter(edu.stanford.nlp.stats.TwoDimensionalIntCounter) CoreLabel(edu.stanford.nlp.ling.CoreLabel) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) Map(java.util.Map)

Example 9 with Treebank

use of edu.stanford.nlp.trees.Treebank in project CoreNLP by stanfordnlp.

the class FactoredParser method main.

/* some documentation for Roger's convenience
 * {pcfg,dep,combo}{PE,DE,TE} are precision/dep/tagging evals for the models

 * parser is the PCFG parser
 * dparser is the dependency parser
 * bparser is the combining parser

 * during testing:
 * tree is the test tree (gold tree)
 * binaryTree is the gold tree binarized
 * tree2b is the best PCFG paser, binarized
 * tree2 is the best PCFG parse (debinarized)
 * tree3 is the dependency parse, binarized
 * tree3db is the dependency parser, debinarized
 * tree4 is the best combo parse, binarized and then debinarized
 * tree4b is the best combo parse, binarized
 */
public static void main(String[] args) {
    Options op = new Options(new EnglishTreebankParserParams());
    // op.tlpParams may be changed to something else later, so don't use it till
    // after options are parsed.
    StringUtils.logInvocationString(log, args);
    String path = "/u/nlp/stuff/corpora/Treebank3/parsed/mrg/wsj";
    int trainLow = 200, trainHigh = 2199, testLow = 2200, testHigh = 2219;
    String serializeFile = null;
    int i = 0;
    while (i < args.length && args[i].startsWith("-")) {
        if (args[i].equalsIgnoreCase("-path") && (i + 1 < args.length)) {
            path = args[i + 1];
            i += 2;
        } else if (args[i].equalsIgnoreCase("-train") && (i + 2 < args.length)) {
            trainLow = Integer.parseInt(args[i + 1]);
            trainHigh = Integer.parseInt(args[i + 2]);
            i += 3;
        } else if (args[i].equalsIgnoreCase("-test") && (i + 2 < args.length)) {
            testLow = Integer.parseInt(args[i + 1]);
            testHigh = Integer.parseInt(args[i + 2]);
            i += 3;
        } else if (args[i].equalsIgnoreCase("-serialize") && (i + 1 < args.length)) {
            serializeFile = args[i + 1];
            i += 2;
        } else if (args[i].equalsIgnoreCase("-tLPP") && (i + 1 < args.length)) {
            try {
                op.tlpParams = (TreebankLangParserParams) Class.forName(args[i + 1]).newInstance();
            } catch (ClassNotFoundException e) {
                log.info("Class not found: " + args[i + 1]);
                throw new RuntimeException(e);
            } catch (InstantiationException e) {
                log.info("Couldn't instantiate: " + args[i + 1] + ": " + e.toString());
                throw new RuntimeException(e);
            } catch (IllegalAccessException e) {
                log.info("illegal access" + e);
                throw new RuntimeException(e);
            }
            i += 2;
        } else if (args[i].equals("-encoding")) {
            // sets encoding for TreebankLangParserParams
            op.tlpParams.setInputEncoding(args[i + 1]);
            op.tlpParams.setOutputEncoding(args[i + 1]);
            i += 2;
        } else {
            i = op.setOptionOrWarn(args, i);
        }
    }
    // System.out.println(tlpParams.getClass());
    TreebankLanguagePack tlp = op.tlpParams.treebankLanguagePack();
    op.trainOptions.sisterSplitters = Generics.newHashSet(Arrays.asList(op.tlpParams.sisterSplitters()));
    //    BinarizerFactory.TreeAnnotator.setTreebankLang(tlpParams);
    PrintWriter pw = op.tlpParams.pw();
    op.testOptions.display();
    op.trainOptions.display();
    op.display();
    op.tlpParams.display();
    // setup tree transforms
    Treebank trainTreebank = op.tlpParams.memoryTreebank();
    MemoryTreebank testTreebank = op.tlpParams.testMemoryTreebank();
    // Treebank blippTreebank = ((EnglishTreebankParserParams) tlpParams).diskTreebank();
    // String blippPath = "/afs/ir.stanford.edu/data/linguistic-data/BLLIP-WSJ/";
    // blippTreebank.loadPath(blippPath, "", true);
    Timing.startTime();
    log.info("Reading trees...");
    testTreebank.loadPath(path, new NumberRangeFileFilter(testLow, testHigh, true));
    if (op.testOptions.increasingLength) {
        Collections.sort(testTreebank, new TreeLengthComparator());
    }
    trainTreebank.loadPath(path, new NumberRangeFileFilter(trainLow, trainHigh, true));
    Timing.tick("done.");
    log.info("Binarizing trees...");
    TreeAnnotatorAndBinarizer binarizer;
    if (!op.trainOptions.leftToRight) {
        binarizer = new TreeAnnotatorAndBinarizer(op.tlpParams, op.forceCNF, !op.trainOptions.outsideFactor(), true, op);
    } else {
        binarizer = new TreeAnnotatorAndBinarizer(op.tlpParams.headFinder(), new LeftHeadFinder(), op.tlpParams, op.forceCNF, !op.trainOptions.outsideFactor(), true, op);
    }
    CollinsPuncTransformer collinsPuncTransformer = null;
    if (op.trainOptions.collinsPunc) {
        collinsPuncTransformer = new CollinsPuncTransformer(tlp);
    }
    TreeTransformer debinarizer = new Debinarizer(op.forceCNF);
    List<Tree> binaryTrainTrees = new ArrayList<>();
    if (op.trainOptions.selectiveSplit) {
        op.trainOptions.splitters = ParentAnnotationStats.getSplitCategories(trainTreebank, op.trainOptions.tagSelectiveSplit, 0, op.trainOptions.selectiveSplitCutOff, op.trainOptions.tagSelectiveSplitCutOff, op.tlpParams.treebankLanguagePack());
        if (op.trainOptions.deleteSplitters != null) {
            List<String> deleted = new ArrayList<>();
            for (String del : op.trainOptions.deleteSplitters) {
                String baseDel = tlp.basicCategory(del);
                boolean checkBasic = del.equals(baseDel);
                for (Iterator<String> it = op.trainOptions.splitters.iterator(); it.hasNext(); ) {
                    String elem = it.next();
                    String baseElem = tlp.basicCategory(elem);
                    boolean delStr = checkBasic && baseElem.equals(baseDel) || elem.equals(del);
                    if (delStr) {
                        it.remove();
                        deleted.add(elem);
                    }
                }
            }
            log.info("Removed from vertical splitters: " + deleted);
        }
    }
    if (op.trainOptions.selectivePostSplit) {
        TreeTransformer myTransformer = new TreeAnnotator(op.tlpParams.headFinder(), op.tlpParams, op);
        Treebank annotatedTB = trainTreebank.transform(myTransformer);
        op.trainOptions.postSplitters = ParentAnnotationStats.getSplitCategories(annotatedTB, true, 0, op.trainOptions.selectivePostSplitCutOff, op.trainOptions.tagSelectivePostSplitCutOff, op.tlpParams.treebankLanguagePack());
    }
    if (op.trainOptions.hSelSplit) {
        binarizer.setDoSelectiveSplit(false);
        for (Tree tree : trainTreebank) {
            if (op.trainOptions.collinsPunc) {
                tree = collinsPuncTransformer.transformTree(tree);
            }
            //tree.pennPrint(tlpParams.pw());
            tree = binarizer.transformTree(tree);
        //binaryTrainTrees.add(tree);
        }
        binarizer.setDoSelectiveSplit(true);
    }
    for (Tree tree : trainTreebank) {
        if (op.trainOptions.collinsPunc) {
            tree = collinsPuncTransformer.transformTree(tree);
        }
        tree = binarizer.transformTree(tree);
        binaryTrainTrees.add(tree);
    }
    if (op.testOptions.verbose) {
        binarizer.dumpStats();
    }
    List<Tree> binaryTestTrees = new ArrayList<>();
    for (Tree tree : testTreebank) {
        if (op.trainOptions.collinsPunc) {
            tree = collinsPuncTransformer.transformTree(tree);
        }
        tree = binarizer.transformTree(tree);
        binaryTestTrees.add(tree);
    }
    // binarization
    Timing.tick("done.");
    BinaryGrammar bg = null;
    UnaryGrammar ug = null;
    DependencyGrammar dg = null;
    // DependencyGrammar dgBLIPP = null;
    Lexicon lex = null;
    Index<String> stateIndex = new HashIndex<>();
    // extract grammars
    Extractor<Pair<UnaryGrammar, BinaryGrammar>> bgExtractor = new BinaryGrammarExtractor(op, stateIndex);
    if (op.doPCFG) {
        log.info("Extracting PCFG...");
        Pair<UnaryGrammar, BinaryGrammar> bgug = null;
        if (op.trainOptions.cheatPCFG) {
            List<Tree> allTrees = new ArrayList<>(binaryTrainTrees);
            allTrees.addAll(binaryTestTrees);
            bgug = bgExtractor.extract(allTrees);
        } else {
            bgug = bgExtractor.extract(binaryTrainTrees);
        }
        bg = bgug.second;
        bg.splitRules();
        ug = bgug.first;
        ug.purgeRules();
        Timing.tick("done.");
    }
    log.info("Extracting Lexicon...");
    Index<String> wordIndex = new HashIndex<>();
    Index<String> tagIndex = new HashIndex<>();
    lex = op.tlpParams.lex(op, wordIndex, tagIndex);
    lex.initializeTraining(binaryTrainTrees.size());
    lex.train(binaryTrainTrees);
    lex.finishTraining();
    Timing.tick("done.");
    if (op.doDep) {
        log.info("Extracting Dependencies...");
        binaryTrainTrees.clear();
        Extractor<DependencyGrammar> dgExtractor = new MLEDependencyGrammarExtractor(op, wordIndex, tagIndex);
        // dgBLIPP = (DependencyGrammar) dgExtractor.extract(new ConcatenationIterator(trainTreebank.iterator(),blippTreebank.iterator()),new TransformTreeDependency(tlpParams,true));
        // DependencyGrammar dg1 = dgExtractor.extract(trainTreebank.iterator(), new TransformTreeDependency(op.tlpParams, true));
        //dgBLIPP=(DependencyGrammar)dgExtractor.extract(blippTreebank.iterator(),new TransformTreeDependency(tlpParams));
        //dg = (DependencyGrammar) dgExtractor.extract(new ConcatenationIterator(trainTreebank.iterator(),blippTreebank.iterator()),new TransformTreeDependency(tlpParams));
        // dg=new DependencyGrammarCombination(dg1,dgBLIPP,2);
        //uses information whether the words are known or not, discards unknown words
        dg = dgExtractor.extract(binaryTrainTrees);
        Timing.tick("done.");
        //System.out.print("Extracting Unknown Word Model...");
        //UnknownWordModel uwm = (UnknownWordModel)uwmExtractor.extract(binaryTrainTrees);
        //Timing.tick("done.");
        System.out.print("Tuning Dependency Model...");
        dg.tune(binaryTestTrees);
        //System.out.println("TUNE DEPS: "+tuneDeps);
        Timing.tick("done.");
    }
    BinaryGrammar boundBG = bg;
    UnaryGrammar boundUG = ug;
    GrammarProjection gp = new NullGrammarProjection(bg, ug);
    // serialization
    if (serializeFile != null) {
        log.info("Serializing parser...");
        LexicalizedParser parser = new LexicalizedParser(lex, bg, ug, dg, stateIndex, wordIndex, tagIndex, op);
        parser.saveParserToSerialized(serializeFile);
        Timing.tick("done.");
    }
    // test: pcfg-parse and output
    ExhaustivePCFGParser parser = null;
    if (op.doPCFG) {
        parser = new ExhaustivePCFGParser(boundBG, boundUG, lex, op, stateIndex, wordIndex, tagIndex);
    }
    ExhaustiveDependencyParser dparser = ((op.doDep && !op.testOptions.useFastFactored) ? new ExhaustiveDependencyParser(dg, lex, op, wordIndex, tagIndex) : null);
    Scorer scorer = (op.doPCFG ? new TwinScorer(new ProjectionScorer(parser, gp, op), dparser) : null);
    //Scorer scorer = parser;
    BiLexPCFGParser bparser = null;
    if (op.doPCFG && op.doDep) {
        bparser = (op.testOptions.useN5) ? new BiLexPCFGParser.N5BiLexPCFGParser(scorer, parser, dparser, bg, ug, dg, lex, op, gp, stateIndex, wordIndex, tagIndex) : new BiLexPCFGParser(scorer, parser, dparser, bg, ug, dg, lex, op, gp, stateIndex, wordIndex, tagIndex);
    }
    Evalb pcfgPE = new Evalb("pcfg  PE", true);
    Evalb comboPE = new Evalb("combo PE", true);
    AbstractEval pcfgCB = new Evalb.CBEval("pcfg  CB", true);
    AbstractEval pcfgTE = new TaggingEval("pcfg  TE");
    AbstractEval comboTE = new TaggingEval("combo TE");
    AbstractEval pcfgTEnoPunct = new TaggingEval("pcfg nopunct TE");
    AbstractEval comboTEnoPunct = new TaggingEval("combo nopunct TE");
    AbstractEval depTE = new TaggingEval("depnd TE");
    AbstractEval depDE = new UnlabeledAttachmentEval("depnd DE", true, null, tlp.punctuationWordRejectFilter());
    AbstractEval comboDE = new UnlabeledAttachmentEval("combo DE", true, null, tlp.punctuationWordRejectFilter());
    if (op.testOptions.evalb) {
        EvalbFormatWriter.initEVALBfiles(op.tlpParams);
    }
    // int[] countByLength = new int[op.testOptions.maxLength+1];
    // Use a reflection ruse, so one can run this without needing the
    // tagger.  Using a function rather than a MaxentTagger means we
    // can distribute a version of the parser that doesn't include the
    // entire tagger.
    Function<List<? extends HasWord>, ArrayList<TaggedWord>> tagger = null;
    if (op.testOptions.preTag) {
        try {
            Class[] argsClass = { String.class };
            Object[] arguments = new Object[] { op.testOptions.taggerSerializedFile };
            tagger = (Function<List<? extends HasWord>, ArrayList<TaggedWord>>) Class.forName("edu.stanford.nlp.tagger.maxent.MaxentTagger").getConstructor(argsClass).newInstance(arguments);
        } catch (Exception e) {
            log.info(e);
            log.info("Warning: No pretagging of sentences will be done.");
        }
    }
    for (int tNum = 0, ttSize = testTreebank.size(); tNum < ttSize; tNum++) {
        Tree tree = testTreebank.get(tNum);
        int testTreeLen = tree.yield().size();
        if (testTreeLen > op.testOptions.maxLength) {
            continue;
        }
        Tree binaryTree = binaryTestTrees.get(tNum);
        // countByLength[testTreeLen]++;
        System.out.println("-------------------------------------");
        System.out.println("Number: " + (tNum + 1));
        System.out.println("Length: " + testTreeLen);
        //tree.pennPrint(pw);
        // System.out.println("XXXX The binary tree is");
        // binaryTree.pennPrint(pw);
        //System.out.println("Here are the tags in the lexicon:");
        //System.out.println(lex.showTags());
        //System.out.println("Here's the tagnumberer:");
        //System.out.println(Numberer.getGlobalNumberer("tags").toString());
        long timeMil1 = System.currentTimeMillis();
        Timing.tick("Starting parse.");
        if (op.doPCFG) {
            //log.info(op.testOptions.forceTags);
            if (op.testOptions.forceTags) {
                if (tagger != null) {
                    //System.out.println("Using a tagger to set tags");
                    //System.out.println("Tagged sentence as: " + tagger.processSentence(cutLast(wordify(binaryTree.yield()))).toString(false));
                    parser.parse(addLast(tagger.apply(cutLast(wordify(binaryTree.yield())))));
                } else {
                    //System.out.println("Forcing tags to match input.");
                    parser.parse(cleanTags(binaryTree.taggedYield(), tlp));
                }
            } else {
                // System.out.println("XXXX Parsing " + binaryTree.yield());
                parser.parse(binaryTree.yieldHasWord());
            }
        //Timing.tick("Done with pcfg phase.");
        }
        if (op.doDep) {
            dparser.parse(binaryTree.yieldHasWord());
        //Timing.tick("Done with dependency phase.");
        }
        boolean bothPassed = false;
        if (op.doPCFG && op.doDep) {
            bothPassed = bparser.parse(binaryTree.yieldHasWord());
        //Timing.tick("Done with combination phase.");
        }
        long timeMil2 = System.currentTimeMillis();
        long elapsed = timeMil2 - timeMil1;
        log.info("Time: " + ((int) (elapsed / 100)) / 10.00 + " sec.");
        //System.out.println("PCFG Best Parse:");
        Tree tree2b = null;
        Tree tree2 = null;
        //System.out.println("Got full best parse...");
        if (op.doPCFG) {
            tree2b = parser.getBestParse();
            tree2 = debinarizer.transformTree(tree2b);
        }
        //System.out.println("Debinarized parse...");
        //tree2.pennPrint();
        //System.out.println("DepG Best Parse:");
        Tree tree3 = null;
        Tree tree3db = null;
        if (op.doDep) {
            tree3 = dparser.getBestParse();
            // was: but wrong Tree tree3db = debinarizer.transformTree(tree2);
            tree3db = debinarizer.transformTree(tree3);
            tree3.pennPrint(pw);
        }
        //tree.pennPrint();
        //((Tree)binaryTrainTrees.get(tNum)).pennPrint();
        //System.out.println("Combo Best Parse:");
        Tree tree4 = null;
        if (op.doPCFG && op.doDep) {
            try {
                tree4 = bparser.getBestParse();
                if (tree4 == null) {
                    tree4 = tree2b;
                }
            } catch (NullPointerException e) {
                log.info("Blocked, using PCFG parse!");
                tree4 = tree2b;
            }
        }
        if (op.doPCFG && !bothPassed) {
            tree4 = tree2b;
        }
        //tree4.pennPrint();
        if (op.doDep) {
            depDE.evaluate(tree3, binaryTree, pw);
            depTE.evaluate(tree3db, tree, pw);
        }
        TreeTransformer tc = op.tlpParams.collinizer();
        TreeTransformer tcEvalb = op.tlpParams.collinizerEvalb();
        if (op.doPCFG) {
            // System.out.println("XXXX Best PCFG was: ");
            // tree2.pennPrint();
            // System.out.println("XXXX Transformed best PCFG is: ");
            // tc.transformTree(tree2).pennPrint();
            //System.out.println("True Best Parse:");
            //tree.pennPrint();
            //tc.transformTree(tree).pennPrint();
            pcfgPE.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw);
            pcfgCB.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw);
            Tree tree4b = null;
            if (op.doDep) {
                comboDE.evaluate((bothPassed ? tree4 : tree3), binaryTree, pw);
                tree4b = tree4;
                tree4 = debinarizer.transformTree(tree4);
                if (op.nodePrune) {
                    NodePruner np = new NodePruner(parser, debinarizer);
                    tree4 = np.prune(tree4);
                }
                //tree4.pennPrint();
                comboPE.evaluate(tc.transformTree(tree4), tc.transformTree(tree), pw);
            }
            //pcfgTE.evaluate(tree2, tree);
            pcfgTE.evaluate(tcEvalb.transformTree(tree2), tcEvalb.transformTree(tree), pw);
            pcfgTEnoPunct.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw);
            if (op.doDep) {
                comboTE.evaluate(tcEvalb.transformTree(tree4), tcEvalb.transformTree(tree), pw);
                comboTEnoPunct.evaluate(tc.transformTree(tree4), tc.transformTree(tree), pw);
            }
            System.out.println("PCFG only: " + parser.scoreBinarizedTree(tree2b, 0));
            //tc.transformTree(tree2).pennPrint();
            tree2.pennPrint(pw);
            if (op.doDep) {
                System.out.println("Combo: " + parser.scoreBinarizedTree(tree4b, 0));
                // tc.transformTree(tree4).pennPrint(pw);
                tree4.pennPrint(pw);
            }
            System.out.println("Correct:" + parser.scoreBinarizedTree(binaryTree, 0));
            /*
        if (parser.scoreBinarizedTree(tree2b,true) < parser.scoreBinarizedTree(binaryTree,true)) {
          System.out.println("SCORE INVERSION");
          parser.validateBinarizedTree(binaryTree,0);
        }
        */
            tree.pennPrint(pw);
        }
        if (op.testOptions.evalb) {
            if (op.doPCFG && op.doDep) {
                EvalbFormatWriter.writeEVALBline(tcEvalb.transformTree(tree), tcEvalb.transformTree(tree4));
            } else if (op.doPCFG) {
                EvalbFormatWriter.writeEVALBline(tcEvalb.transformTree(tree), tcEvalb.transformTree(tree2));
            } else if (op.doDep) {
                EvalbFormatWriter.writeEVALBline(tcEvalb.transformTree(tree), tcEvalb.transformTree(tree3db));
            }
        }
    }
    if (op.testOptions.evalb) {
        EvalbFormatWriter.closeEVALBfiles();
    }
    // op.testOptions.display();
    if (op.doPCFG) {
        pcfgPE.display(false, pw);
        System.out.println("Grammar size: " + stateIndex.size());
        pcfgCB.display(false, pw);
        if (op.doDep) {
            comboPE.display(false, pw);
        }
        pcfgTE.display(false, pw);
        pcfgTEnoPunct.display(false, pw);
        if (op.doDep) {
            comboTE.display(false, pw);
            comboTEnoPunct.display(false, pw);
        }
    }
    if (op.doDep) {
        depTE.display(false, pw);
        depDE.display(false, pw);
    }
    if (op.doPCFG && op.doDep) {
        comboDE.display(false, pw);
    }
// pcfgPE.printGoodBad();
}
Also used : Treebank(edu.stanford.nlp.trees.Treebank) MemoryTreebank(edu.stanford.nlp.trees.MemoryTreebank) ArrayList(java.util.ArrayList) Tree(edu.stanford.nlp.trees.Tree) TreebankLanguagePack(edu.stanford.nlp.trees.TreebankLanguagePack) ArrayList(java.util.ArrayList) List(java.util.List) TaggingEval(edu.stanford.nlp.parser.metrics.TaggingEval) NumberRangeFileFilter(edu.stanford.nlp.io.NumberRangeFileFilter) Evalb(edu.stanford.nlp.parser.metrics.Evalb) TreeTransformer(edu.stanford.nlp.trees.TreeTransformer) UnlabeledAttachmentEval(edu.stanford.nlp.parser.metrics.UnlabeledAttachmentEval) MemoryTreebank(edu.stanford.nlp.trees.MemoryTreebank) PrintWriter(java.io.PrintWriter) Pair(edu.stanford.nlp.util.Pair) HasWord(edu.stanford.nlp.ling.HasWord) AbstractEval(edu.stanford.nlp.parser.metrics.AbstractEval) LeftHeadFinder(edu.stanford.nlp.trees.LeftHeadFinder) TreeLengthComparator(edu.stanford.nlp.trees.TreeLengthComparator) HashIndex(edu.stanford.nlp.util.HashIndex)

Example 10 with Treebank

use of edu.stanford.nlp.trees.Treebank in project CoreNLP by stanfordnlp.

the class DVParser method main.

/**
   * An example command line for training a new parser:
   * <br>
   *  nohup java -mx6g edu.stanford.nlp.parser.dvparser.DVParser -cachedTrees /scr/nlp/data/dvparser/wsj/cached.wsj.train.simple.ser.gz -train -testTreebank  /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj/22 2200-2219 -debugOutputFrequency 400 -nofilter -trainingThreads 5 -parser /u/nlp/data/lexparser/wsjPCFG.nocompact.simple.ser.gz -trainingIterations 40 -batchSize 25 -model /scr/nlp/data/dvparser/wsj/wsj.combine.v2.ser.gz -unkWord "*UNK*" -dvCombineCategories &gt; /scr/nlp/data/dvparser/wsj/wsj.combine.v2.out 2&gt;&amp;1 &amp;
   */
public static void main(String[] args) throws IOException, ClassNotFoundException {
    if (args.length == 0) {
        help();
        System.exit(2);
    }
    log.info("Running DVParser with arguments:");
    for (String arg : args) {
        log.info("  " + arg);
    }
    log.info();
    String parserPath = null;
    String trainTreebankPath = null;
    FileFilter trainTreebankFilter = null;
    String cachedTrainTreesPath = null;
    boolean runGradientCheck = false;
    boolean runTraining = false;
    String testTreebankPath = null;
    FileFilter testTreebankFilter = null;
    String initialModelPath = null;
    String modelPath = null;
    boolean filter = true;
    String resultsRecordPath = null;
    List<String> unusedArgs = new ArrayList<>();
    // These parameters can be null or 0 if the model was not
    // serialized with the new parameters.  Setting the options at the
    // command line will override these defaults.
    // TODO: if/when we integrate back into the main branch and
    // rebuild models, we can get rid of this
    List<String> argsWithDefaults = new ArrayList<>(Arrays.asList(new String[] { "-wordVectorFile", Options.LexOptions.DEFAULT_WORD_VECTOR_FILE, "-dvKBest", Integer.toString(TrainOptions.DEFAULT_K_BEST), "-batchSize", Integer.toString(TrainOptions.DEFAULT_BATCH_SIZE), "-trainingIterations", Integer.toString(TrainOptions.DEFAULT_TRAINING_ITERATIONS), "-qnIterationsPerBatch", Integer.toString(TrainOptions.DEFAULT_QN_ITERATIONS_PER_BATCH), "-regCost", Double.toString(TrainOptions.DEFAULT_REGCOST), "-learningRate", Double.toString(TrainOptions.DEFAULT_LEARNING_RATE), "-deltaMargin", Double.toString(TrainOptions.DEFAULT_DELTA_MARGIN), "-unknownNumberVector", "-unknownDashedWordVectors", "-unknownCapsVector", "-unknownchinesepercentvector", "-unknownchinesenumbervector", "-unknownchineseyearvector", "-unkWord", "*UNK*", "-transformMatrixType", "DIAGONAL", "-scalingForInit", Double.toString(TrainOptions.DEFAULT_SCALING_FOR_INIT), "-trainWordVectors" }));
    argsWithDefaults.addAll(Arrays.asList(args));
    args = argsWithDefaults.toArray(new String[argsWithDefaults.size()]);
    for (int argIndex = 0; argIndex < args.length; ) {
        if (args[argIndex].equalsIgnoreCase("-parser")) {
            parserPath = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
            Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
            argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
            testTreebankPath = treebankDescription.first();
            testTreebankFilter = treebankDescription.second();
        } else if (args[argIndex].equalsIgnoreCase("-treebank")) {
            Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-treebank");
            argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
            trainTreebankPath = treebankDescription.first();
            trainTreebankFilter = treebankDescription.second();
        } else if (args[argIndex].equalsIgnoreCase("-cachedTrees")) {
            cachedTrainTreesPath = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-runGradientCheck")) {
            runGradientCheck = true;
            argIndex++;
        } else if (args[argIndex].equalsIgnoreCase("-train")) {
            runTraining = true;
            argIndex++;
        } else if (args[argIndex].equalsIgnoreCase("-model")) {
            modelPath = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-nofilter")) {
            filter = false;
            argIndex++;
        } else if (args[argIndex].equalsIgnoreCase("-continueTraining")) {
            runTraining = true;
            filter = false;
            initialModelPath = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-resultsRecord")) {
            resultsRecordPath = args[argIndex + 1];
            argIndex += 2;
        } else {
            unusedArgs.add(args[argIndex++]);
        }
    }
    if (parserPath == null && modelPath == null) {
        throw new IllegalArgumentException("Must supply either a base parser model with -parser or a serialized DVParser with -model");
    }
    if (!runTraining && modelPath == null && !runGradientCheck) {
        throw new IllegalArgumentException("Need to either train a new model, run the gradient check or specify a model to load with -model");
    }
    String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
    DVParser dvparser = null;
    LexicalizedParser lexparser = null;
    if (initialModelPath != null) {
        lexparser = LexicalizedParser.loadModel(initialModelPath, newArgs);
        DVModel model = getModelFromLexicalizedParser(lexparser);
        dvparser = new DVParser(model, lexparser);
    } else if (runTraining || runGradientCheck) {
        lexparser = LexicalizedParser.loadModel(parserPath, newArgs);
        dvparser = new DVParser(lexparser);
    } else if (modelPath != null) {
        lexparser = LexicalizedParser.loadModel(modelPath, newArgs);
        DVModel model = getModelFromLexicalizedParser(lexparser);
        dvparser = new DVParser(model, lexparser);
    }
    List<Tree> trainSentences = new ArrayList<>();
    IdentityHashMap<Tree, byte[]> trainCompressedParses = Generics.newIdentityHashMap();
    if (cachedTrainTreesPath != null) {
        for (String path : cachedTrainTreesPath.split(",")) {
            List<Pair<Tree, byte[]>> cache = IOUtils.readObjectFromFile(path);
            for (Pair<Tree, byte[]> pair : cache) {
                trainSentences.add(pair.first());
                trainCompressedParses.put(pair.first(), pair.second());
            }
            log.info("Read in " + cache.size() + " trees from " + path);
        }
    }
    if (trainTreebankPath != null) {
        // TODO: make the transformer a member of the model?
        TreeTransformer transformer = buildTrainTransformer(dvparser.getOp());
        Treebank treebank = dvparser.getOp().tlpParams.memoryTreebank();
        ;
        treebank.loadPath(trainTreebankPath, trainTreebankFilter);
        treebank = treebank.transform(transformer);
        log.info("Read in " + treebank.size() + " trees from " + trainTreebankPath);
        CacheParseHypotheses cacher = new CacheParseHypotheses(dvparser.parser);
        CacheParseHypotheses.CacheProcessor processor = new CacheParseHypotheses.CacheProcessor(cacher, lexparser, dvparser.op.trainOptions.dvKBest, transformer);
        for (Tree tree : treebank) {
            trainSentences.add(tree);
            trainCompressedParses.put(tree, processor.process(tree).second);
        //System.out.println(tree);
        }
        log.info("Finished parsing " + treebank.size() + " trees, getting " + dvparser.op.trainOptions.dvKBest + " hypotheses each");
    }
    if ((runTraining || runGradientCheck) && filter) {
        log.info("Filtering rules for the given training set");
        dvparser.dvModel.setRulesForTrainingSet(trainSentences, trainCompressedParses);
        log.info("Done filtering rules; " + dvparser.dvModel.numBinaryMatrices + " binary matrices, " + dvparser.dvModel.numUnaryMatrices + " unary matrices, " + dvparser.dvModel.wordVectors.size() + " word vectors");
    }
    //dvparser.dvModel.printAllMatrices();
    Treebank testTreebank = null;
    if (testTreebankPath != null) {
        log.info("Reading in trees from " + testTreebankPath);
        if (testTreebankFilter != null) {
            log.info("Filtering on " + testTreebankFilter);
        }
        testTreebank = dvparser.getOp().tlpParams.memoryTreebank();
        ;
        testTreebank.loadPath(testTreebankPath, testTreebankFilter);
        log.info("Read in " + testTreebank.size() + " trees for testing");
    }
    //    runGradientCheck= true;
    if (runGradientCheck) {
        log.info("Running gradient check on " + trainSentences.size() + " trees");
        dvparser.runGradientCheck(trainSentences, trainCompressedParses);
    }
    if (runTraining) {
        log.info("Training the RNN parser");
        log.info("Current train options: " + dvparser.getOp().trainOptions);
        dvparser.train(trainSentences, trainCompressedParses, testTreebank, modelPath, resultsRecordPath);
        if (modelPath != null) {
            dvparser.saveModel(modelPath);
        }
    }
    if (testTreebankPath != null) {
        EvaluateTreebank evaluator = new EvaluateTreebank(dvparser.attachModelToLexicalizedParser());
        evaluator.testOnTreebank(testTreebank);
    }
    log.info("Successfully ran DVParser");
}
Also used : Treebank(edu.stanford.nlp.trees.Treebank) EvaluateTreebank(edu.stanford.nlp.parser.lexparser.EvaluateTreebank) LexicalizedParser(edu.stanford.nlp.parser.lexparser.LexicalizedParser) ArrayList(java.util.ArrayList) EvaluateTreebank(edu.stanford.nlp.parser.lexparser.EvaluateTreebank) Tree(edu.stanford.nlp.trees.Tree) FileFilter(java.io.FileFilter) TreeTransformer(edu.stanford.nlp.trees.TreeTransformer) CompositeTreeTransformer(edu.stanford.nlp.trees.CompositeTreeTransformer) Pair(edu.stanford.nlp.util.Pair)

Aggregations

Treebank (edu.stanford.nlp.trees.Treebank)27 Tree (edu.stanford.nlp.trees.Tree)16 TreeTransformer (edu.stanford.nlp.trees.TreeTransformer)10 ArrayList (java.util.ArrayList)8 Language (edu.stanford.nlp.international.Language)7 EvaluateTreebank (edu.stanford.nlp.parser.lexparser.EvaluateTreebank)7 TreebankLangParserParams (edu.stanford.nlp.parser.lexparser.TreebankLangParserParams)7 Pair (edu.stanford.nlp.util.Pair)7 PrintWriter (java.io.PrintWriter)7 Label (edu.stanford.nlp.ling.Label)6 LexicalizedParser (edu.stanford.nlp.parser.lexparser.LexicalizedParser)6 FileFilter (java.io.FileFilter)6 Map (java.util.Map)4 CoreLabel (edu.stanford.nlp.ling.CoreLabel)3 EnglishTreebankParserParams (edu.stanford.nlp.parser.lexparser.EnglishTreebankParserParams)3 DiskTreebank (edu.stanford.nlp.trees.DiskTreebank)3 MemoryTreebank (edu.stanford.nlp.trees.MemoryTreebank)3 ArabicMorphoFeatureSpecification (edu.stanford.nlp.international.arabic.ArabicMorphoFeatureSpecification)2 FrenchMorphoFeatureSpecification (edu.stanford.nlp.international.french.FrenchMorphoFeatureSpecification)2 MorphoFeatureSpecification (edu.stanford.nlp.international.morph.MorphoFeatureSpecification)2