Search in sources :

Example 1 with RerankingParserQuery

use of edu.stanford.nlp.parser.lexparser.RerankingParserQuery in project CoreNLP by stanfordnlp.

the class FindNearestNeighbors method main.

public static void main(String[] args) throws Exception {
    String modelPath = null;
    String outputPath = null;
    String testTreebankPath = null;
    FileFilter testTreebankFilter = null;
    List<String> unusedArgs = new ArrayList<>();
    for (int argIndex = 0; argIndex < args.length; ) {
        if (args[argIndex].equalsIgnoreCase("-model")) {
            modelPath = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
            Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
            argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
            testTreebankPath = treebankDescription.first();
            testTreebankFilter = treebankDescription.second();
        } else if (args[argIndex].equalsIgnoreCase("-output")) {
            outputPath = args[argIndex + 1];
            argIndex += 2;
        } else {
            unusedArgs.add(args[argIndex++]);
        }
    }
    if (modelPath == null) {
        throw new IllegalArgumentException("Need to specify -model");
    }
    if (testTreebankPath == null) {
        throw new IllegalArgumentException("Need to specify -testTreebank");
    }
    if (outputPath == null) {
        throw new IllegalArgumentException("Need to specify -output");
    }
    String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
    LexicalizedParser lexparser = LexicalizedParser.loadModel(modelPath, newArgs);
    Treebank testTreebank = null;
    if (testTreebankPath != null) {
        log.info("Reading in trees from " + testTreebankPath);
        if (testTreebankFilter != null) {
            log.info("Filtering on " + testTreebankFilter);
        }
        testTreebank = lexparser.getOp().tlpParams.memoryTreebank();
        ;
        testTreebank.loadPath(testTreebankPath, testTreebankFilter);
        log.info("Read in " + testTreebank.size() + " trees for testing");
    }
    FileWriter out = new FileWriter(outputPath);
    BufferedWriter bout = new BufferedWriter(out);
    log.info("Parsing " + testTreebank.size() + " trees");
    int count = 0;
    List<ParseRecord> records = Generics.newArrayList();
    for (Tree goldTree : testTreebank) {
        List<Word> tokens = goldTree.yieldWords();
        ParserQuery parserQuery = lexparser.parserQuery();
        if (!parserQuery.parse(tokens)) {
            throw new AssertionError("Could not parse: " + tokens);
        }
        if (!(parserQuery instanceof RerankingParserQuery)) {
            throw new IllegalArgumentException("Expected a LexicalizedParser with a Reranker attached");
        }
        RerankingParserQuery rpq = (RerankingParserQuery) parserQuery;
        if (!(rpq.rerankerQuery() instanceof DVModelReranker.Query)) {
            throw new IllegalArgumentException("Expected a LexicalizedParser with a DVModel attached");
        }
        DeepTree tree = ((DVModelReranker.Query) rpq.rerankerQuery()).getDeepTrees().get(0);
        SimpleMatrix rootVector = null;
        for (Map.Entry<Tree, SimpleMatrix> entry : tree.getVectors().entrySet()) {
            if (entry.getKey().label().value().equals("ROOT")) {
                rootVector = entry.getValue();
                break;
            }
        }
        if (rootVector == null) {
            throw new AssertionError("Could not find root nodevector");
        }
        out.write(tokens + "\n");
        out.write(tree.getTree() + "\n");
        for (int i = 0; i < rootVector.getNumElements(); ++i) {
            out.write("  " + rootVector.get(i));
        }
        out.write("\n\n\n");
        count++;
        if (count % 10 == 0) {
            log.info("  " + count);
        }
        records.add(new ParseRecord(tokens, goldTree, tree.getTree(), rootVector, tree.getVectors()));
    }
    log.info("  done parsing");
    List<Pair<Tree, SimpleMatrix>> subtrees = Generics.newArrayList();
    for (ParseRecord record : records) {
        for (Map.Entry<Tree, SimpleMatrix> entry : record.nodeVectors.entrySet()) {
            if (entry.getKey().getLeaves().size() <= maxLength) {
                subtrees.add(Pair.makePair(entry.getKey(), entry.getValue()));
            }
        }
    }
    log.info("There are " + subtrees.size() + " subtrees in the set of trees");
    PriorityQueue<ScoredObject<Pair<Tree, Tree>>> bestmatches = new PriorityQueue<>(101, ScoredComparator.DESCENDING_COMPARATOR);
    for (int i = 0; i < subtrees.size(); ++i) {
        log.info(subtrees.get(i).first().yieldWords());
        log.info(subtrees.get(i).first());
        for (int j = 0; j < subtrees.size(); ++j) {
            if (i == j) {
                continue;
            }
            // TODO: look at basic category?
            double normF = subtrees.get(i).second().minus(subtrees.get(j).second()).normF();
            bestmatches.add(new ScoredObject<>(Pair.makePair(subtrees.get(i).first(), subtrees.get(j).first()), normF));
            if (bestmatches.size() > 100) {
                bestmatches.poll();
            }
        }
        List<ScoredObject<Pair<Tree, Tree>>> ordered = Generics.newArrayList();
        while (bestmatches.size() > 0) {
            ordered.add(bestmatches.poll());
        }
        Collections.reverse(ordered);
        for (ScoredObject<Pair<Tree, Tree>> pair : ordered) {
            log.info(" MATCHED " + pair.object().second.yieldWords() + " ... " + pair.object().second() + " with a score of " + pair.score());
        }
        log.info();
        log.info();
        bestmatches.clear();
    }
    /*
    for (int i = 0; i < records.size(); ++i) {
      if (i % 10 == 0) {
        log.info("  " + i);
      }
      List<ScoredObject<ParseRecord>> scored = Generics.newArrayList();
      for (int j = 0; j < records.size(); ++j) {
        if (i == j) continue;

        double score = 0.0;
        int matches = 0;
        for (Map.Entry<Tree, SimpleMatrix> first : records.get(i).nodeVectors.entrySet()) {
          for (Map.Entry<Tree, SimpleMatrix> second : records.get(j).nodeVectors.entrySet()) {
            String firstBasic = dvparser.dvModel.basicCategory(first.getKey().label().value());
            String secondBasic = dvparser.dvModel.basicCategory(second.getKey().label().value());
            if (firstBasic.equals(secondBasic)) {
              ++matches;
              double normF = first.getValue().minus(second.getValue()).normF();
              score += normF * normF;
            }
          }
        }
        if (matches == 0) {
          score = Double.POSITIVE_INFINITY;
        } else {
          score = score / matches;
        }
        //double score = records.get(i).vector.minus(records.get(j).vector).normF();
        scored.add(new ScoredObject<ParseRecord>(records.get(j), score));
      }
      Collections.sort(scored, ScoredComparator.ASCENDING_COMPARATOR);

      out.write(records.get(i).sentence.toString() + "\n");
      for (int j = 0; j < numNeighbors; ++j) {
        out.write("   " + scored.get(j).score() + ": " + scored.get(j).object().sentence + "\n");
      }
      out.write("\n\n");
    }
    log.info();
    */
    bout.flush();
    out.flush();
    out.close();
}
Also used : Word(edu.stanford.nlp.ling.Word) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) ParserQuery(edu.stanford.nlp.parser.common.ParserQuery) Treebank(edu.stanford.nlp.trees.Treebank) LexicalizedParser(edu.stanford.nlp.parser.lexparser.LexicalizedParser) FileWriter(java.io.FileWriter) ArrayList(java.util.ArrayList) BufferedWriter(java.io.BufferedWriter) SimpleMatrix(org.ejml.simple.SimpleMatrix) ScoredObject(edu.stanford.nlp.util.ScoredObject) DeepTree(edu.stanford.nlp.trees.DeepTree) Tree(edu.stanford.nlp.trees.Tree) DeepTree(edu.stanford.nlp.trees.DeepTree) FileFilter(java.io.FileFilter) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) Pair(edu.stanford.nlp.util.Pair) PriorityQueue(java.util.PriorityQueue) IdentityHashMap(java.util.IdentityHashMap) Map(java.util.Map) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) ParserQuery(edu.stanford.nlp.parser.common.ParserQuery)

Example 2 with RerankingParserQuery

use of edu.stanford.nlp.parser.lexparser.RerankingParserQuery in project CoreNLP by stanfordnlp.

the class ParseAndPrintMatrices method main.

public static void main(String[] args) throws IOException {
    String modelPath = null;
    String outputPath = null;
    String inputPath = null;
    String testTreebankPath = null;
    FileFilter testTreebankFilter = null;
    List<String> unusedArgs = Generics.newArrayList();
    for (int argIndex = 0; argIndex < args.length; ) {
        if (args[argIndex].equalsIgnoreCase("-model")) {
            modelPath = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-output")) {
            outputPath = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-input")) {
            inputPath = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
            Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
            argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
            testTreebankPath = treebankDescription.first();
            testTreebankFilter = treebankDescription.second();
        } else {
            unusedArgs.add(args[argIndex++]);
        }
    }
    String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
    LexicalizedParser parser = LexicalizedParser.loadModel(modelPath, newArgs);
    DVModel model = DVParser.getModelFromLexicalizedParser(parser);
    File outputFile = new File(outputPath);
    FileSystem.checkNotExistsOrFail(outputFile);
    FileSystem.mkdirOrFail(outputFile);
    int count = 0;
    if (inputPath != null) {
        Reader input = new BufferedReader(new FileReader(inputPath));
        DocumentPreprocessor processor = new DocumentPreprocessor(input);
        for (List<HasWord> sentence : processor) {
            // index from 1
            count++;
            ParserQuery pq = parser.parserQuery();
            if (!(pq instanceof RerankingParserQuery)) {
                throw new IllegalArgumentException("Expected a RerankingParserQuery");
            }
            RerankingParserQuery rpq = (RerankingParserQuery) pq;
            if (!rpq.parse(sentence)) {
                throw new RuntimeException("Unparsable sentence: " + sentence);
            }
            RerankerQuery reranker = rpq.rerankerQuery();
            if (!(reranker instanceof DVModelReranker.Query)) {
                throw new IllegalArgumentException("Expected a DVModelReranker");
            }
            DeepTree deepTree = ((DVModelReranker.Query) reranker).getDeepTrees().get(0);
            IdentityHashMap<Tree, SimpleMatrix> vectors = deepTree.getVectors();
            for (Map.Entry<Tree, SimpleMatrix> entry : vectors.entrySet()) {
                log.info(entry.getKey() + "   " + entry.getValue());
            }
            FileWriter fout = new FileWriter(outputPath + File.separator + "sentence" + count + ".txt");
            BufferedWriter bout = new BufferedWriter(fout);
            bout.write(SentenceUtils.listToString(sentence));
            bout.newLine();
            bout.write(deepTree.getTree().toString());
            bout.newLine();
            for (HasWord word : sentence) {
                outputMatrix(bout, model.getWordVector(word.word()));
            }
            Tree rootTree = findRootTree(vectors);
            outputTreeMatrices(bout, rootTree, vectors);
            bout.flush();
            fout.close();
        }
    }
}
Also used : RerankerQuery(edu.stanford.nlp.parser.lexparser.RerankerQuery) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) ParserQuery(edu.stanford.nlp.parser.common.ParserQuery) LexicalizedParser(edu.stanford.nlp.parser.lexparser.LexicalizedParser) FileWriter(java.io.FileWriter) Reader(java.io.Reader) BufferedReader(java.io.BufferedReader) FileReader(java.io.FileReader) BufferedWriter(java.io.BufferedWriter) SimpleMatrix(org.ejml.simple.SimpleMatrix) DeepTree(edu.stanford.nlp.trees.DeepTree) Tree(edu.stanford.nlp.trees.Tree) FileReader(java.io.FileReader) DeepTree(edu.stanford.nlp.trees.DeepTree) FileFilter(java.io.FileFilter) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) Pair(edu.stanford.nlp.util.Pair) HasWord(edu.stanford.nlp.ling.HasWord) RerankerQuery(edu.stanford.nlp.parser.lexparser.RerankerQuery) BufferedReader(java.io.BufferedReader) DocumentPreprocessor(edu.stanford.nlp.process.DocumentPreprocessor) File(java.io.File) Map(java.util.Map) IdentityHashMap(java.util.IdentityHashMap) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) ParserQuery(edu.stanford.nlp.parser.common.ParserQuery)

Aggregations

ParserQuery (edu.stanford.nlp.parser.common.ParserQuery)2 LexicalizedParser (edu.stanford.nlp.parser.lexparser.LexicalizedParser)2 RerankingParserQuery (edu.stanford.nlp.parser.lexparser.RerankingParserQuery)2 DeepTree (edu.stanford.nlp.trees.DeepTree)2 Tree (edu.stanford.nlp.trees.Tree)2 Pair (edu.stanford.nlp.util.Pair)2 BufferedWriter (java.io.BufferedWriter)2 FileFilter (java.io.FileFilter)2 FileWriter (java.io.FileWriter)2 IdentityHashMap (java.util.IdentityHashMap)2 Map (java.util.Map)2 SimpleMatrix (org.ejml.simple.SimpleMatrix)2 HasWord (edu.stanford.nlp.ling.HasWord)1 Word (edu.stanford.nlp.ling.Word)1 RerankerQuery (edu.stanford.nlp.parser.lexparser.RerankerQuery)1 DocumentPreprocessor (edu.stanford.nlp.process.DocumentPreprocessor)1 Treebank (edu.stanford.nlp.trees.Treebank)1 ScoredObject (edu.stanford.nlp.util.ScoredObject)1 BufferedReader (java.io.BufferedReader)1 File (java.io.File)1