Search in sources :

Example 26 with SimpleMatrix

use of org.ejml.simple.SimpleMatrix in project CoreNLP by stanfordnlp.

the class FindNearestNeighbors method main.

public static void main(String[] args) throws Exception {
    String modelPath = null;
    String outputPath = null;
    String testTreebankPath = null;
    FileFilter testTreebankFilter = null;
    List<String> unusedArgs = new ArrayList<>();
    for (int argIndex = 0; argIndex < args.length; ) {
        if (args[argIndex].equalsIgnoreCase("-model")) {
            modelPath = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
            Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
            argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
            testTreebankPath = treebankDescription.first();
            testTreebankFilter = treebankDescription.second();
        } else if (args[argIndex].equalsIgnoreCase("-output")) {
            outputPath = args[argIndex + 1];
            argIndex += 2;
        } else {
            unusedArgs.add(args[argIndex++]);
        }
    }
    if (modelPath == null) {
        throw new IllegalArgumentException("Need to specify -model");
    }
    if (testTreebankPath == null) {
        throw new IllegalArgumentException("Need to specify -testTreebank");
    }
    if (outputPath == null) {
        throw new IllegalArgumentException("Need to specify -output");
    }
    String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
    LexicalizedParser lexparser = LexicalizedParser.loadModel(modelPath, newArgs);
    Treebank testTreebank = null;
    if (testTreebankPath != null) {
        log.info("Reading in trees from " + testTreebankPath);
        if (testTreebankFilter != null) {
            log.info("Filtering on " + testTreebankFilter);
        }
        testTreebank = lexparser.getOp().tlpParams.memoryTreebank();
        ;
        testTreebank.loadPath(testTreebankPath, testTreebankFilter);
        log.info("Read in " + testTreebank.size() + " trees for testing");
    }
    FileWriter out = new FileWriter(outputPath);
    BufferedWriter bout = new BufferedWriter(out);
    log.info("Parsing " + testTreebank.size() + " trees");
    int count = 0;
    List<ParseRecord> records = Generics.newArrayList();
    for (Tree goldTree : testTreebank) {
        List<Word> tokens = goldTree.yieldWords();
        ParserQuery parserQuery = lexparser.parserQuery();
        if (!parserQuery.parse(tokens)) {
            throw new AssertionError("Could not parse: " + tokens);
        }
        if (!(parserQuery instanceof RerankingParserQuery)) {
            throw new IllegalArgumentException("Expected a LexicalizedParser with a Reranker attached");
        }
        RerankingParserQuery rpq = (RerankingParserQuery) parserQuery;
        if (!(rpq.rerankerQuery() instanceof DVModelReranker.Query)) {
            throw new IllegalArgumentException("Expected a LexicalizedParser with a DVModel attached");
        }
        DeepTree tree = ((DVModelReranker.Query) rpq.rerankerQuery()).getDeepTrees().get(0);
        SimpleMatrix rootVector = null;
        for (Map.Entry<Tree, SimpleMatrix> entry : tree.getVectors().entrySet()) {
            if (entry.getKey().label().value().equals("ROOT")) {
                rootVector = entry.getValue();
                break;
            }
        }
        if (rootVector == null) {
            throw new AssertionError("Could not find root nodevector");
        }
        out.write(tokens + "\n");
        out.write(tree.getTree() + "\n");
        for (int i = 0; i < rootVector.getNumElements(); ++i) {
            out.write("  " + rootVector.get(i));
        }
        out.write("\n\n\n");
        count++;
        if (count % 10 == 0) {
            log.info("  " + count);
        }
        records.add(new ParseRecord(tokens, goldTree, tree.getTree(), rootVector, tree.getVectors()));
    }
    log.info("  done parsing");
    List<Pair<Tree, SimpleMatrix>> subtrees = Generics.newArrayList();
    for (ParseRecord record : records) {
        for (Map.Entry<Tree, SimpleMatrix> entry : record.nodeVectors.entrySet()) {
            if (entry.getKey().getLeaves().size() <= maxLength) {
                subtrees.add(Pair.makePair(entry.getKey(), entry.getValue()));
            }
        }
    }
    log.info("There are " + subtrees.size() + " subtrees in the set of trees");
    PriorityQueue<ScoredObject<Pair<Tree, Tree>>> bestmatches = new PriorityQueue<>(101, ScoredComparator.DESCENDING_COMPARATOR);
    for (int i = 0; i < subtrees.size(); ++i) {
        log.info(subtrees.get(i).first().yieldWords());
        log.info(subtrees.get(i).first());
        for (int j = 0; j < subtrees.size(); ++j) {
            if (i == j) {
                continue;
            }
            // TODO: look at basic category?
            double normF = subtrees.get(i).second().minus(subtrees.get(j).second()).normF();
            bestmatches.add(new ScoredObject<>(Pair.makePair(subtrees.get(i).first(), subtrees.get(j).first()), normF));
            if (bestmatches.size() > 100) {
                bestmatches.poll();
            }
        }
        List<ScoredObject<Pair<Tree, Tree>>> ordered = Generics.newArrayList();
        while (bestmatches.size() > 0) {
            ordered.add(bestmatches.poll());
        }
        Collections.reverse(ordered);
        for (ScoredObject<Pair<Tree, Tree>> pair : ordered) {
            log.info(" MATCHED " + pair.object().second.yieldWords() + " ... " + pair.object().second() + " with a score of " + pair.score());
        }
        log.info();
        log.info();
        bestmatches.clear();
    }
    /*
    for (int i = 0; i < records.size(); ++i) {
      if (i % 10 == 0) {
        log.info("  " + i);
      }
      List<ScoredObject<ParseRecord>> scored = Generics.newArrayList();
      for (int j = 0; j < records.size(); ++j) {
        if (i == j) continue;

        double score = 0.0;
        int matches = 0;
        for (Map.Entry<Tree, SimpleMatrix> first : records.get(i).nodeVectors.entrySet()) {
          for (Map.Entry<Tree, SimpleMatrix> second : records.get(j).nodeVectors.entrySet()) {
            String firstBasic = dvparser.dvModel.basicCategory(first.getKey().label().value());
            String secondBasic = dvparser.dvModel.basicCategory(second.getKey().label().value());
            if (firstBasic.equals(secondBasic)) {
              ++matches;
              double normF = first.getValue().minus(second.getValue()).normF();
              score += normF * normF;
            }
          }
        }
        if (matches == 0) {
          score = Double.POSITIVE_INFINITY;
        } else {
          score = score / matches;
        }
        //double score = records.get(i).vector.minus(records.get(j).vector).normF();
        scored.add(new ScoredObject<ParseRecord>(records.get(j), score));
      }
      Collections.sort(scored, ScoredComparator.ASCENDING_COMPARATOR);

      out.write(records.get(i).sentence.toString() + "\n");
      for (int j = 0; j < numNeighbors; ++j) {
        out.write("   " + scored.get(j).score() + ": " + scored.get(j).object().sentence + "\n");
      }
      out.write("\n\n");
    }
    log.info();
    */
    bout.flush();
    out.flush();
    out.close();
}
Also used : Word(edu.stanford.nlp.ling.Word) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) ParserQuery(edu.stanford.nlp.parser.common.ParserQuery) Treebank(edu.stanford.nlp.trees.Treebank) LexicalizedParser(edu.stanford.nlp.parser.lexparser.LexicalizedParser) FileWriter(java.io.FileWriter) ArrayList(java.util.ArrayList) BufferedWriter(java.io.BufferedWriter) SimpleMatrix(org.ejml.simple.SimpleMatrix) ScoredObject(edu.stanford.nlp.util.ScoredObject) DeepTree(edu.stanford.nlp.trees.DeepTree) Tree(edu.stanford.nlp.trees.Tree) DeepTree(edu.stanford.nlp.trees.DeepTree) FileFilter(java.io.FileFilter) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) Pair(edu.stanford.nlp.util.Pair) PriorityQueue(java.util.PriorityQueue) IdentityHashMap(java.util.IdentityHashMap) Map(java.util.Map) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) ParserQuery(edu.stanford.nlp.parser.common.ParserQuery)

Example 27 with SimpleMatrix

use of org.ejml.simple.SimpleMatrix in project CoreNLP by stanfordnlp.

the class ParseAndPrintMatrices method main.

public static void main(String[] args) throws IOException {
    String modelPath = null;
    String outputPath = null;
    String inputPath = null;
    String testTreebankPath = null;
    FileFilter testTreebankFilter = null;
    List<String> unusedArgs = Generics.newArrayList();
    for (int argIndex = 0; argIndex < args.length; ) {
        if (args[argIndex].equalsIgnoreCase("-model")) {
            modelPath = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-output")) {
            outputPath = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-input")) {
            inputPath = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
            Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
            argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
            testTreebankPath = treebankDescription.first();
            testTreebankFilter = treebankDescription.second();
        } else {
            unusedArgs.add(args[argIndex++]);
        }
    }
    String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
    LexicalizedParser parser = LexicalizedParser.loadModel(modelPath, newArgs);
    DVModel model = DVParser.getModelFromLexicalizedParser(parser);
    File outputFile = new File(outputPath);
    FileSystem.checkNotExistsOrFail(outputFile);
    FileSystem.mkdirOrFail(outputFile);
    int count = 0;
    if (inputPath != null) {
        Reader input = new BufferedReader(new FileReader(inputPath));
        DocumentPreprocessor processor = new DocumentPreprocessor(input);
        for (List<HasWord> sentence : processor) {
            // index from 1
            count++;
            ParserQuery pq = parser.parserQuery();
            if (!(pq instanceof RerankingParserQuery)) {
                throw new IllegalArgumentException("Expected a RerankingParserQuery");
            }
            RerankingParserQuery rpq = (RerankingParserQuery) pq;
            if (!rpq.parse(sentence)) {
                throw new RuntimeException("Unparsable sentence: " + sentence);
            }
            RerankerQuery reranker = rpq.rerankerQuery();
            if (!(reranker instanceof DVModelReranker.Query)) {
                throw new IllegalArgumentException("Expected a DVModelReranker");
            }
            DeepTree deepTree = ((DVModelReranker.Query) reranker).getDeepTrees().get(0);
            IdentityHashMap<Tree, SimpleMatrix> vectors = deepTree.getVectors();
            for (Map.Entry<Tree, SimpleMatrix> entry : vectors.entrySet()) {
                log.info(entry.getKey() + "   " + entry.getValue());
            }
            FileWriter fout = new FileWriter(outputPath + File.separator + "sentence" + count + ".txt");
            BufferedWriter bout = new BufferedWriter(fout);
            bout.write(SentenceUtils.listToString(sentence));
            bout.newLine();
            bout.write(deepTree.getTree().toString());
            bout.newLine();
            for (HasWord word : sentence) {
                outputMatrix(bout, model.getWordVector(word.word()));
            }
            Tree rootTree = findRootTree(vectors);
            outputTreeMatrices(bout, rootTree, vectors);
            bout.flush();
            fout.close();
        }
    }
}
Also used : RerankerQuery(edu.stanford.nlp.parser.lexparser.RerankerQuery) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) ParserQuery(edu.stanford.nlp.parser.common.ParserQuery) LexicalizedParser(edu.stanford.nlp.parser.lexparser.LexicalizedParser) FileWriter(java.io.FileWriter) Reader(java.io.Reader) BufferedReader(java.io.BufferedReader) FileReader(java.io.FileReader) BufferedWriter(java.io.BufferedWriter) SimpleMatrix(org.ejml.simple.SimpleMatrix) DeepTree(edu.stanford.nlp.trees.DeepTree) Tree(edu.stanford.nlp.trees.Tree) FileReader(java.io.FileReader) DeepTree(edu.stanford.nlp.trees.DeepTree) FileFilter(java.io.FileFilter) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) Pair(edu.stanford.nlp.util.Pair) HasWord(edu.stanford.nlp.ling.HasWord) RerankerQuery(edu.stanford.nlp.parser.lexparser.RerankerQuery) BufferedReader(java.io.BufferedReader) DocumentPreprocessor(edu.stanford.nlp.process.DocumentPreprocessor) File(java.io.File) Map(java.util.Map) IdentityHashMap(java.util.IdentityHashMap) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) ParserQuery(edu.stanford.nlp.parser.common.ParserQuery)

Example 28 with SimpleMatrix

use of org.ejml.simple.SimpleMatrix in project CoreNLP by stanfordnlp.

the class DVModel method randomTransformMatrix.

/**
   * Create a random transform matrix based on the initialization
   * parameters.  This will be numRows x numCols big.  These can be
   * plugged into either unary or binary transform matrices.
   */
private SimpleMatrix randomTransformMatrix() {
    SimpleMatrix matrix;
    switch(op.trainOptions.transformMatrixType) {
        case DIAGONAL:
            matrix = SimpleMatrix.random(numRows, numCols, -1.0 / Math.sqrt((double) numCols * 100.0), 1.0 / Math.sqrt((double) numCols * 100.0), rand).plus(identity);
            break;
        case RANDOM:
            matrix = SimpleMatrix.random(numRows, numCols, -1.0 / Math.sqrt((double) numCols), 1.0 / Math.sqrt((double) numCols), rand);
            break;
        case OFF_DIAGONAL:
            matrix = SimpleMatrix.random(numRows, numCols, -1.0 / Math.sqrt((double) numCols * 100.0), 1.0 / Math.sqrt((double) numCols * 100.0), rand).plus(identity);
            for (int i = 0; i < numCols; ++i) {
                int x = rand.nextInt(numCols);
                int y = rand.nextInt(numCols);
                // -1, 0, or 1
                int scale = rand.nextInt(3) - 1;
                matrix.set(x, y, matrix.get(x, y) + scale);
            }
            break;
        case RANDOM_ZEROS:
            matrix = SimpleMatrix.random(numRows, numCols, -1.0 / Math.sqrt((double) numCols * 100.0), 1.0 / Math.sqrt((double) numCols * 100.0), rand).plus(identity);
            for (int i = 0; i < numCols; ++i) {
                int x = rand.nextInt(numCols);
                int y = rand.nextInt(numCols);
                matrix.set(x, y, 0.0);
            }
            break;
        default:
            throw new IllegalArgumentException("Unexpected matrix initialization type " + op.trainOptions.transformMatrixType);
    }
    return matrix;
}
Also used : SimpleMatrix(org.ejml.simple.SimpleMatrix)

Example 29 with SimpleMatrix

use of org.ejml.simple.SimpleMatrix in project CoreNLP by stanfordnlp.

the class DVModel method randomContextMatrix.

/**
   * Creates a random context matrix.  This will be numRows x
   * 2*numCols big.  These can be appended to the end of either a
   * unary or binary transform matrix to get the transform matrix
   * which uses context words.
   */
private SimpleMatrix randomContextMatrix() {
    SimpleMatrix matrix = new SimpleMatrix(numRows, numCols * 2);
    matrix.insertIntoThis(0, 0, identity.scale(op.trainOptions.scalingForInit * 0.1));
    matrix.insertIntoThis(0, numCols, identity.scale(op.trainOptions.scalingForInit * 0.1));
    matrix = matrix.plus(SimpleMatrix.random(numRows, numCols * 2, -1.0 / Math.sqrt((double) numCols * 100.0), 1.0 / Math.sqrt((double) numCols * 100.0), rand));
    return matrix;
}
Also used : SimpleMatrix(org.ejml.simple.SimpleMatrix)

Example 30 with SimpleMatrix

use of org.ejml.simple.SimpleMatrix in project CoreNLP by stanfordnlp.

the class DVModel method filterRulesForBatch.

public void filterRulesForBatch(TwoDimensionalSet<String, String> binaryRules, Set<String> unaryRules, Set<String> words) {
    TwoDimensionalMap<String, String, SimpleMatrix> newBinaryTransforms = TwoDimensionalMap.treeMap();
    TwoDimensionalMap<String, String, SimpleMatrix> newBinaryScores = TwoDimensionalMap.treeMap();
    for (Pair<String, String> binaryRule : binaryRules) {
        SimpleMatrix transform = binaryTransform.get(binaryRule.first(), binaryRule.second());
        if (transform != null) {
            newBinaryTransforms.put(binaryRule.first(), binaryRule.second(), transform);
        }
        SimpleMatrix score = binaryScore.get(binaryRule.first(), binaryRule.second());
        if (score != null) {
            newBinaryScores.put(binaryRule.first(), binaryRule.second(), score);
        }
        if ((transform == null && score != null) || (transform != null && score == null)) {
            throw new AssertionError();
        }
    }
    binaryTransform = newBinaryTransforms;
    binaryScore = newBinaryScores;
    numBinaryMatrices = binaryTransform.size();
    Map<String, SimpleMatrix> newUnaryTransforms = Generics.newTreeMap();
    Map<String, SimpleMatrix> newUnaryScores = Generics.newTreeMap();
    for (String unaryRule : unaryRules) {
        SimpleMatrix transform = unaryTransform.get(unaryRule);
        if (transform != null) {
            newUnaryTransforms.put(unaryRule, transform);
        }
        SimpleMatrix score = unaryScore.get(unaryRule);
        if (score != null) {
            newUnaryScores.put(unaryRule, score);
        }
        if ((transform == null && score != null) || (transform != null && score == null)) {
            throw new AssertionError();
        }
    }
    unaryTransform = newUnaryTransforms;
    unaryScore = newUnaryScores;
    numUnaryMatrices = unaryTransform.size();
    Map<String, SimpleMatrix> newWordVectors = Generics.newTreeMap();
    for (String word : words) {
        SimpleMatrix wordVector = wordVectors.get(word);
        if (wordVector != null) {
            newWordVectors.put(word, wordVector);
        }
    }
    wordVectors = newWordVectors;
}
Also used : SimpleMatrix(org.ejml.simple.SimpleMatrix)

Aggregations

SimpleMatrix (org.ejml.simple.SimpleMatrix)52 Tree (edu.stanford.nlp.trees.Tree)8 Map (java.util.Map)7 DeepTree (edu.stanford.nlp.trees.DeepTree)5 TwoDimensionalMap (edu.stanford.nlp.util.TwoDimensionalMap)5 SimpleTensor (edu.stanford.nlp.neural.SimpleTensor)4 LexicalizedParser (edu.stanford.nlp.parser.lexparser.LexicalizedParser)4 Pair (edu.stanford.nlp.util.Pair)4 IdentityHashMap (java.util.IdentityHashMap)4 Mention (edu.stanford.nlp.coref.data.Mention)3 BufferedWriter (java.io.BufferedWriter)3 File (java.io.File)3 FileWriter (java.io.FileWriter)3 ArrayList (java.util.ArrayList)3 CoreLabel (edu.stanford.nlp.ling.CoreLabel)2 Embedding (edu.stanford.nlp.neural.Embedding)2 ParserQuery (edu.stanford.nlp.parser.common.ParserQuery)2 RerankingParserQuery (edu.stanford.nlp.parser.lexparser.RerankingParserQuery)2 FileFilter (java.io.FileFilter)2 Bone (com.jme3.animation.Bone)1