Search in sources :

Example 76 with SemanticGraph

use of edu.stanford.nlp.semgraph.SemanticGraph in project CoreNLP by stanfordnlp.

the class ClauseSplitterSearchProblem method search.

/**
   * The core implementation of the search.
   *
   * @param root The root word to search from. Traditionally, this is the root of the sentence.
   * @param candidateFragments The callback for the resulting sentence fragments.
   *                           This is a predicate of a triple of values.
   *                           The return value of the predicate determines whether we should continue searching.
   *                           The triple is a triple of
   *                           <ol>
   *                             <li>The log probability of the sentence fragment, according to the featurizer and the weights</li>
   *                             <li>The features along the path to this fragment. The last element of this is the features from the most recent step.</li>
   *                             <li>The sentence fragment. Because it is relatively expensive to compute the resulting tree, this is returned as a lazy {@link Supplier}.</li>
   *                           </ol>
   * @param classifier The classifier for whether an arc should be on the path to a clause split, a clause split itself, or neither.
   * @param featurizer The featurizer to use. Make sure this matches the weights!
   * @param actionSpace The action space we are allowed to take. Each action defines a means of splitting a clause on a dependency boundary.
   */
protected void search(// The root to search from
IndexedWord root, // The output specs
final Predicate<Triple<Double, List<Counter<String>>, Supplier<SentenceFragment>>> candidateFragments, // The learning specs
final Classifier<ClauseSplitter.ClauseClassifierLabel, String> classifier, Map<String, ? extends List<String>> hardCodedSplits, final Function<Triple<State, Action, State>, Counter<String>> featurizer, final Collection<Action> actionSpace, final int maxTicks) {
    // (the fringe)
    PriorityQueue<Pair<State, List<Counter<String>>>> fringe = new FixedPrioritiesPriorityQueue<>();
    // (avoid duplicate work)
    Set<IndexedWord> seenWords = new HashSet<>();
    State firstState = new State(null, null, -9000, null, x -> {
    }, // First state is implicitly "done"
    true);
    fringe.add(Pair.makePair(firstState, new ArrayList<>(0)), -0.0);
    int ticks = 0;
    while (!fringe.isEmpty()) {
        if (++ticks > maxTicks) {
            //        log.info("WARNING! Timed out on search with " + ticks + " ticks");
            return;
        }
        // Useful variables
        double logProbSoFar = fringe.getPriority();
        assert logProbSoFar <= 0.0;
        Pair<State, List<Counter<String>>> lastStatePair = fringe.removeFirst();
        State lastState = lastStatePair.first;
        List<Counter<String>> featuresSoFar = lastStatePair.second;
        IndexedWord rootWord = lastState.edge == null ? root : lastState.edge.getDependent();
        // Register thunk
        if (lastState.isDone) {
            if (!candidateFragments.test(Triple.makeTriple(logProbSoFar, featuresSoFar, () -> {
                SemanticGraph copy = new SemanticGraph(tree);
                lastState.thunk.andThen(x -> {
                    for (IndexedWord newTreeRoot : x.getRoots()) {
                        if (newTreeRoot != null) {
                            for (SemanticGraphEdge extraEdge : extraEdgesByGovernor.get(newTreeRoot)) {
                                assert Util.isTree(x);
                                addSubtree(x, newTreeRoot, extraEdge.getRelation().toString(), tree, extraEdge.getDependent(), tree.getIncomingEdgesSorted(newTreeRoot));
                                assert Util.isTree(x);
                            }
                        }
                    }
                }).accept(copy);
                return new SentenceFragment(copy, assumedTruth, false);
            }))) {
                break;
            }
        }
        // Find relevant auxilliary terms
        SemanticGraphEdge subjOrNull = null;
        SemanticGraphEdge objOrNull = null;
        for (SemanticGraphEdge auxEdge : tree.outgoingEdgeIterable(rootWord)) {
            String relString = auxEdge.getRelation().toString();
            if (relString.contains("obj")) {
                objOrNull = auxEdge;
            } else if (relString.contains("subj")) {
                subjOrNull = auxEdge;
            }
        }
        // For each outgoing edge...
        for (SemanticGraphEdge outgoingEdge : tree.outgoingEdgeIterable(rootWord)) {
            // This fires if the governor is an indirect speech verb, and the outgoing edge is a ccomp
            if (outgoingEdge.getRelation().toString().equals("ccomp") && ((outgoingEdge.getGovernor().lemma() != null && INDIRECT_SPEECH_LEMMAS.contains(outgoingEdge.getGovernor().lemma())) || INDIRECT_SPEECH_LEMMAS.contains(outgoingEdge.getGovernor().word()))) {
                continue;
            }
            // Get some variables
            String outgoingEdgeRelation = outgoingEdge.getRelation().toString();
            List<String> forcedArcOrder = hardCodedSplits.get(outgoingEdgeRelation);
            if (forcedArcOrder == null && outgoingEdgeRelation.contains(":")) {
                forcedArcOrder = hardCodedSplits.get(outgoingEdgeRelation.substring(0, outgoingEdgeRelation.indexOf(":")) + ":*");
            }
            boolean doneForcedArc = false;
            // For each action...
            for (Action action : (forcedArcOrder == null ? actionSpace : orderActions(actionSpace, forcedArcOrder))) {
                // Check the prerequisite
                if (!action.prerequisitesMet(tree, outgoingEdge)) {
                    continue;
                }
                if (forcedArcOrder != null && doneForcedArc) {
                    break;
                }
                // 1. Compute the child state
                Optional<State> candidate = action.applyTo(tree, lastState, outgoingEdge, subjOrNull, objOrNull);
                if (candidate.isPresent()) {
                    double logProbability;
                    ClauseClassifierLabel bestLabel;
                    Counter<String> features = featurizer.apply(Triple.makeTriple(lastState, action, candidate.get()));
                    if (forcedArcOrder != null && !doneForcedArc) {
                        logProbability = 0.0;
                        bestLabel = ClauseClassifierLabel.CLAUSE_SPLIT;
                        doneForcedArc = true;
                    } else if (features.containsKey("__undocumented_junit_no_classifier")) {
                        logProbability = Double.NEGATIVE_INFINITY;
                        bestLabel = ClauseClassifierLabel.CLAUSE_INTERM;
                    } else {
                        Counter<ClauseClassifierLabel> scores = classifier.scoresOf(new RVFDatum<>(features));
                        if (scores.size() > 0) {
                            Counters.logNormalizeInPlace(scores);
                        }
                        String rel = outgoingEdge.getRelation().toString();
                        if ("nsubj".equals(rel) || "dobj".equals(rel)) {
                            // Always at least yield on nsubj and dobj
                            scores.remove(ClauseClassifierLabel.NOT_A_CLAUSE);
                        }
                        logProbability = Counters.max(scores, Double.NEGATIVE_INFINITY);
                        bestLabel = Counters.argmax(scores, (x, y) -> 0, ClauseClassifierLabel.CLAUSE_SPLIT);
                    }
                    if (bestLabel != ClauseClassifierLabel.NOT_A_CLAUSE) {
                        Pair<State, List<Counter<String>>> childState = Pair.makePair(candidate.get().withIsDone(bestLabel), new ArrayList<Counter<String>>(featuresSoFar) {

                            {
                                add(features);
                            }
                        });
                        // 2. Register the child state
                        if (!seenWords.contains(childState.first.edge.getDependent())) {
                            //            log.info("  pushing " + action.signature() + " with " + argmax.first.edge);
                            fringe.add(childState, logProbability);
                        }
                    }
                }
            }
        }
        seenWords.add(rootWord);
    }
//    log.info("Search finished in " + ticks + " ticks and " + classifierEvals + " classifier evaluations.");
}
Also used : java.util(java.util) SemanticGraphEdge(edu.stanford.nlp.semgraph.SemanticGraphEdge) GrammaticalRelation(edu.stanford.nlp.trees.GrammaticalRelation) Counters(edu.stanford.nlp.stats.Counters) Predicate(java.util.function.Predicate) Redwood(edu.stanford.nlp.util.logging.Redwood) edu.stanford.nlp.util(edu.stanford.nlp.util) ClauseClassifierLabel(edu.stanford.nlp.naturalli.ClauseSplitter.ClauseClassifierLabel) Function(java.util.function.Function) Supplier(java.util.function.Supplier) Consumer(java.util.function.Consumer) Counter(edu.stanford.nlp.stats.Counter) Stream(java.util.stream.Stream) java.io(java.io) edu.stanford.nlp.classify(edu.stanford.nlp.classify) Language(edu.stanford.nlp.international.Language) edu.stanford.nlp.ling(edu.stanford.nlp.ling) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) PriorityQueue(edu.stanford.nlp.util.PriorityQueue) ClauseClassifierLabel(edu.stanford.nlp.naturalli.ClauseSplitter.ClauseClassifierLabel) Counter(edu.stanford.nlp.stats.Counter) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) SemanticGraphEdge(edu.stanford.nlp.semgraph.SemanticGraphEdge) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph)

Example 77 with SemanticGraph

use of edu.stanford.nlp.semgraph.SemanticGraph in project CoreNLP by stanfordnlp.

the class CreateClauseDataset method process.

@Override
public void process(long id, Annotation doc) {
    CoreMap sentence = doc.get(CoreAnnotations.SentencesAnnotation.class).get(0);
    SemanticGraph depparse = sentence.get(SemanticGraphCoreAnnotations.BasicDependenciesAnnotation.class);
    log.info("| " + sentence.get(CoreAnnotations.TextAnnotation.class));
    // Get all valid subject spans
    BitSet consumedAsSubjects = new BitSet();
    @SuppressWarnings("MismatchedQueryAndUpdateOfCollection") List<Span> subjectSpans = new ArrayList<>();
    NEXTNODE: for (IndexedWord head : depparse.topologicalSort()) {
        // Check if the node is a noun/pronoun
        if (head.tag().startsWith("N") || head.tag().equals("PRP")) {
            // Try to get the NP chunk
            Optional<List<IndexedWord>> subjectChunk = segmenter.getValidChunk(depparse, head, segmenter.VALID_SUBJECT_ARCS, Optional.empty(), true);
            if (subjectChunk.isPresent()) {
                // Make sure it's not already a member of a larger NP
                for (IndexedWord tok : subjectChunk.get()) {
                    if (consumedAsSubjects.get(tok.index())) {
                        // Already considered. Continue to the next node.
                        continue NEXTNODE;
                    }
                }
                // Register it as an NP
                for (IndexedWord tok : subjectChunk.get()) {
                    consumedAsSubjects.set(tok.index());
                }
                // Add it as a subject
                subjectSpans.add(toSpan(subjectChunk.get()));
            }
        }
    }
}
Also used : SemanticGraphCoreAnnotations(edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations) SemanticGraphCoreAnnotations(edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) IndexedWord(edu.stanford.nlp.ling.IndexedWord) Span(edu.stanford.nlp.ie.machinereading.structure.Span)

Example 78 with SemanticGraph

use of edu.stanford.nlp.semgraph.SemanticGraph in project CoreNLP by stanfordnlp.

the class ForwardEntailerSearchProblem method searchImplementation.

/**
   * The search algorithm, starting with a full sentence and iteratively shortening it to its entailed sentences.
   *
   * @return A list of search results, corresponding to shortenings of the sentence.
   */
@SuppressWarnings("unchecked")
private List<SearchResult> searchImplementation() {
    // Pre-process the tree
    SemanticGraph parseTree = new SemanticGraph(this.parseTree);
    assert Util.isTree(parseTree);
    // (remove common determiners)
    List<String> determinerRemovals = new ArrayList<>();
    parseTree.getLeafVertices().stream().filter(vertex -> "the".equalsIgnoreCase(vertex.word()) || "a".equalsIgnoreCase(vertex.word()) || "an".equalsIgnoreCase(vertex.word()) || "this".equalsIgnoreCase(vertex.word()) || "that".equalsIgnoreCase(vertex.word()) || "those".equalsIgnoreCase(vertex.word()) || "these".equalsIgnoreCase(vertex.word())).forEach(vertex -> {
        parseTree.removeVertex(vertex);
        assert Util.isTree(parseTree);
        determinerRemovals.add("det");
    });
    // (cut conj_and nodes)
    Set<SemanticGraphEdge> andsToAdd = new HashSet<>();
    for (IndexedWord vertex : parseTree.vertexSet()) {
        if (parseTree.inDegree(vertex) > 1) {
            SemanticGraphEdge conjAnd = null;
            for (SemanticGraphEdge edge : parseTree.incomingEdgeIterable(vertex)) {
                if ("conj:and".equals(edge.getRelation().toString())) {
                    conjAnd = edge;
                }
            }
            if (conjAnd != null) {
                parseTree.removeEdge(conjAnd);
                assert Util.isTree(parseTree);
                andsToAdd.add(conjAnd);
            }
        }
    }
    // Clean the tree
    Util.cleanTree(parseTree);
    assert Util.isTree(parseTree);
    // Find the subject / object split
    // This takes max O(n^2) time, expected O(n*log(n)) time.
    // Optimal is O(n), but I'm too lazy to implement it.
    BitSet isSubject = new BitSet(256);
    for (IndexedWord vertex : parseTree.vertexSet()) {
        // Search up the tree for a subj node; if found, mark that vertex as a subject.
        Iterator<SemanticGraphEdge> incomingEdges = parseTree.incomingEdgeIterator(vertex);
        SemanticGraphEdge edge = null;
        if (incomingEdges.hasNext()) {
            edge = incomingEdges.next();
        }
        int numIters = 0;
        while (edge != null) {
            if (edge.getRelation().toString().endsWith("subj")) {
                assert vertex.index() > 0;
                isSubject.set(vertex.index() - 1);
                break;
            }
            incomingEdges = parseTree.incomingEdgeIterator(edge.getGovernor());
            if (incomingEdges.hasNext()) {
                edge = incomingEdges.next();
            } else {
                edge = null;
            }
            numIters += 1;
            if (numIters > 100) {
                //          log.error("tree has apparent depth > 100");
                return Collections.EMPTY_LIST;
            }
        }
    }
    // Outputs
    List<SearchResult> results = new ArrayList<>();
    if (!determinerRemovals.isEmpty()) {
        if (andsToAdd.isEmpty()) {
            double score = Math.pow(weights.deletionProbability("det"), (double) determinerRemovals.size());
            assert !Double.isNaN(score);
            assert !Double.isInfinite(score);
            results.add(new SearchResult(parseTree, determinerRemovals, score));
        } else {
            SemanticGraph treeWithAnds = new SemanticGraph(parseTree);
            assert Util.isTree(treeWithAnds);
            for (SemanticGraphEdge and : andsToAdd) {
                treeWithAnds.addEdge(and.getGovernor(), and.getDependent(), and.getRelation(), Double.NEGATIVE_INFINITY, false);
            }
            assert Util.isTree(treeWithAnds);
            results.add(new SearchResult(treeWithAnds, determinerRemovals, Math.pow(weights.deletionProbability("det"), (double) determinerRemovals.size())));
        }
    }
    // Initialize the search
    assert Util.isTree(parseTree);
    List<IndexedWord> topologicalVertices;
    try {
        topologicalVertices = parseTree.topologicalSort();
    } catch (IllegalStateException e) {
        //      log.info("Could not topologically sort the vertices! Using left-to-right traversal.");
        topologicalVertices = parseTree.vertexListSorted();
    }
    if (topologicalVertices.isEmpty()) {
        return results;
    }
    Stack<SearchState> fringe = new Stack<>();
    fringe.push(new SearchState(new BitSet(256), 0, parseTree, null, null, 1.0));
    // Start the search
    int numTicks = 0;
    while (!fringe.isEmpty()) {
        // Overhead with popping a node.
        if (numTicks >= maxTicks) {
            return results;
        }
        numTicks += 1;
        if (results.size() >= maxResults) {
            return results;
        }
        SearchState state = fringe.pop();
        assert state.score > 0.0;
        IndexedWord currentWord = topologicalVertices.get(state.currentIndex);
        // Push the case where we don't delete
        int nextIndex = state.currentIndex + 1;
        int numIters = 0;
        while (nextIndex < topologicalVertices.size()) {
            IndexedWord nextWord = topologicalVertices.get(nextIndex);
            assert nextWord.index() > 0;
            if (!state.deletionMask.get(nextWord.index() - 1)) {
                fringe.push(new SearchState(state.deletionMask, nextIndex, state.tree, null, state, state.score));
                break;
            } else {
                nextIndex += 1;
            }
            numIters += 1;
            if (numIters > 10000) {
                //          log.error("logic error (apparent infinite loop); returning");
                return results;
            }
        }
        // Check if we can delete this subtree
        boolean canDelete = !state.tree.getFirstRoot().equals(currentWord);
        for (SemanticGraphEdge edge : state.tree.incomingEdgeIterable(currentWord)) {
            if ("CD".equals(edge.getGovernor().tag())) {
                canDelete = false;
            } else {
                // Get token information
                CoreLabel token = edge.getDependent().backingLabel();
                OperatorSpec operator;
                NaturalLogicRelation lexicalRelation;
                Polarity tokenPolarity = token.get(NaturalLogicAnnotations.PolarityAnnotation.class);
                if (tokenPolarity == null) {
                    tokenPolarity = Polarity.DEFAULT;
                }
                // Get the relation for this deletion
                if ((operator = token.get(NaturalLogicAnnotations.OperatorAnnotation.class)) != null) {
                    lexicalRelation = operator.instance.deleteRelation;
                } else {
                    assert edge.getDependent().index() > 0;
                    lexicalRelation = NaturalLogicRelation.forDependencyDeletion(edge.getRelation().toString(), isSubject.get(edge.getDependent().index() - 1));
                }
                NaturalLogicRelation projectedRelation = tokenPolarity.projectLexicalRelation(lexicalRelation);
                // Make sure this is a valid entailment
                if (!projectedRelation.applyToTruthValue(truthOfPremise).isTrue()) {
                    canDelete = false;
                }
            }
        }
        if (canDelete) {
            // Register the deletion
            Lazy<Pair<SemanticGraph, BitSet>> treeWithDeletionsAndNewMask = Lazy.of(() -> {
                SemanticGraph impl = new SemanticGraph(state.tree);
                BitSet newMask = state.deletionMask;
                for (IndexedWord vertex : state.tree.descendants(currentWord)) {
                    impl.removeVertex(vertex);
                    assert vertex.index() > 0;
                    newMask.set(vertex.index() - 1);
                    assert newMask.get(vertex.index() - 1);
                }
                return Pair.makePair(impl, newMask);
            });
            // Compute the score of the sentence
            double newScore = state.score;
            for (SemanticGraphEdge edge : state.tree.incomingEdgeIterable(currentWord)) {
                double multiplier = weights.deletionProbability(edge, state.tree.outgoingEdgeIterable(edge.getGovernor()));
                assert !Double.isNaN(multiplier);
                assert !Double.isInfinite(multiplier);
                newScore *= multiplier;
            }
            // Register the result
            if (newScore > 0.0) {
                SemanticGraph resultTree = new SemanticGraph(treeWithDeletionsAndNewMask.get().first);
                andsToAdd.stream().filter(edge -> resultTree.containsVertex(edge.getGovernor()) && resultTree.containsVertex(edge.getDependent())).forEach(edge -> resultTree.addEdge(edge.getGovernor(), edge.getDependent(), edge.getRelation(), Double.NEGATIVE_INFINITY, false));
                results.add(new SearchResult(resultTree, aggregateDeletedEdges(state, state.tree.incomingEdgeIterable(currentWord), determinerRemovals), newScore));
                // Push the state with this subtree deleted
                nextIndex = state.currentIndex + 1;
                numIters = 0;
                while (nextIndex < topologicalVertices.size()) {
                    IndexedWord nextWord = topologicalVertices.get(nextIndex);
                    BitSet newMask = treeWithDeletionsAndNewMask.get().second;
                    SemanticGraph treeWithDeletions = treeWithDeletionsAndNewMask.get().first;
                    if (!newMask.get(nextWord.index() - 1)) {
                        assert treeWithDeletions.containsVertex(topologicalVertices.get(nextIndex));
                        fringe.push(new SearchState(newMask, nextIndex, treeWithDeletions, null, state, newScore));
                        break;
                    } else {
                        nextIndex += 1;
                    }
                    numIters += 1;
                    if (numIters > 10000) {
                        //              log.error("logic error (apparent infinite loop); returning");
                        return results;
                    }
                }
            }
        }
    }
    // Return
    return results;
}
Also used : Lazy(edu.stanford.nlp.util.Lazy) CoreLabel(edu.stanford.nlp.ling.CoreLabel) java.util(java.util) SemanticGraphEdge(edu.stanford.nlp.semgraph.SemanticGraphEdge) StringUtils(edu.stanford.nlp.util.StringUtils) Redwood(edu.stanford.nlp.util.logging.Redwood) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) Pair(edu.stanford.nlp.util.Pair) Collectors(java.util.stream.Collectors) IndexedWord(edu.stanford.nlp.ling.IndexedWord) Pair(edu.stanford.nlp.util.Pair) SemanticGraphEdge(edu.stanford.nlp.semgraph.SemanticGraphEdge) CoreLabel(edu.stanford.nlp.ling.CoreLabel) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) IndexedWord(edu.stanford.nlp.ling.IndexedWord)

Example 79 with SemanticGraph

use of edu.stanford.nlp.semgraph.SemanticGraph in project CoreNLP by stanfordnlp.

the class NaturalLogicAnnotator method getGeneralizedSubtreeSpan.

/** A helper method for
   * {@link NaturalLogicAnnotator#getModifierSubtreeSpan(edu.stanford.nlp.semgraph.SemanticGraph, edu.stanford.nlp.ling.IndexedWord)} and
   * {@link NaturalLogicAnnotator#getSubtreeSpan(edu.stanford.nlp.semgraph.SemanticGraph, edu.stanford.nlp.ling.IndexedWord)}.
   */
private static Pair<Integer, Integer> getGeneralizedSubtreeSpan(SemanticGraph tree, IndexedWord root, Set<String> validArcs) {
    int min = root.index();
    int max = root.index();
    Queue<IndexedWord> fringe = new LinkedList<>();
    for (SemanticGraphEdge edge : tree.outgoingEdgeIterable(root)) {
        String edgeLabel = edge.getRelation().getShortName();
        if ((validArcs == null || validArcs.contains(edgeLabel)) && !"punct".equals(edgeLabel)) {
            fringe.add(edge.getDependent());
        }
    }
    while (!fringe.isEmpty()) {
        IndexedWord node = fringe.poll();
        min = Math.min(node.index(), min);
        max = Math.max(node.index(), max);
        // ignore punctuation
        fringe.addAll(tree.getOutEdgesSorted(node).stream().filter(edge -> edge.getGovernor().equals(node) && !(edge.getGovernor().equals(edge.getDependent())) && !"punct".equals(edge.getRelation().getShortName())).map(SemanticGraphEdge::getDependent).collect(Collectors.toList()));
    }
    return Pair.makePair(min, max + 1);
}
Also used : CoreLabel(edu.stanford.nlp.ling.CoreLabel) java.util(java.util) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) SemanticGraphEdge(edu.stanford.nlp.semgraph.SemanticGraphEdge) Redwood(edu.stanford.nlp.util.logging.Redwood) edu.stanford.nlp.util(edu.stanford.nlp.util) SentenceAnnotator(edu.stanford.nlp.pipeline.SentenceAnnotator) SemgrexMatcher(edu.stanford.nlp.semgraph.semgrex.SemgrexMatcher) NaturalLogicAnnotations(edu.stanford.nlp.naturalli.NaturalLogicAnnotations) Function(java.util.function.Function) Collectors(java.util.stream.Collectors) Span(edu.stanford.nlp.ie.machinereading.structure.Span) CoreAnnotation(edu.stanford.nlp.ling.CoreAnnotation) Annotation(edu.stanford.nlp.pipeline.Annotation) SemanticGraphCoreAnnotations(edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations) TokenSequenceMatcher(edu.stanford.nlp.ling.tokensregex.TokenSequenceMatcher) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) SemgrexPattern(edu.stanford.nlp.semgraph.semgrex.SemgrexPattern) IndexedWord(edu.stanford.nlp.ling.IndexedWord) TokenSequencePattern(edu.stanford.nlp.ling.tokensregex.TokenSequencePattern) IndexedWord(edu.stanford.nlp.ling.IndexedWord) SemanticGraphEdge(edu.stanford.nlp.semgraph.SemanticGraphEdge)

Example 80 with SemanticGraph

use of edu.stanford.nlp.semgraph.SemanticGraph in project CoreNLP by stanfordnlp.

the class ClauseSplitter method train.

/**
   * Train a clause searcher factory. That is, train a classifier for which arcs should be
   * new clauses.
   *
   * @param trainingData The training data. This is a stream of triples of:
   *                     <ol>
   *                       <li>The sentence containing a known extraction.</li>
   *                       <li>The span of the subject in the sentence, as a token span.</li>
   *                       <li>The span of the object in the sentence, as a token span.</li>
   *                     </ol>
   * @param modelPath The path to save the model to. This is useful for {@link ClauseSplitter#load(String)}.
   * @param trainingDataDump The path to save the training data, as a set of labeled featurized datums.
   * @param featurizer The featurizer to use for this classifier.
   *
   * @return A factory for creating searchers from a given dependency tree.
   */
static ClauseSplitter train(Stream<Pair<CoreMap, Collection<Pair<Span, Span>>>> trainingData, Optional<File> modelPath, Optional<File> trainingDataDump, Featurizer featurizer) {
    // Parse options
    LinearClassifierFactory<ClauseClassifierLabel, String> factory = new LinearClassifierFactory<>();
    // Generally useful objects
    OpenIE openie = new OpenIE(PropertiesUtils.asProperties("splitter.nomodel", "true", "optimizefor", "GENERAL"));
    WeightedDataset<ClauseClassifierLabel, String> dataset = new WeightedDataset<>();
    AtomicInteger numExamplesProcessed = new AtomicInteger(0);
    final Optional<PrintWriter> datasetDumpWriter = trainingDataDump.map(file -> {
        try {
            return new PrintWriter(new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(trainingDataDump.get()))));
        } catch (IOException e) {
            throw new RuntimeIOException(e);
        }
    });
    // Step 1: Loop over data
    forceTrack("Training inference");
    trainingData.forEach(rawExample -> {
        CoreMap sentence = rawExample.first;
        Collection<Pair<Span, Span>> spans = rawExample.second;
        List<CoreLabel> tokens = sentence.get(CoreAnnotations.TokensAnnotation.class);
        SemanticGraph tree = sentence.get(SemanticGraphCoreAnnotations.EnhancedDependenciesAnnotation.class);
        ClauseSplitterSearchProblem problem = new ClauseSplitterSearchProblem(tree, true);
        problem.search(fragmentAndScore -> {
            List<Counter<String>> features = fragmentAndScore.second;
            SentenceFragment fragment = fragmentAndScore.third.get();
            Set<RelationTriple> extractions = new HashSet<>(openie.relationsInFragments(openie.entailmentsFromClause(fragment)));
            Trilean correct = Trilean.FALSE;
            RELATION_TRIPLE_LOOP: for (RelationTriple extraction : extractions) {
                Span subjectGuess = Span.fromValues(extraction.subject.get(0).index() - 1, extraction.subject.get(extraction.subject.size() - 1).index());
                Span objectGuess = Span.fromValues(extraction.object.get(0).index() - 1, extraction.object.get(extraction.object.size() - 1).index());
                for (Pair<Span, Span> candidateGold : spans) {
                    Span subjectSpan = candidateGold.first;
                    Span objectSpan = candidateGold.second;
                    if ((subjectGuess.equals(subjectSpan) && objectGuess.equals(objectSpan)) || (subjectGuess.equals(objectSpan) && objectGuess.equals(subjectSpan))) {
                        correct = Trilean.TRUE;
                        break RELATION_TRIPLE_LOOP;
                    } else if (Util.nerOverlap(tokens, subjectSpan, subjectGuess) && Util.nerOverlap(tokens, objectSpan, objectGuess) || Util.nerOverlap(tokens, subjectSpan, objectGuess) && Util.nerOverlap(tokens, objectSpan, subjectGuess)) {
                        if (!correct.isTrue()) {
                            correct = Trilean.TRUE;
                            break RELATION_TRIPLE_LOOP;
                        }
                    } else {
                        if (!correct.isTrue()) {
                            correct = Trilean.UNKNOWN;
                            break RELATION_TRIPLE_LOOP;
                        }
                    }
                }
            }
            if (!features.isEmpty()) {
                List<Pair<Counter<String>, ClauseClassifierLabel>> decisionsToAddAsDatums = new ArrayList<>();
                if (correct.isTrue()) {
                    for (int i = 0; i < features.size(); ++i) {
                        if (i == features.size() - 1) {
                            decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT));
                        } else {
                            decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM));
                        }
                    }
                } else if (correct.isFalse()) {
                    decisionsToAddAsDatums.add(Pair.makePair(features.get(features.size() - 1), ClauseClassifierLabel.NOT_A_CLAUSE));
                } else if (correct.isUnknown()) {
                    boolean isSimpleSplit = false;
                    for (Counter<String> feats : features) {
                        if (featurizer.isSimpleSplit(feats)) {
                            isSimpleSplit = true;
                            break;
                        }
                    }
                    if (isSimpleSplit) {
                        for (int i = 0; i < features.size(); ++i) {
                            if (i == features.size() - 1) {
                                decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT));
                            } else {
                                decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM));
                            }
                        }
                    }
                }
                for (Pair<Counter<String>, ClauseClassifierLabel> decision : decisionsToAddAsDatums) {
                    RVFDatum<ClauseClassifierLabel, String> datum = new RVFDatum<>(decision.first);
                    datum.setLabel(decision.second);
                    if (datasetDumpWriter.isPresent()) {
                        datasetDumpWriter.get().println(decision.second + "\t" + StringUtils.join(decision.first.entrySet().stream().map(entry -> entry.getKey() + "->" + entry.getValue()), ";"));
                    }
                    dataset.add(datum);
                }
            }
            return true;
        }, new LinearClassifier<>(new ClassicCounter<>()), Collections.emptyMap(), featurizer, 10000);
        if (numExamplesProcessed.incrementAndGet() % 100 == 0) {
            log("processed " + numExamplesProcessed + " training sentences: " + dataset.size() + " datums");
        }
    });
    endTrack("Training inference");
    // Close the file
    if (datasetDumpWriter.isPresent()) {
        datasetDumpWriter.get().close();
    }
    // Step 2: Train classifier
    forceTrack("Training");
    Classifier<ClauseClassifierLabel, String> fullClassifier = factory.trainClassifier(dataset);
    endTrack("Training");
    if (modelPath.isPresent()) {
        Pair<Classifier<ClauseClassifierLabel, String>, Featurizer> toSave = Pair.makePair(fullClassifier, featurizer);
        try {
            IOUtils.writeObjectToFile(toSave, modelPath.get());
            log("SUCCESS: wrote model to " + modelPath.get().getPath());
        } catch (IOException e) {
            log("ERROR: failed to save model to path: " + modelPath.get().getPath());
            err(e);
        }
    }
    // Step 3: Check accuracy of classifier
    forceTrack("Training accuracy");
    dataset.randomize(42L);
    Util.dumpAccuracy(fullClassifier, dataset);
    endTrack("Training accuracy");
    int numFolds = 5;
    forceTrack(numFolds + " fold cross-validation");
    for (int fold = 0; fold < numFolds; ++fold) {
        forceTrack("Fold " + (fold + 1));
        forceTrack("Training");
        Pair<GeneralDataset<ClauseClassifierLabel, String>, GeneralDataset<ClauseClassifierLabel, String>> foldData = dataset.splitOutFold(fold, numFolds);
        Classifier<ClauseClassifierLabel, String> classifier = factory.trainClassifier(foldData.first);
        endTrack("Training");
        forceTrack("Test");
        Util.dumpAccuracy(classifier, foldData.second);
        endTrack("Test");
        endTrack("Fold " + (fold + 1));
    }
    endTrack(numFolds + " fold cross-validation");
    // Step 5: return factory
    return (tree, truth) -> new ClauseSplitterSearchProblem(tree, truth, Optional.of(fullClassifier), Optional.of(featurizer));
}
Also used : CoreLabel(edu.stanford.nlp.ling.CoreLabel) java.util(java.util) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) IOUtils(edu.stanford.nlp.io.IOUtils) edu.stanford.nlp.util(edu.stanford.nlp.util) BiFunction(java.util.function.BiFunction) Redwood(edu.stanford.nlp.util.logging.Redwood) Util(edu.stanford.nlp.util.logging.Redwood.Util) Span(edu.stanford.nlp.ie.machinereading.structure.Span) Counter(edu.stanford.nlp.stats.Counter) Stream(java.util.stream.Stream) java.io(java.io) SemanticGraphCoreAnnotations(edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) edu.stanford.nlp.classify(edu.stanford.nlp.classify) RelationTriple(edu.stanford.nlp.ie.util.RelationTriple) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) GZIPOutputStream(java.util.zip.GZIPOutputStream) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) RVFDatum(edu.stanford.nlp.ling.RVFDatum) ClauseSplitterSearchProblem(edu.stanford.nlp.naturalli.ClauseSplitterSearchProblem) Counter(edu.stanford.nlp.stats.Counter) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) GZIPOutputStream(java.util.zip.GZIPOutputStream) RelationTriple(edu.stanford.nlp.ie.util.RelationTriple) ClauseSplitterSearchProblem(edu.stanford.nlp.naturalli.ClauseSplitterSearchProblem) SemanticGraphCoreAnnotations(edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations) CoreLabel(edu.stanford.nlp.ling.CoreLabel) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) Span(edu.stanford.nlp.ie.machinereading.structure.Span) RVFDatum(edu.stanford.nlp.ling.RVFDatum) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) SemanticGraphCoreAnnotations(edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter)

Aggregations

SemanticGraph (edu.stanford.nlp.semgraph.SemanticGraph)126 IndexedWord (edu.stanford.nlp.ling.IndexedWord)57 SemanticGraphCoreAnnotations (edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations)53 CoreLabel (edu.stanford.nlp.ling.CoreLabel)51 CoreAnnotations (edu.stanford.nlp.ling.CoreAnnotations)47 SemanticGraphEdge (edu.stanford.nlp.semgraph.SemanticGraphEdge)24 Tree (edu.stanford.nlp.trees.Tree)20 CoreMap (edu.stanford.nlp.util.CoreMap)19 TreeCoreAnnotations (edu.stanford.nlp.trees.TreeCoreAnnotations)18 SemgrexMatcher (edu.stanford.nlp.semgraph.semgrex.SemgrexMatcher)16 GrammaticalRelation (edu.stanford.nlp.trees.GrammaticalRelation)16 Annotation (edu.stanford.nlp.pipeline.Annotation)14 SemgrexPattern (edu.stanford.nlp.semgraph.semgrex.SemgrexPattern)12 ArrayList (java.util.ArrayList)12 Mention (edu.stanford.nlp.coref.data.Mention)11 java.util (java.util)11 edu.stanford.nlp.util (edu.stanford.nlp.util)10 Properties (java.util.Properties)9 Collectors (java.util.stream.Collectors)9 CorefCoreAnnotations (edu.stanford.nlp.coref.CorefCoreAnnotations)8