Search in sources :

Example 6 with ScoredObject

use of edu.stanford.nlp.util.ScoredObject in project CoreNLP by stanfordnlp.

the class FindNearestNeighbors method main.

public static void main(String[] args) throws Exception {
    String modelPath = null;
    String outputPath = null;
    String testTreebankPath = null;
    FileFilter testTreebankFilter = null;
    List<String> unusedArgs = new ArrayList<>();
    for (int argIndex = 0; argIndex < args.length; ) {
        if (args[argIndex].equalsIgnoreCase("-model")) {
            modelPath = args[argIndex + 1];
            argIndex += 2;
        } else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
            Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
            argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
            testTreebankPath = treebankDescription.first();
            testTreebankFilter = treebankDescription.second();
        } else if (args[argIndex].equalsIgnoreCase("-output")) {
            outputPath = args[argIndex + 1];
            argIndex += 2;
        } else {
            unusedArgs.add(args[argIndex++]);
        }
    }
    if (modelPath == null) {
        throw new IllegalArgumentException("Need to specify -model");
    }
    if (testTreebankPath == null) {
        throw new IllegalArgumentException("Need to specify -testTreebank");
    }
    if (outputPath == null) {
        throw new IllegalArgumentException("Need to specify -output");
    }
    String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
    LexicalizedParser lexparser = LexicalizedParser.loadModel(modelPath, newArgs);
    Treebank testTreebank = null;
    if (testTreebankPath != null) {
        log.info("Reading in trees from " + testTreebankPath);
        if (testTreebankFilter != null) {
            log.info("Filtering on " + testTreebankFilter);
        }
        testTreebank = lexparser.getOp().tlpParams.memoryTreebank();
        ;
        testTreebank.loadPath(testTreebankPath, testTreebankFilter);
        log.info("Read in " + testTreebank.size() + " trees for testing");
    }
    FileWriter out = new FileWriter(outputPath);
    BufferedWriter bout = new BufferedWriter(out);
    log.info("Parsing " + testTreebank.size() + " trees");
    int count = 0;
    List<ParseRecord> records = Generics.newArrayList();
    for (Tree goldTree : testTreebank) {
        List<Word> tokens = goldTree.yieldWords();
        ParserQuery parserQuery = lexparser.parserQuery();
        if (!parserQuery.parse(tokens)) {
            throw new AssertionError("Could not parse: " + tokens);
        }
        if (!(parserQuery instanceof RerankingParserQuery)) {
            throw new IllegalArgumentException("Expected a LexicalizedParser with a Reranker attached");
        }
        RerankingParserQuery rpq = (RerankingParserQuery) parserQuery;
        if (!(rpq.rerankerQuery() instanceof DVModelReranker.Query)) {
            throw new IllegalArgumentException("Expected a LexicalizedParser with a DVModel attached");
        }
        DeepTree tree = ((DVModelReranker.Query) rpq.rerankerQuery()).getDeepTrees().get(0);
        SimpleMatrix rootVector = null;
        for (Map.Entry<Tree, SimpleMatrix> entry : tree.getVectors().entrySet()) {
            if (entry.getKey().label().value().equals("ROOT")) {
                rootVector = entry.getValue();
                break;
            }
        }
        if (rootVector == null) {
            throw new AssertionError("Could not find root nodevector");
        }
        out.write(tokens + "\n");
        out.write(tree.getTree() + "\n");
        for (int i = 0; i < rootVector.getNumElements(); ++i) {
            out.write("  " + rootVector.get(i));
        }
        out.write("\n\n\n");
        count++;
        if (count % 10 == 0) {
            log.info("  " + count);
        }
        records.add(new ParseRecord(tokens, goldTree, tree.getTree(), rootVector, tree.getVectors()));
    }
    log.info("  done parsing");
    List<Pair<Tree, SimpleMatrix>> subtrees = Generics.newArrayList();
    for (ParseRecord record : records) {
        for (Map.Entry<Tree, SimpleMatrix> entry : record.nodeVectors.entrySet()) {
            if (entry.getKey().getLeaves().size() <= maxLength) {
                subtrees.add(Pair.makePair(entry.getKey(), entry.getValue()));
            }
        }
    }
    log.info("There are " + subtrees.size() + " subtrees in the set of trees");
    PriorityQueue<ScoredObject<Pair<Tree, Tree>>> bestmatches = new PriorityQueue<>(101, ScoredComparator.DESCENDING_COMPARATOR);
    for (int i = 0; i < subtrees.size(); ++i) {
        log.info(subtrees.get(i).first().yieldWords());
        log.info(subtrees.get(i).first());
        for (int j = 0; j < subtrees.size(); ++j) {
            if (i == j) {
                continue;
            }
            // TODO: look at basic category?
            double normF = subtrees.get(i).second().minus(subtrees.get(j).second()).normF();
            bestmatches.add(new ScoredObject<>(Pair.makePair(subtrees.get(i).first(), subtrees.get(j).first()), normF));
            if (bestmatches.size() > 100) {
                bestmatches.poll();
            }
        }
        List<ScoredObject<Pair<Tree, Tree>>> ordered = Generics.newArrayList();
        while (bestmatches.size() > 0) {
            ordered.add(bestmatches.poll());
        }
        Collections.reverse(ordered);
        for (ScoredObject<Pair<Tree, Tree>> pair : ordered) {
            log.info(" MATCHED " + pair.object().second.yieldWords() + " ... " + pair.object().second() + " with a score of " + pair.score());
        }
        log.info();
        log.info();
        bestmatches.clear();
    }
    /*
    for (int i = 0; i < records.size(); ++i) {
      if (i % 10 == 0) {
        log.info("  " + i);
      }
      List<ScoredObject<ParseRecord>> scored = Generics.newArrayList();
      for (int j = 0; j < records.size(); ++j) {
        if (i == j) continue;

        double score = 0.0;
        int matches = 0;
        for (Map.Entry<Tree, SimpleMatrix> first : records.get(i).nodeVectors.entrySet()) {
          for (Map.Entry<Tree, SimpleMatrix> second : records.get(j).nodeVectors.entrySet()) {
            String firstBasic = dvparser.dvModel.basicCategory(first.getKey().label().value());
            String secondBasic = dvparser.dvModel.basicCategory(second.getKey().label().value());
            if (firstBasic.equals(secondBasic)) {
              ++matches;
              double normF = first.getValue().minus(second.getValue()).normF();
              score += normF * normF;
            }
          }
        }
        if (matches == 0) {
          score = Double.POSITIVE_INFINITY;
        } else {
          score = score / matches;
        }
        //double score = records.get(i).vector.minus(records.get(j).vector).normF();
        scored.add(new ScoredObject<ParseRecord>(records.get(j), score));
      }
      Collections.sort(scored, ScoredComparator.ASCENDING_COMPARATOR);

      out.write(records.get(i).sentence.toString() + "\n");
      for (int j = 0; j < numNeighbors; ++j) {
        out.write("   " + scored.get(j).score() + ": " + scored.get(j).object().sentence + "\n");
      }
      out.write("\n\n");
    }
    log.info();
    */
    bout.flush();
    out.flush();
    out.close();
}
Also used : Word(edu.stanford.nlp.ling.Word) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) ParserQuery(edu.stanford.nlp.parser.common.ParserQuery) Treebank(edu.stanford.nlp.trees.Treebank) LexicalizedParser(edu.stanford.nlp.parser.lexparser.LexicalizedParser) FileWriter(java.io.FileWriter) ArrayList(java.util.ArrayList) BufferedWriter(java.io.BufferedWriter) SimpleMatrix(org.ejml.simple.SimpleMatrix) ScoredObject(edu.stanford.nlp.util.ScoredObject) DeepTree(edu.stanford.nlp.trees.DeepTree) Tree(edu.stanford.nlp.trees.Tree) DeepTree(edu.stanford.nlp.trees.DeepTree) FileFilter(java.io.FileFilter) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) Pair(edu.stanford.nlp.util.Pair) PriorityQueue(java.util.PriorityQueue) IdentityHashMap(java.util.IdentityHashMap) Map(java.util.Map) RerankingParserQuery(edu.stanford.nlp.parser.lexparser.RerankingParserQuery) ParserQuery(edu.stanford.nlp.parser.common.ParserQuery)

Example 7 with ScoredObject

use of edu.stanford.nlp.util.ScoredObject in project CoreNLP by stanfordnlp.

the class CharniakScoredParsesReaderWriter method stringToParses.

/**
 * Convert string representing scored parses (in the charniak parser output format)
 *   to list of scored parse trees
 * @param parseStr
 * @return list of scored parse trees
 */
public static List<ScoredObject<Tree>> stringToParses(String parseStr) {
    try {
        BufferedReader br = new BufferedReader(new StringReader(parseStr));
        Iterable<List<ScoredObject<Tree>>> trees = readScoredTrees("", br);
        List<ScoredObject<Tree>> res = null;
        Iterator<List<ScoredObject<Tree>>> iter = trees.iterator();
        if (iter != null && iter.hasNext()) {
            res = iter.next();
        }
        br.close();
        return res;
    } catch (IOException ex) {
        throw new RuntimeException(ex);
    }
}
Also used : ScoredObject(edu.stanford.nlp.util.ScoredObject) BufferedReader(java.io.BufferedReader) StringReader(java.io.StringReader) Tree(edu.stanford.nlp.trees.Tree) ArrayList(java.util.ArrayList) List(java.util.List) IOException(java.io.IOException)

Example 8 with ScoredObject

use of edu.stanford.nlp.util.ScoredObject in project CoreNLP by stanfordnlp.

the class CharniakScoredParsesReaderWriter method printScoredTrees.

/**
 * Print scored parse trees in format used by charniak parser
 * @param trees - trees to output
 * @param filename - file to output to
 */
public static void printScoredTrees(Iterable<List<ScoredObject<Tree>>> trees, String filename) {
    try {
        PrintWriter pw = IOUtils.getPrintWriter(filename);
        int i = 0;
        for (List<ScoredObject<Tree>> treeList : trees) {
            printScoredTrees(pw, i, treeList);
            i++;
        }
        pw.close();
    } catch (IOException ex) {
        throw new RuntimeException(ex);
    }
}
Also used : ScoredObject(edu.stanford.nlp.util.ScoredObject) IOException(java.io.IOException) PrintWriter(java.io.PrintWriter)

Example 9 with ScoredObject

use of edu.stanford.nlp.util.ScoredObject in project CoreNLP by stanfordnlp.

the class PerceptronModel method findHighestScoringTransitions.

private Collection<ScoredObject<Integer>> findHighestScoringTransitions(State state, List<String> features, boolean requireLegal, int numTransitions, List<ParserConstraint> constraints) {
    float[] scores = new float[transitionIndex.size()];
    for (String feature : features) {
        Weight weight = featureWeights.get(feature);
        if (weight == null) {
            // Features not in our index are ignored
            continue;
        }
        weight.score(scores);
    }
    PriorityQueue<ScoredObject<Integer>> queue = new PriorityQueue<>(numTransitions + 1, ScoredComparator.ASCENDING_COMPARATOR);
    for (int i = 0; i < scores.length; ++i) {
        if (!requireLegal || transitionIndex.get(i).isLegal(state, constraints)) {
            queue.add(new ScoredObject<>(i, scores[i]));
            if (queue.size() > numTransitions) {
                queue.poll();
            }
        }
    }
    return queue;
}
Also used : ScoredObject(edu.stanford.nlp.util.ScoredObject) PriorityQueue(java.util.PriorityQueue) ParserConstraint(edu.stanford.nlp.parser.common.ParserConstraint)

Example 10 with ScoredObject

use of edu.stanford.nlp.util.ScoredObject in project CoreNLP by stanfordnlp.

the class PerceptronModel method trainTree.

/**
 * index: the tree to train
 * binarizedTrees: a list of all the training trees we know about, binarized
 * transitionLists: a list of pre-assembled transitions for the trees
 */
private TrainingResult trainTree(TrainingExample example) {
    int numCorrect = 0;
    int numWrong = 0;
    Tree tree = example.binarizedTree;
    List<TrainingUpdate> updates = Generics.newArrayList();
    Pair<Integer, Integer> firstError = null;
    IntCounter<Class<? extends Transition>> correctTransitions = new IntCounter<>();
    TwoDimensionalIntCounter<Class<? extends Transition>, Class<? extends Transition>> wrongTransitions = new TwoDimensionalIntCounter<>();
    ReorderingOracle reorderer = null;
    if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
        reorderer = new ReorderingOracle(op, rootOnlyStates);
    }
    int reorderSuccess = 0;
    int reorderFail = 0;
    if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
        if (op.trainOptions().beamSize <= 0) {
            throw new IllegalArgumentException("Illegal beam size " + op.trainOptions().beamSize);
        }
        PriorityQueue<State> agenda = new PriorityQueue<>(op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
        State goldState = example.initialStateFromGoldTagTree();
        List<Transition> transitions = example.trainTransitions();
        agenda.add(goldState);
        while (transitions.size() > 0) {
            Transition goldTransition = transitions.get(0);
            Transition highestScoringTransitionFromGoldState = null;
            double highestScoreFromGoldState = 0.0;
            PriorityQueue<State> newAgenda = new PriorityQueue<>(op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
            State highestScoringState = null;
            // keep track of the state in the current agenda which leads
            // to the highest score on the next agenda.  this will be
            // trained down assuming it is not the correct state
            State highestCurrentState = null;
            for (State currentState : agenda) {
                // TODO: can maybe speed this part up, although it doesn't seem like a critical part of the runtime
                boolean isGoldState = goldState.areTransitionsEqual(currentState);
                List<String> features = featureFactory.featurize(currentState);
                Collection<ScoredObject<Integer>> stateTransitions = findHighestScoringTransitions(currentState, features, true, op.trainOptions().beamSize, null);
                for (ScoredObject<Integer> transition : stateTransitions) {
                    State newState = transitionIndex.get(transition.object()).apply(currentState, transition.score());
                    newAgenda.add(newState);
                    if (newAgenda.size() > op.trainOptions().beamSize) {
                        newAgenda.poll();
                    }
                    if (highestScoringState == null || highestScoringState.score() < newState.score()) {
                        highestScoringState = newState;
                        highestCurrentState = currentState;
                    }
                    if (isGoldState && (highestScoringTransitionFromGoldState == null || transition.score() > highestScoreFromGoldState)) {
                        highestScoringTransitionFromGoldState = transitionIndex.get(transition.object());
                        highestScoreFromGoldState = transition.score();
                    }
                }
            }
            // state (eg one with ROOT) isn't on the agenda so it stops.
            if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM && highestScoringTransitionFromGoldState == null) {
                break;
            }
            if (highestScoringState == null) {
                System.err.println("Unable to find a best transition!");
                System.err.println("Previous agenda:");
                for (State state : agenda) {
                    System.err.println(state);
                }
                System.err.println("Gold transitions:");
                System.err.println(example.transitions);
                break;
            }
            State newGoldState = goldTransition.apply(goldState, 0.0);
            if (firstError == null && !highestScoringTransitionFromGoldState.equals(goldTransition)) {
                int predictedIndex = transitionIndex.indexOf(highestScoringTransitionFromGoldState);
                int goldIndex = transitionIndex.indexOf(goldTransition);
                if (predictedIndex < 0) {
                    throw new AssertionError("Predicted transition not in the index: " + highestScoringTransitionFromGoldState);
                }
                if (goldIndex < 0) {
                    throw new AssertionError("Gold transition not in the index: " + goldTransition);
                }
                firstError = new Pair<>(predictedIndex, goldIndex);
            }
            // otherwise, down the last transition, up the correct
            if (!newGoldState.areTransitionsEqual(highestScoringState)) {
                ++numWrong;
                wrongTransitions.incrementCount(goldTransition.getClass(), highestScoringTransitionFromGoldState.getClass());
                List<String> goldFeatures = featureFactory.featurize(goldState);
                int lastTransition = transitionIndex.indexOf(highestScoringState.transitions.peek());
                updates.add(new TrainingUpdate(featureFactory.featurize(highestCurrentState), -1, lastTransition, learningRate));
                updates.add(new TrainingUpdate(goldFeatures, transitionIndex.indexOf(goldTransition), -1, learningRate));
                if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM) {
                    // If the correct state has fallen off the agenda, break
                    if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
                        break;
                    } else {
                        transitions.remove(0);
                    }
                } else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
                    if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
                        if (!reorderer.reorder(goldState, highestScoringTransitionFromGoldState, transitions)) {
                            if (reorderSuccess == 0)
                                reorderFail = 1;
                            break;
                        }
                        newGoldState = highestScoringTransitionFromGoldState.apply(goldState);
                        if (!ShiftReduceUtils.findStateOnAgenda(newAgenda, newGoldState)) {
                            if (reorderSuccess == 0)
                                reorderFail = 1;
                            break;
                        }
                        reorderSuccess = 1;
                    } else {
                        transitions.remove(0);
                    }
                }
            } else {
                ++numCorrect;
                correctTransitions.incrementCount(goldTransition.getClass());
                transitions.remove(0);
            }
            goldState = newGoldState;
            agenda = newAgenda;
        }
    } else if (op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.EARLY_TERMINATION || op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.GOLD) {
        State state = example.initialStateFromGoldTagTree();
        List<Transition> transitions = example.trainTransitions();
        boolean keepGoing = true;
        while (transitions.size() > 0 && keepGoing) {
            Transition gold = transitions.get(0);
            int goldNum = transitionIndex.indexOf(gold);
            List<String> features = featureFactory.featurize(state);
            int predictedNum = findHighestScoringTransition(state, features, false).object();
            Transition predicted = transitionIndex.get(predictedNum);
            if (goldNum == predictedNum) {
                transitions.remove(0);
                state = gold.apply(state);
                numCorrect++;
                correctTransitions.incrementCount(gold.getClass());
            } else {
                numWrong++;
                wrongTransitions.incrementCount(gold.getClass(), predicted.getClass());
                if (firstError == null) {
                    firstError = new Pair<>(predictedNum, goldNum);
                }
                // TODO: allow weighted features, weighted training, etc
                updates.add(new TrainingUpdate(features, goldNum, predictedNum, learningRate));
                switch(op.trainOptions().trainingMethod) {
                    case EARLY_TERMINATION:
                        keepGoing = false;
                        break;
                    case GOLD:
                        transitions.remove(0);
                        state = gold.apply(state);
                        break;
                    case REORDER_ORACLE:
                        keepGoing = reorderer.reorder(state, predicted, transitions);
                        if (keepGoing) {
                            state = predicted.apply(state);
                            reorderSuccess = 1;
                        } else if (reorderSuccess == 0) {
                            reorderFail = 1;
                        }
                        break;
                    default:
                        throw new IllegalArgumentException("Unexpected method " + op.trainOptions().trainingMethod);
                }
            }
        }
    }
    return new TrainingResult(updates, numCorrect, numWrong, firstError, correctTransitions, wrongTransitions, reorderSuccess, reorderFail);
}
Also used : ScoredObject(edu.stanford.nlp.util.ScoredObject) Tree(edu.stanford.nlp.trees.Tree) ArrayList(java.util.ArrayList) List(java.util.List) TwoDimensionalIntCounter(edu.stanford.nlp.stats.TwoDimensionalIntCounter) IntCounter(edu.stanford.nlp.stats.IntCounter) Pair(edu.stanford.nlp.util.Pair) PriorityQueue(java.util.PriorityQueue) TwoDimensionalIntCounter(edu.stanford.nlp.stats.TwoDimensionalIntCounter) ParserConstraint(edu.stanford.nlp.parser.common.ParserConstraint)

Aggregations

ScoredObject (edu.stanford.nlp.util.ScoredObject)17 Tree (edu.stanford.nlp.trees.Tree)11 ArrayList (java.util.ArrayList)11 ParserConstraint (edu.stanford.nlp.parser.common.ParserConstraint)7 List (java.util.List)6 PriorityQueue (java.util.PriorityQueue)5 Pair (edu.stanford.nlp.util.Pair)4 NoSuchParseException (edu.stanford.nlp.parser.common.NoSuchParseException)3 Word (edu.stanford.nlp.ling.Word)2 ParserQuery (edu.stanford.nlp.parser.common.ParserQuery)2 TreePrint (edu.stanford.nlp.trees.TreePrint)2 Timing (edu.stanford.nlp.util.Timing)2 IOException (java.io.IOException)2 LinkedList (java.util.LinkedList)2 Map (java.util.Map)2 CoreLabel (edu.stanford.nlp.ling.CoreLabel)1 EvaluateTreebank (edu.stanford.nlp.parser.lexparser.EvaluateTreebank)1 LexicalizedParser (edu.stanford.nlp.parser.lexparser.LexicalizedParser)1 RerankingParserQuery (edu.stanford.nlp.parser.lexparser.RerankingParserQuery)1 AbstractEval (edu.stanford.nlp.parser.metrics.AbstractEval)1