Search in sources :

Example 1 with PriorityQueue

use of edu.stanford.nlp.util.PriorityQueue in project CoreNLP by stanfordnlp.

the class ClauseSplitterSearchProblem method search.

/**
   * The core implementation of the search.
   *
   * @param root The root word to search from. Traditionally, this is the root of the sentence.
   * @param candidateFragments The callback for the resulting sentence fragments.
   *                           This is a predicate of a triple of values.
   *                           The return value of the predicate determines whether we should continue searching.
   *                           The triple is a triple of
   *                           <ol>
   *                             <li>The log probability of the sentence fragment, according to the featurizer and the weights</li>
   *                             <li>The features along the path to this fragment. The last element of this is the features from the most recent step.</li>
   *                             <li>The sentence fragment. Because it is relatively expensive to compute the resulting tree, this is returned as a lazy {@link Supplier}.</li>
   *                           </ol>
   * @param classifier The classifier for whether an arc should be on the path to a clause split, a clause split itself, or neither.
   * @param featurizer The featurizer to use. Make sure this matches the weights!
   * @param actionSpace The action space we are allowed to take. Each action defines a means of splitting a clause on a dependency boundary.
   */
protected void search(// The root to search from
IndexedWord root, // The output specs
final Predicate<Triple<Double, List<Counter<String>>, Supplier<SentenceFragment>>> candidateFragments, // The learning specs
final Classifier<ClauseSplitter.ClauseClassifierLabel, String> classifier, Map<String, ? extends List<String>> hardCodedSplits, final Function<Triple<State, Action, State>, Counter<String>> featurizer, final Collection<Action> actionSpace, final int maxTicks) {
    // (the fringe)
    PriorityQueue<Pair<State, List<Counter<String>>>> fringe = new FixedPrioritiesPriorityQueue<>();
    // (avoid duplicate work)
    Set<IndexedWord> seenWords = new HashSet<>();
    State firstState = new State(null, null, -9000, null, x -> {
    }, // First state is implicitly "done"
    true);
    fringe.add(Pair.makePair(firstState, new ArrayList<>(0)), -0.0);
    int ticks = 0;
    while (!fringe.isEmpty()) {
        if (++ticks > maxTicks) {
            //        log.info("WARNING! Timed out on search with " + ticks + " ticks");
            return;
        }
        // Useful variables
        double logProbSoFar = fringe.getPriority();
        assert logProbSoFar <= 0.0;
        Pair<State, List<Counter<String>>> lastStatePair = fringe.removeFirst();
        State lastState = lastStatePair.first;
        List<Counter<String>> featuresSoFar = lastStatePair.second;
        IndexedWord rootWord = lastState.edge == null ? root : lastState.edge.getDependent();
        // Register thunk
        if (lastState.isDone) {
            if (!candidateFragments.test(Triple.makeTriple(logProbSoFar, featuresSoFar, () -> {
                SemanticGraph copy = new SemanticGraph(tree);
                lastState.thunk.andThen(x -> {
                    for (IndexedWord newTreeRoot : x.getRoots()) {
                        if (newTreeRoot != null) {
                            for (SemanticGraphEdge extraEdge : extraEdgesByGovernor.get(newTreeRoot)) {
                                assert Util.isTree(x);
                                addSubtree(x, newTreeRoot, extraEdge.getRelation().toString(), tree, extraEdge.getDependent(), tree.getIncomingEdgesSorted(newTreeRoot));
                                assert Util.isTree(x);
                            }
                        }
                    }
                }).accept(copy);
                return new SentenceFragment(copy, assumedTruth, false);
            }))) {
                break;
            }
        }
        // Find relevant auxilliary terms
        SemanticGraphEdge subjOrNull = null;
        SemanticGraphEdge objOrNull = null;
        for (SemanticGraphEdge auxEdge : tree.outgoingEdgeIterable(rootWord)) {
            String relString = auxEdge.getRelation().toString();
            if (relString.contains("obj")) {
                objOrNull = auxEdge;
            } else if (relString.contains("subj")) {
                subjOrNull = auxEdge;
            }
        }
        // For each outgoing edge...
        for (SemanticGraphEdge outgoingEdge : tree.outgoingEdgeIterable(rootWord)) {
            // This fires if the governor is an indirect speech verb, and the outgoing edge is a ccomp
            if (outgoingEdge.getRelation().toString().equals("ccomp") && ((outgoingEdge.getGovernor().lemma() != null && INDIRECT_SPEECH_LEMMAS.contains(outgoingEdge.getGovernor().lemma())) || INDIRECT_SPEECH_LEMMAS.contains(outgoingEdge.getGovernor().word()))) {
                continue;
            }
            // Get some variables
            String outgoingEdgeRelation = outgoingEdge.getRelation().toString();
            List<String> forcedArcOrder = hardCodedSplits.get(outgoingEdgeRelation);
            if (forcedArcOrder == null && outgoingEdgeRelation.contains(":")) {
                forcedArcOrder = hardCodedSplits.get(outgoingEdgeRelation.substring(0, outgoingEdgeRelation.indexOf(":")) + ":*");
            }
            boolean doneForcedArc = false;
            // For each action...
            for (Action action : (forcedArcOrder == null ? actionSpace : orderActions(actionSpace, forcedArcOrder))) {
                // Check the prerequisite
                if (!action.prerequisitesMet(tree, outgoingEdge)) {
                    continue;
                }
                if (forcedArcOrder != null && doneForcedArc) {
                    break;
                }
                // 1. Compute the child state
                Optional<State> candidate = action.applyTo(tree, lastState, outgoingEdge, subjOrNull, objOrNull);
                if (candidate.isPresent()) {
                    double logProbability;
                    ClauseClassifierLabel bestLabel;
                    Counter<String> features = featurizer.apply(Triple.makeTriple(lastState, action, candidate.get()));
                    if (forcedArcOrder != null && !doneForcedArc) {
                        logProbability = 0.0;
                        bestLabel = ClauseClassifierLabel.CLAUSE_SPLIT;
                        doneForcedArc = true;
                    } else if (features.containsKey("__undocumented_junit_no_classifier")) {
                        logProbability = Double.NEGATIVE_INFINITY;
                        bestLabel = ClauseClassifierLabel.CLAUSE_INTERM;
                    } else {
                        Counter<ClauseClassifierLabel> scores = classifier.scoresOf(new RVFDatum<>(features));
                        if (scores.size() > 0) {
                            Counters.logNormalizeInPlace(scores);
                        }
                        String rel = outgoingEdge.getRelation().toString();
                        if ("nsubj".equals(rel) || "dobj".equals(rel)) {
                            // Always at least yield on nsubj and dobj
                            scores.remove(ClauseClassifierLabel.NOT_A_CLAUSE);
                        }
                        logProbability = Counters.max(scores, Double.NEGATIVE_INFINITY);
                        bestLabel = Counters.argmax(scores, (x, y) -> 0, ClauseClassifierLabel.CLAUSE_SPLIT);
                    }
                    if (bestLabel != ClauseClassifierLabel.NOT_A_CLAUSE) {
                        Pair<State, List<Counter<String>>> childState = Pair.makePair(candidate.get().withIsDone(bestLabel), new ArrayList<Counter<String>>(featuresSoFar) {

                            {
                                add(features);
                            }
                        });
                        // 2. Register the child state
                        if (!seenWords.contains(childState.first.edge.getDependent())) {
                            //            log.info("  pushing " + action.signature() + " with " + argmax.first.edge);
                            fringe.add(childState, logProbability);
                        }
                    }
                }
            }
        }
        seenWords.add(rootWord);
    }
//    log.info("Search finished in " + ticks + " ticks and " + classifierEvals + " classifier evaluations.");
}
Also used : java.util(java.util) SemanticGraphEdge(edu.stanford.nlp.semgraph.SemanticGraphEdge) GrammaticalRelation(edu.stanford.nlp.trees.GrammaticalRelation) Counters(edu.stanford.nlp.stats.Counters) Predicate(java.util.function.Predicate) Redwood(edu.stanford.nlp.util.logging.Redwood) edu.stanford.nlp.util(edu.stanford.nlp.util) ClauseClassifierLabel(edu.stanford.nlp.naturalli.ClauseSplitter.ClauseClassifierLabel) Function(java.util.function.Function) Supplier(java.util.function.Supplier) Consumer(java.util.function.Consumer) Counter(edu.stanford.nlp.stats.Counter) Stream(java.util.stream.Stream) java.io(java.io) edu.stanford.nlp.classify(edu.stanford.nlp.classify) Language(edu.stanford.nlp.international.Language) edu.stanford.nlp.ling(edu.stanford.nlp.ling) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) PriorityQueue(edu.stanford.nlp.util.PriorityQueue) ClauseClassifierLabel(edu.stanford.nlp.naturalli.ClauseSplitter.ClauseClassifierLabel) Counter(edu.stanford.nlp.stats.Counter) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) SemanticGraphEdge(edu.stanford.nlp.semgraph.SemanticGraphEdge) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph)

Aggregations

edu.stanford.nlp.classify (edu.stanford.nlp.classify)1 Language (edu.stanford.nlp.international.Language)1 edu.stanford.nlp.ling (edu.stanford.nlp.ling)1 ClauseClassifierLabel (edu.stanford.nlp.naturalli.ClauseSplitter.ClauseClassifierLabel)1 SemanticGraph (edu.stanford.nlp.semgraph.SemanticGraph)1 SemanticGraphEdge (edu.stanford.nlp.semgraph.SemanticGraphEdge)1 ClassicCounter (edu.stanford.nlp.stats.ClassicCounter)1 Counter (edu.stanford.nlp.stats.Counter)1 Counters (edu.stanford.nlp.stats.Counters)1 GrammaticalRelation (edu.stanford.nlp.trees.GrammaticalRelation)1 edu.stanford.nlp.util (edu.stanford.nlp.util)1 PriorityQueue (edu.stanford.nlp.util.PriorityQueue)1 Redwood (edu.stanford.nlp.util.logging.Redwood)1 java.io (java.io)1 java.util (java.util)1 Consumer (java.util.function.Consumer)1 Function (java.util.function.Function)1 Predicate (java.util.function.Predicate)1 Supplier (java.util.function.Supplier)1 Stream (java.util.stream.Stream)1