Search in sources :

Example 1 with Lazy

use of edu.stanford.nlp.util.Lazy in project CoreNLP by stanfordnlp.

the class ForwardEntailerSearchProblem method searchImplementation.

/**
   * The search algorithm, starting with a full sentence and iteratively shortening it to its entailed sentences.
   *
   * @return A list of search results, corresponding to shortenings of the sentence.
   */
@SuppressWarnings("unchecked")
private List<SearchResult> searchImplementation() {
    // Pre-process the tree
    SemanticGraph parseTree = new SemanticGraph(this.parseTree);
    assert Util.isTree(parseTree);
    // (remove common determiners)
    List<String> determinerRemovals = new ArrayList<>();
    parseTree.getLeafVertices().stream().filter(vertex -> "the".equalsIgnoreCase(vertex.word()) || "a".equalsIgnoreCase(vertex.word()) || "an".equalsIgnoreCase(vertex.word()) || "this".equalsIgnoreCase(vertex.word()) || "that".equalsIgnoreCase(vertex.word()) || "those".equalsIgnoreCase(vertex.word()) || "these".equalsIgnoreCase(vertex.word())).forEach(vertex -> {
        parseTree.removeVertex(vertex);
        assert Util.isTree(parseTree);
        determinerRemovals.add("det");
    });
    // (cut conj_and nodes)
    Set<SemanticGraphEdge> andsToAdd = new HashSet<>();
    for (IndexedWord vertex : parseTree.vertexSet()) {
        if (parseTree.inDegree(vertex) > 1) {
            SemanticGraphEdge conjAnd = null;
            for (SemanticGraphEdge edge : parseTree.incomingEdgeIterable(vertex)) {
                if ("conj:and".equals(edge.getRelation().toString())) {
                    conjAnd = edge;
                }
            }
            if (conjAnd != null) {
                parseTree.removeEdge(conjAnd);
                assert Util.isTree(parseTree);
                andsToAdd.add(conjAnd);
            }
        }
    }
    // Clean the tree
    Util.cleanTree(parseTree);
    assert Util.isTree(parseTree);
    // Find the subject / object split
    // This takes max O(n^2) time, expected O(n*log(n)) time.
    // Optimal is O(n), but I'm too lazy to implement it.
    BitSet isSubject = new BitSet(256);
    for (IndexedWord vertex : parseTree.vertexSet()) {
        // Search up the tree for a subj node; if found, mark that vertex as a subject.
        Iterator<SemanticGraphEdge> incomingEdges = parseTree.incomingEdgeIterator(vertex);
        SemanticGraphEdge edge = null;
        if (incomingEdges.hasNext()) {
            edge = incomingEdges.next();
        }
        int numIters = 0;
        while (edge != null) {
            if (edge.getRelation().toString().endsWith("subj")) {
                assert vertex.index() > 0;
                isSubject.set(vertex.index() - 1);
                break;
            }
            incomingEdges = parseTree.incomingEdgeIterator(edge.getGovernor());
            if (incomingEdges.hasNext()) {
                edge = incomingEdges.next();
            } else {
                edge = null;
            }
            numIters += 1;
            if (numIters > 100) {
                //          log.error("tree has apparent depth > 100");
                return Collections.EMPTY_LIST;
            }
        }
    }
    // Outputs
    List<SearchResult> results = new ArrayList<>();
    if (!determinerRemovals.isEmpty()) {
        if (andsToAdd.isEmpty()) {
            double score = Math.pow(weights.deletionProbability("det"), (double) determinerRemovals.size());
            assert !Double.isNaN(score);
            assert !Double.isInfinite(score);
            results.add(new SearchResult(parseTree, determinerRemovals, score));
        } else {
            SemanticGraph treeWithAnds = new SemanticGraph(parseTree);
            assert Util.isTree(treeWithAnds);
            for (SemanticGraphEdge and : andsToAdd) {
                treeWithAnds.addEdge(and.getGovernor(), and.getDependent(), and.getRelation(), Double.NEGATIVE_INFINITY, false);
            }
            assert Util.isTree(treeWithAnds);
            results.add(new SearchResult(treeWithAnds, determinerRemovals, Math.pow(weights.deletionProbability("det"), (double) determinerRemovals.size())));
        }
    }
    // Initialize the search
    assert Util.isTree(parseTree);
    List<IndexedWord> topologicalVertices;
    try {
        topologicalVertices = parseTree.topologicalSort();
    } catch (IllegalStateException e) {
        //      log.info("Could not topologically sort the vertices! Using left-to-right traversal.");
        topologicalVertices = parseTree.vertexListSorted();
    }
    if (topologicalVertices.isEmpty()) {
        return results;
    }
    Stack<SearchState> fringe = new Stack<>();
    fringe.push(new SearchState(new BitSet(256), 0, parseTree, null, null, 1.0));
    // Start the search
    int numTicks = 0;
    while (!fringe.isEmpty()) {
        // Overhead with popping a node.
        if (numTicks >= maxTicks) {
            return results;
        }
        numTicks += 1;
        if (results.size() >= maxResults) {
            return results;
        }
        SearchState state = fringe.pop();
        assert state.score > 0.0;
        IndexedWord currentWord = topologicalVertices.get(state.currentIndex);
        // Push the case where we don't delete
        int nextIndex = state.currentIndex + 1;
        int numIters = 0;
        while (nextIndex < topologicalVertices.size()) {
            IndexedWord nextWord = topologicalVertices.get(nextIndex);
            assert nextWord.index() > 0;
            if (!state.deletionMask.get(nextWord.index() - 1)) {
                fringe.push(new SearchState(state.deletionMask, nextIndex, state.tree, null, state, state.score));
                break;
            } else {
                nextIndex += 1;
            }
            numIters += 1;
            if (numIters > 10000) {
                //          log.error("logic error (apparent infinite loop); returning");
                return results;
            }
        }
        // Check if we can delete this subtree
        boolean canDelete = !state.tree.getFirstRoot().equals(currentWord);
        for (SemanticGraphEdge edge : state.tree.incomingEdgeIterable(currentWord)) {
            if ("CD".equals(edge.getGovernor().tag())) {
                canDelete = false;
            } else {
                // Get token information
                CoreLabel token = edge.getDependent().backingLabel();
                OperatorSpec operator;
                NaturalLogicRelation lexicalRelation;
                Polarity tokenPolarity = token.get(NaturalLogicAnnotations.PolarityAnnotation.class);
                if (tokenPolarity == null) {
                    tokenPolarity = Polarity.DEFAULT;
                }
                // Get the relation for this deletion
                if ((operator = token.get(NaturalLogicAnnotations.OperatorAnnotation.class)) != null) {
                    lexicalRelation = operator.instance.deleteRelation;
                } else {
                    assert edge.getDependent().index() > 0;
                    lexicalRelation = NaturalLogicRelation.forDependencyDeletion(edge.getRelation().toString(), isSubject.get(edge.getDependent().index() - 1));
                }
                NaturalLogicRelation projectedRelation = tokenPolarity.projectLexicalRelation(lexicalRelation);
                // Make sure this is a valid entailment
                if (!projectedRelation.applyToTruthValue(truthOfPremise).isTrue()) {
                    canDelete = false;
                }
            }
        }
        if (canDelete) {
            // Register the deletion
            Lazy<Pair<SemanticGraph, BitSet>> treeWithDeletionsAndNewMask = Lazy.of(() -> {
                SemanticGraph impl = new SemanticGraph(state.tree);
                BitSet newMask = state.deletionMask;
                for (IndexedWord vertex : state.tree.descendants(currentWord)) {
                    impl.removeVertex(vertex);
                    assert vertex.index() > 0;
                    newMask.set(vertex.index() - 1);
                    assert newMask.get(vertex.index() - 1);
                }
                return Pair.makePair(impl, newMask);
            });
            // Compute the score of the sentence
            double newScore = state.score;
            for (SemanticGraphEdge edge : state.tree.incomingEdgeIterable(currentWord)) {
                double multiplier = weights.deletionProbability(edge, state.tree.outgoingEdgeIterable(edge.getGovernor()));
                assert !Double.isNaN(multiplier);
                assert !Double.isInfinite(multiplier);
                newScore *= multiplier;
            }
            // Register the result
            if (newScore > 0.0) {
                SemanticGraph resultTree = new SemanticGraph(treeWithDeletionsAndNewMask.get().first);
                andsToAdd.stream().filter(edge -> resultTree.containsVertex(edge.getGovernor()) && resultTree.containsVertex(edge.getDependent())).forEach(edge -> resultTree.addEdge(edge.getGovernor(), edge.getDependent(), edge.getRelation(), Double.NEGATIVE_INFINITY, false));
                results.add(new SearchResult(resultTree, aggregateDeletedEdges(state, state.tree.incomingEdgeIterable(currentWord), determinerRemovals), newScore));
                // Push the state with this subtree deleted
                nextIndex = state.currentIndex + 1;
                numIters = 0;
                while (nextIndex < topologicalVertices.size()) {
                    IndexedWord nextWord = topologicalVertices.get(nextIndex);
                    BitSet newMask = treeWithDeletionsAndNewMask.get().second;
                    SemanticGraph treeWithDeletions = treeWithDeletionsAndNewMask.get().first;
                    if (!newMask.get(nextWord.index() - 1)) {
                        assert treeWithDeletions.containsVertex(topologicalVertices.get(nextIndex));
                        fringe.push(new SearchState(newMask, nextIndex, treeWithDeletions, null, state, newScore));
                        break;
                    } else {
                        nextIndex += 1;
                    }
                    numIters += 1;
                    if (numIters > 10000) {
                        //              log.error("logic error (apparent infinite loop); returning");
                        return results;
                    }
                }
            }
        }
    }
    // Return
    return results;
}
Also used : Lazy(edu.stanford.nlp.util.Lazy) CoreLabel(edu.stanford.nlp.ling.CoreLabel) java.util(java.util) SemanticGraphEdge(edu.stanford.nlp.semgraph.SemanticGraphEdge) StringUtils(edu.stanford.nlp.util.StringUtils) Redwood(edu.stanford.nlp.util.logging.Redwood) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) Pair(edu.stanford.nlp.util.Pair) Collectors(java.util.stream.Collectors) IndexedWord(edu.stanford.nlp.ling.IndexedWord) Pair(edu.stanford.nlp.util.Pair) SemanticGraphEdge(edu.stanford.nlp.semgraph.SemanticGraphEdge) CoreLabel(edu.stanford.nlp.ling.CoreLabel) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) IndexedWord(edu.stanford.nlp.ling.IndexedWord)

Aggregations

CoreLabel (edu.stanford.nlp.ling.CoreLabel)1 IndexedWord (edu.stanford.nlp.ling.IndexedWord)1 SemanticGraph (edu.stanford.nlp.semgraph.SemanticGraph)1 SemanticGraphEdge (edu.stanford.nlp.semgraph.SemanticGraphEdge)1 Lazy (edu.stanford.nlp.util.Lazy)1 Pair (edu.stanford.nlp.util.Pair)1 StringUtils (edu.stanford.nlp.util.StringUtils)1 Redwood (edu.stanford.nlp.util.logging.Redwood)1 java.util (java.util)1 Collectors (java.util.stream.Collectors)1