Search in sources :

Example 1 with DocumentExamples

use of edu.stanford.nlp.coref.statistical.DocumentExamples in project CoreNLP by stanfordnlp.

the class FastNeuralCorefDataExporter method process.

@Override
public void process(int id, Document document) {
    JsonArrayBuilder clusters = Json.createArrayBuilder();
    for (CorefCluster gold : document.goldCorefClusters.values()) {
        JsonArrayBuilder c = Json.createArrayBuilder();
        for (Mention m : gold.corefMentions) {
            c.add(m.mentionID);
        }
        clusters.add(c.build());
    }
    goldClusterWriter.println(Json.createObjectBuilder().add(String.valueOf(id), clusters.build()).build());
    Map<Pair<Integer, Integer>, Boolean> allPairs = CorefUtils.getLabeledMentionPairs(document);
    Map<Pair<Integer, Integer>, Boolean> pairs = new HashMap<>();
    for (Map.Entry<Integer, List<Integer>> e : CorefUtils.heuristicFilter(CorefUtils.getSortedMentions(document), maxMentionDistance, maxMentionDistanceWithStringMatch).entrySet()) {
        for (int m1 : e.getValue()) {
            Pair<Integer, Integer> pair = new Pair<Integer, Integer>(m1, e.getKey());
            pairs.put(pair, allPairs.get(pair));
        }
    }
    JsonArrayBuilder sentences = Json.createArrayBuilder();
    for (CoreMap sentence : document.annotation.get(SentencesAnnotation.class)) {
        sentences.add(getSentenceArray(sentence.get(CoreAnnotations.TokensAnnotation.class)));
    }
    JsonObjectBuilder mentions = Json.createObjectBuilder();
    for (Mention m : document.predictedMentionsByID.values()) {
        Iterator<SemanticGraphEdge> iterator = m.enhancedDependency.incomingEdgeIterator(m.headIndexedWord);
        SemanticGraphEdge relation = iterator.hasNext() ? iterator.next() : null;
        String depRelation = relation == null ? "no-parent" : relation.getRelation().toString();
        String depParent = relation == null ? "<missing>" : relation.getSource().word();
        mentions.add(String.valueOf(m.mentionNum), Json.createObjectBuilder().add("doc_id", id).add("mention_id", m.mentionID).add("mention_num", m.mentionNum).add("sent_num", m.sentNum).add("start_index", m.startIndex).add("end_index", m.endIndex).add("head_index", m.headIndex).add("mention_type", m.mentionType.toString()).add("dep_relation", depRelation).add("dep_parent", depParent).add("sentence", getSentenceArray(m.sentenceWords)).build());
    }
    DocumentExamples examples = extractor.extract(0, document, pairs, compressor);
    JsonObjectBuilder mentionFeatures = Json.createObjectBuilder();
    for (Map.Entry<Integer, CompressedFeatureVector> e : examples.mentionFeatures.entrySet()) {
        JsonObjectBuilder features = Json.createObjectBuilder();
        for (int i = 0; i < e.getValue().keys.size(); i++) {
            features.add(String.valueOf(e.getValue().keys.get(i)), e.getValue().values.get(i));
        }
        mentionFeatures.add(String.valueOf(e.getKey()), features);
    }
    JsonObjectBuilder mentionPairs = Json.createObjectBuilder();
    for (Example e : examples.examples) {
        JsonObjectBuilder example = Json.createObjectBuilder().add("mid1", e.mentionId1).add("mid2", e.mentionId2);
        JsonObjectBuilder features = Json.createObjectBuilder();
        for (int i = 0; i < e.pairwiseFeatures.keys.size(); i++) {
            features.add(String.valueOf(e.pairwiseFeatures.keys.get(i)), e.pairwiseFeatures.values.get(i));
        }
        example.add("label", (int) (e.label));
        example.add("features", features);
        mentionPairs.add(String.valueOf(e.mentionId1) + " " + String.valueOf(e.mentionId2), example);
    }
    JsonObject docData = Json.createObjectBuilder().add("sentences", sentences.build()).add("mentions", mentions.build()).add("pairs", mentionPairs.build()).add("mention_features", mentionFeatures.build()).build();
    dataWriter.println(docData);
    System.out.println("Writing " + dataWriter.toString());
}
Also used : HashMap(java.util.HashMap) JsonObject(javax.json.JsonObject) DocumentExamples(edu.stanford.nlp.coref.statistical.DocumentExamples) SemanticGraphEdge(edu.stanford.nlp.semgraph.SemanticGraphEdge) CompressedFeatureVector(edu.stanford.nlp.coref.statistical.CompressedFeatureVector) CorefCluster(edu.stanford.nlp.coref.data.CorefCluster) Mention(edu.stanford.nlp.coref.data.Mention) Example(edu.stanford.nlp.coref.statistical.Example) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) List(java.util.List) JsonArrayBuilder(javax.json.JsonArrayBuilder) JsonObjectBuilder(javax.json.JsonObjectBuilder) HashMap(java.util.HashMap) Map(java.util.Map) CoreMap(edu.stanford.nlp.util.CoreMap) CoreMap(edu.stanford.nlp.util.CoreMap) Pair(edu.stanford.nlp.util.Pair)

Example 2 with DocumentExamples

use of edu.stanford.nlp.coref.statistical.DocumentExamples in project CoreNLP by stanfordnlp.

the class FastNeuralCorefAlgorithm method runCoref.

@Override
public void runCoref(Document document) {
    Map<Integer, List<Integer>> mentionToCandidateAntecedents = CorefUtils.heuristicFilter(CorefUtils.getSortedMentions(document), maxMentionDistance, maxMentionDistanceWithStringMatch);
    Map<Pair<Integer, Integer>, Boolean> mentionPairs = new HashMap<>();
    for (Map.Entry<Integer, List<Integer>> e : mentionToCandidateAntecedents.entrySet()) {
        for (int m1 : e.getValue()) {
            mentionPairs.put(new Pair<>(m1, e.getKey()), true);
        }
    }
    Compressor<String> compressor = new Compressor<>();
    DocumentExamples examples = featureExtractor.extract(0, document, mentionPairs, compressor);
    Counter<Pair<Integer, Integer>> pairwiseScores = new ClassicCounter<>();
    // We cache representations for mentions so we compute them O(n) rather than O(n^2) times
    Map<Integer, SimpleMatrix> antecedentCache = new HashMap<>();
    Map<Integer, SimpleMatrix> anaphorCache = new HashMap<>();
    // Score all mention pairs on how likely they are to be coreferent
    for (Example mentionPair : examples.examples) {
        if (Thread.interrupted()) {
            // Allow interrupting
            throw new RuntimeInterruptedException();
        }
        pairwiseScores.incrementCount(new Pair<>(mentionPair.mentionId1, mentionPair.mentionId2), model.score(document.predictedMentionsByID.get(mentionPair.mentionId1), document.predictedMentionsByID.get(mentionPair.mentionId2), compressor.uncompress(examples.mentionFeatures.get(mentionPair.mentionId1)), compressor.uncompress(examples.mentionFeatures.get(mentionPair.mentionId2)), compressor.uncompress(mentionPair.pairwiseFeatures), antecedentCache, anaphorCache));
    }
    // Score each mention for anaphoricity
    for (int anaphorId : mentionToCandidateAntecedents.keySet()) {
        if (Thread.interrupted()) {
            // Allow interrupting
            throw new RuntimeInterruptedException();
        }
        pairwiseScores.incrementCount(new Pair<>(-1, anaphorId), model.score(null, document.predictedMentionsByID.get(anaphorId), null, compressor.uncompress(examples.mentionFeatures.get(anaphorId)), null, antecedentCache, anaphorCache));
    }
    // Link each mention to the highest-scoring candidate antecedent
    for (Map.Entry<Integer, List<Integer>> e : mentionToCandidateAntecedents.entrySet()) {
        int antecedent = -1;
        int anaphor = e.getKey();
        double bestScore = pairwiseScores.getCount(new Pair<>(-1, anaphor)) - 50 * (greedyness - 0.5);
        for (int ca : e.getValue()) {
            double score = pairwiseScores.getCount(new Pair<>(ca, anaphor));
            if (score > bestScore) {
                bestScore = score;
                antecedent = ca;
            }
        }
        if (antecedent > 0) {
            CorefUtils.mergeCoreferenceClusters(new Pair<>(antecedent, anaphor), document);
        }
    }
}
Also used : HashMap(java.util.HashMap) RuntimeInterruptedException(edu.stanford.nlp.util.RuntimeInterruptedException) Compressor(edu.stanford.nlp.coref.statistical.Compressor) DocumentExamples(edu.stanford.nlp.coref.statistical.DocumentExamples) SimpleMatrix(org.ejml.simple.SimpleMatrix) Example(edu.stanford.nlp.coref.statistical.Example) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) List(java.util.List) HashMap(java.util.HashMap) Map(java.util.Map) Pair(edu.stanford.nlp.util.Pair)

Aggregations

DocumentExamples (edu.stanford.nlp.coref.statistical.DocumentExamples)2 Example (edu.stanford.nlp.coref.statistical.Example)2 Pair (edu.stanford.nlp.util.Pair)2 HashMap (java.util.HashMap)2 List (java.util.List)2 Map (java.util.Map)2 CorefCluster (edu.stanford.nlp.coref.data.CorefCluster)1 Mention (edu.stanford.nlp.coref.data.Mention)1 CompressedFeatureVector (edu.stanford.nlp.coref.statistical.CompressedFeatureVector)1 Compressor (edu.stanford.nlp.coref.statistical.Compressor)1 CoreAnnotations (edu.stanford.nlp.ling.CoreAnnotations)1 SemanticGraphEdge (edu.stanford.nlp.semgraph.SemanticGraphEdge)1 ClassicCounter (edu.stanford.nlp.stats.ClassicCounter)1 CoreMap (edu.stanford.nlp.util.CoreMap)1 RuntimeInterruptedException (edu.stanford.nlp.util.RuntimeInterruptedException)1 JsonArrayBuilder (javax.json.JsonArrayBuilder)1 JsonObject (javax.json.JsonObject)1 JsonObjectBuilder (javax.json.JsonObjectBuilder)1 SimpleMatrix (org.ejml.simple.SimpleMatrix)1