Search in sources :

Example 1 with RVFDataset

use of edu.stanford.nlp.classify.RVFDataset in project CoreNLP by stanfordnlp.

the class EntityClassifier method train.

private static void train(List<SceneGraphImage> images, String modelPath, Embedding embeddings) throws IOException {
    RVFDataset<String, String> dataset = new RVFDataset<String, String>();
    SceneGraphSentenceMatcher sentenceMatcher = new SceneGraphSentenceMatcher(embeddings);
    for (SceneGraphImage img : images) {
        for (SceneGraphImageRegion region : img.regions) {
            SemanticGraph sg = region.getEnhancedSemanticGraph();
            SemanticGraphEnhancer.enhance(sg);
            List<Triple<IndexedWord, IndexedWord, String>> relationTriples = sentenceMatcher.getRelationTriples(region);
            for (Triple<IndexedWord, IndexedWord, String> relation : relationTriples) {
                IndexedWord w1 = sg.getNodeByIndexSafe(relation.first.index());
                if (w1 != null) {
                    dataset.add(getDatum(w1, relation.first.get(SceneGraphCoreAnnotations.GoldEntityAnnotation.class), embeddings));
                }
            }
        }
    }
    LinearClassifierFactory<String, String> classifierFactory = new LinearClassifierFactory<String, String>(new QNMinimizer(15), 1e-4, false, REG_STRENGTH);
    Classifier<String, String> classifier = classifierFactory.trainClassifier(dataset);
    IOUtils.writeObjectToFile(classifier, modelPath);
    System.err.println(classifier.evaluateAccuracy(dataset));
}
Also used : RVFDataset(edu.stanford.nlp.classify.RVFDataset) SceneGraphImage(edu.stanford.nlp.scenegraph.image.SceneGraphImage) QNMinimizer(edu.stanford.nlp.optimization.QNMinimizer) Triple(edu.stanford.nlp.util.Triple) LinearClassifierFactory(edu.stanford.nlp.classify.LinearClassifierFactory) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) IndexedWord(edu.stanford.nlp.ling.IndexedWord) SceneGraphImageRegion(edu.stanford.nlp.scenegraph.image.SceneGraphImageRegion)

Example 2 with RVFDataset

use of edu.stanford.nlp.classify.RVFDataset in project CoreNLP by stanfordnlp.

the class SupervisedSieveTraining method featurize.

// goldList null if not training
public static FeaturesData featurize(SieveData sd, List<XMLToAnnotation.GoldQuoteInfo> goldList, boolean isTraining) {
    Annotation doc = sd.doc;
    // use to access functions
    Sieve sieve = new Sieve(doc, sd.characterMap, sd.pronounCorefMap, sd.animacyList);
    List<CoreMap> quotes = doc.get(CoreAnnotations.QuotationsAnnotation.class);
    List<CoreMap> sentences = doc.get(CoreAnnotations.SentencesAnnotation.class);
    List<CoreLabel> tokens = doc.get(CoreAnnotations.TokensAnnotation.class);
    Map<Integer, List<CoreMap>> paragraphToQuotes = getQuotesInParagraph(doc);
    GeneralDataset<String, String> dataset = new RVFDataset<>();
    // necessary for 'ScoreBestMention'
    // maps quote to corresponding indices in the dataset
    Map<Integer, Pair<Integer, Integer>> mapQuoteToDataRange = new HashMap<>();
    Map<Integer, Sieve.MentionData> mapDatumToMention = new HashMap<>();
    if (isTraining && goldList.size() != quotes.size()) {
        throw new RuntimeException("Gold Quote List size doesn't match quote list size!");
    }
    for (int quoteIdx = 0; quoteIdx < quotes.size(); quoteIdx++) {
        int initialSize = dataset.size();
        CoreMap quote = quotes.get(quoteIdx);
        XMLToAnnotation.GoldQuoteInfo gold = null;
        if (isTraining) {
            gold = goldList.get(quoteIdx);
            if (gold.speaker.isEmpty()) {
                continue;
            }
        }
        CoreMap quoteFirstSentence = sentences.get(quote.get(CoreAnnotations.SentenceBeginAnnotation.class));
        Pair<Integer, Integer> quoteRun = new Pair<>(quote.get(CoreAnnotations.TokenBeginAnnotation.class), quote.get(CoreAnnotations.TokenEndAnnotation.class));
        // int quoteChapter = quoteFirstSentence.get(ChapterAnnotator.ChapterAnnotation.class);
        int quoteParagraphIdx = quoteFirstSentence.get(CoreAnnotations.ParagraphIndexAnnotation.class);
        // add mentions before quote up to the previous paragraph
        int rightValue = quoteRun.first - 1;
        int leftValue = quoteRun.first - 1;
        // move left value to be the first token idx of the previous paragraph
        for (int sentIdx = quote.get(CoreAnnotations.SentenceBeginAnnotation.class); sentIdx >= 0; sentIdx--) {
            CoreMap sentence = sentences.get(sentIdx);
            if (sentence.get(CoreAnnotations.ParagraphIndexAnnotation.class) == quoteParagraphIdx) {
                continue;
            }
            if (sentence.get(CoreAnnotations.ParagraphIndexAnnotation.class) == quoteParagraphIdx - 1) {
                // quoteParagraphIdx - 1 for this and prev
                leftValue = sentence.get(CoreAnnotations.TokenBeginAnnotation.class);
            } else {
                break;
            }
        }
        List<Sieve.MentionData> mentionsInPreviousParagraph = new ArrayList<>();
        if (leftValue > -1 && rightValue > -1)
            mentionsInPreviousParagraph = eliminateDuplicates(sieve.findClosestMentionsInSpanBackward(new Pair<>(leftValue, rightValue)));
        // mentions in next paragraph
        leftValue = quoteRun.second + 1;
        rightValue = quoteRun.second + 1;
        for (int sentIdx = quote.get(CoreAnnotations.SentenceEndAnnotation.class); sentIdx < sentences.size(); sentIdx++) {
            CoreMap sentence = sentences.get(sentIdx);
            // }
            if (sentence.get(CoreAnnotations.ParagraphIndexAnnotation.class) == quoteParagraphIdx) {
                // quoteParagraphIdx + 1
                rightValue = sentence.get(CoreAnnotations.TokenEndAnnotation.class) - 1;
            } else {
                break;
            }
        }
        List<Sieve.MentionData> mentionsInNextParagraph = new ArrayList<>();
        if (leftValue < tokens.size() && rightValue < tokens.size())
            mentionsInNextParagraph = sieve.findClosestMentionsInSpanForward(new Pair<>(leftValue, rightValue));
        List<Sieve.MentionData> candidateMentions = new ArrayList<>();
        candidateMentions.addAll(mentionsInPreviousParagraph);
        candidateMentions.addAll(mentionsInNextParagraph);
        // System.out.println(candidateMentions.size());
        int rankedDistance = 1;
        int numBackwards = mentionsInPreviousParagraph.size();
        for (Sieve.MentionData mention : candidateMentions) {
            // List<CoreLabel> mentionCandidateTokens = doc.get(CoreAnnotations.TokensAnnotation.class).subList(mention.begin, mention.end + 1);
            // CoreMap mentionCandidateSentence = sentences.get(mentionCandidateTokens.get(0).sentIndex());
            // if (mentionCandidateSentence.get(ChapterAnnotator.ChapterAnnotation.class) != quoteChapter) {
            // continue;
            // }
            Counter<String> features = new ClassicCounter<>();
            boolean isLeft = true;
            int distance = quoteRun.first - mention.end;
            if (distance < 0) {
                isLeft = false;
                distance = mention.begin - quoteRun.second;
            }
            if (distance < 0) {
                // disregard mention-in-quote cases.
                continue;
            }
            features.setCount("wordDistance", distance);
            List<CoreLabel> betweenTokens;
            if (isLeft) {
                betweenTokens = tokens.subList(mention.end + 1, quoteRun.first);
            } else {
                betweenTokens = tokens.subList(quoteRun.second + 1, mention.begin);
            }
            // Punctuation in between
            for (CoreLabel token : betweenTokens) {
                if (punctuation.contains(token.word())) {
                    features.setCount("punctuationPresence:" + token.word(), 1);
                }
            }
            // number of mentions away
            features.setCount("rankedDistance", rankedDistance);
            rankedDistance++;
            if (rankedDistance == numBackwards) {
                // reset for the forward
                rankedDistance = 1;
            }
            // int quoteParagraphIdx = quoteFirstSentence.get(CoreAnnotations.ParagraphIndexAnnotation.class);
            // third distance: # of paragraphs away
            int mentionParagraphIdx = -1;
            CoreMap sentenceInMentionParagraph = null;
            int quoteParagraphBeginToken = getParagraphBeginToken(quoteFirstSentence, sentences);
            int quoteParagraphEndToken = getParagraphEndToken(quoteFirstSentence, sentences);
            if (isLeft) {
                if (quoteParagraphBeginToken <= mention.begin && mention.end <= quoteParagraphEndToken) {
                    features.setCount("leftParagraphDistance", 0);
                    mentionParagraphIdx = quoteParagraphIdx;
                    sentenceInMentionParagraph = quoteFirstSentence;
                } else {
                    int paragraphDistance = 1;
                    int currParagraphIdx = quoteParagraphIdx - paragraphDistance;
                    CoreMap currSentence = quoteFirstSentence;
                    int currSentenceIdx = currSentence.get(CoreAnnotations.SentenceIndexAnnotation.class);
                    while (currParagraphIdx >= 0) {
                        // extract begin and end tokens of
                        while (currSentence.get(CoreAnnotations.ParagraphIndexAnnotation.class) != currParagraphIdx) {
                            currSentenceIdx--;
                            currSentence = sentences.get(currSentenceIdx);
                        }
                        int prevParagraphBegin = getParagraphBeginToken(currSentence, sentences);
                        int prevParagraphEnd = getParagraphEndToken(currSentence, sentences);
                        if (prevParagraphBegin <= mention.begin && mention.end <= prevParagraphEnd) {
                            mentionParagraphIdx = currParagraphIdx;
                            sentenceInMentionParagraph = currSentence;
                            features.setCount("leftParagraphDistance", paragraphDistance);
                            if (paragraphDistance % 2 == 0)
                                features.setCount("leftParagraphDistanceEven", 1);
                            break;
                        }
                        paragraphDistance++;
                        currParagraphIdx--;
                    }
                }
            } else // right
            {
                if (quoteParagraphBeginToken <= mention.begin && mention.end <= quoteParagraphEndToken) {
                    features.setCount("rightParagraphDistance", 0);
                    sentenceInMentionParagraph = quoteFirstSentence;
                    mentionParagraphIdx = quoteParagraphIdx;
                } else {
                    int paragraphDistance = 1;
                    int nextParagraphIndex = quoteParagraphIdx + paragraphDistance;
                    CoreMap currSentence = quoteFirstSentence;
                    int currSentenceIdx = currSentence.get(CoreAnnotations.SentenceIndexAnnotation.class);
                    while (currSentenceIdx < sentences.size()) {
                        while (currSentence.get(CoreAnnotations.ParagraphIndexAnnotation.class) != nextParagraphIndex) {
                            currSentenceIdx++;
                            currSentence = sentences.get(currSentenceIdx);
                        }
                        int nextParagraphBegin = getParagraphBeginToken(currSentence, sentences);
                        int nextParagraphEnd = getParagraphEndToken(currSentence, sentences);
                        if (nextParagraphBegin <= mention.begin && mention.end <= nextParagraphEnd) {
                            sentenceInMentionParagraph = currSentence;
                            features.setCount("rightParagraphDistance", paragraphDistance);
                            break;
                        }
                        paragraphDistance++;
                        nextParagraphIndex++;
                    }
                }
            }
            // 2. mention features
            if (sentenceInMentionParagraph != null) {
                int mentionParagraphBegin = getParagraphBeginToken(sentenceInMentionParagraph, sentences);
                int mentionParagraphEnd = getParagraphEndToken(sentenceInMentionParagraph, sentences);
                if (!(mentionParagraphBegin == quoteParagraphBeginToken && mentionParagraphEnd == quoteParagraphEndToken)) {
                    List<CoreMap> quotesInMentionParagraph = paragraphToQuotes.getOrDefault(mentionParagraphIdx, new ArrayList<>());
                    Pair<ArrayList<String>, ArrayList<Pair<Integer, Integer>>> namesInMentionParagraph = sieve.scanForNames(new Pair<>(mentionParagraphBegin, mentionParagraphEnd));
                    features.setCount("quotesInMentionParagraph", quotesInMentionParagraph.size());
                    features.setCount("wordsInMentionParagraph", mentionParagraphEnd - mentionParagraphBegin + 1);
                    features.setCount("namesInMentionParagraph", namesInMentionParagraph.first.size());
                    // mention ordering in paragraph it is in
                    for (int i = 0; i < namesInMentionParagraph.second.size(); i++) {
                        if (ExtractQuotesUtil.rangeContains(new Pair<>(mention.begin, mention.end), namesInMentionParagraph.second.get(i)))
                            features.setCount("orderInParagraph", i);
                    }
                    if (quotesInMentionParagraph.size() == 1) {
                        CoreMap qInMentionParagraph = quotesInMentionParagraph.get(0);
                        if (qInMentionParagraph.get(CoreAnnotations.TokenBeginAnnotation.class) == mentionParagraphBegin && qInMentionParagraph.get(CoreAnnotations.TokenEndAnnotation.class) - 1 == mentionParagraphEnd) {
                            features.setCount("mentionParagraphIsInConversation", 1);
                        } else {
                            features.setCount("mentionParagraphIsInConversation", -1);
                        }
                    }
                    for (CoreMap quoteIMP : quotesInMentionParagraph) {
                        if (ExtractQuotesUtil.rangeContains(new Pair<>(quoteIMP.get(CoreAnnotations.TokenBeginAnnotation.class), quoteIMP.get(CoreAnnotations.TokenEndAnnotation.class) - 1), new Pair<>(mention.begin, mention.end)))
                            features.setCount("mentionInQuote", 1);
                    }
                    if (features.getCount("mentionInQuote") != 1)
                        features.setCount("mentionNotInQuote", 1);
                }
            }
            // or there will be an array index crash
            if (mention.begin > 0) {
                CoreLabel prevWord = tokens.get(mention.begin - 1);
                features.setCount("prevWordType:" + prevWord.tag(), 1);
                if (punctuationForFeatures.contains(prevWord.lemma()))
                    features.setCount("prevWordPunct:" + prevWord.lemma(), 1);
            }
            if (mention.end + 1 < tokens.size()) {
                CoreLabel nextWord = tokens.get(mention.end + 1);
                features.setCount("nextWordType:" + nextWord.tag(), 1);
                if (punctuationForFeatures.contains(nextWord.lemma()))
                    features.setCount("nextWordPunct:" + nextWord.lemma(), 1);
            }
            // features.setCount("prevAndNext:" + prevWord.tag()+ ";" + nextWord.tag(), 1);
            // quote paragraph features
            List<CoreMap> quotesInQuoteParagraph = paragraphToQuotes.get(quoteParagraphIdx);
            features.setCount("QuotesInQuoteParagraph", quotesInQuoteParagraph.size());
            features.setCount("WordsInQuoteParagraph", quoteParagraphEndToken - quoteParagraphBeginToken + 1);
            features.setCount("NamesInQuoteParagraph", sieve.scanForNames(new Pair<>(quoteParagraphBeginToken, quoteParagraphEndToken)).first.size());
            // quote features
            features.setCount("quoteLength", quote.get(CoreAnnotations.TokenEndAnnotation.class) - quote.get(CoreAnnotations.TokenBeginAnnotation.class) + 1);
            for (int i = 0; i < quotesInQuoteParagraph.size(); i++) {
                if (quotesInQuoteParagraph.get(i).equals(quote)) {
                    features.setCount("quotePosition", i + 1);
                }
            }
            if (features.getCount("quotePosition") == 0)
                throw new RuntimeException("Check this (equality not working)");
            Pair<ArrayList<String>, ArrayList<Pair<Integer, Integer>>> namesData = sieve.scanForNames(quoteRun);
            for (String name : namesData.first) {
                features.setCount("charactersInQuote:" + sd.characterMap.get(name).get(0).name, 1);
            }
            // if quote encompasses entire paragraph
            if (quote.get(CoreAnnotations.TokenBeginAnnotation.class) == quoteParagraphBeginToken && quote.get(CoreAnnotations.TokenEndAnnotation.class) == quoteParagraphEndToken) {
                features.setCount("isImplicitSpeaker", 1);
            } else {
                features.setCount("isImplicitSpeaker", -1);
            }
            // Vocative detection
            if (mention.type.equals("name")) {
                List<Person> pList = sd.characterMap.get(sieve.tokenRangeToString(new Pair<>(mention.begin, mention.end)));
                Person p = null;
                if (pList != null)
                    p = pList.get(0);
                else {
                    Pair<ArrayList<String>, ArrayList<Pair<Integer, Integer>>> scanForNamesResultPair = sieve.scanForNames(new Pair<>(mention.begin, mention.end));
                    if (scanForNamesResultPair.first.size() != 0) {
                        String scanForNamesResultString = scanForNamesResultPair.first.get(0);
                        if (scanForNamesResultString != null && sd.characterMap.containsKey(scanForNamesResultString)) {
                            p = sd.characterMap.get(scanForNamesResultString).get(0);
                        }
                    }
                }
                if (p != null) {
                    for (String name : namesData.first) {
                        if (p.aliases.contains(name))
                            features.setCount("nameInQuote", 1);
                    }
                    if (quoteParagraphIdx > 0) {
                        // Paragraph prevParagraph = paragraphs.get(ex.paragraph_idx - 1);
                        List<CoreMap> quotesInPrevParagraph = paragraphToQuotes.getOrDefault(quoteParagraphIdx - 1, new ArrayList<>());
                        List<Pair<Integer, Integer>> exclusionList = new ArrayList<>();
                        for (CoreMap quoteIPP : quotesInPrevParagraph) {
                            Pair<Integer, Integer> quoteRange = new Pair<>(quoteIPP.get(CoreAnnotations.TokenBeginAnnotation.class), quoteIPP.get(CoreAnnotations.TokenEndAnnotation.class));
                            exclusionList.add(quoteRange);
                            for (String name : sieve.scanForNames(quoteRange).first) {
                                if (p.aliases.contains(name))
                                    features.setCount("nameInPrevParagraphQuote", 1);
                            }
                        }
                        int sentenceIdx = quoteFirstSentence.get(CoreAnnotations.SentenceIndexAnnotation.class);
                        CoreMap sentenceInPrevParagraph = null;
                        for (int i = sentenceIdx - 1; i >= 0; i--) {
                            CoreMap currSentence = sentences.get(i);
                            if (currSentence.get(CoreAnnotations.ParagraphIndexAnnotation.class) == quoteParagraphIdx - 1) {
                                sentenceInPrevParagraph = currSentence;
                                break;
                            }
                        }
                        int prevParagraphBegin = getParagraphBeginToken(sentenceInPrevParagraph, sentences);
                        int prevParagraphEnd = getParagraphEndToken(sentenceInPrevParagraph, sentences);
                        List<Pair<Integer, Integer>> prevParagraphNonQuoteRuns = getRangeExclusion(new Pair<>(prevParagraphBegin, prevParagraphEnd), exclusionList);
                        for (Pair<Integer, Integer> nonQuoteRange : prevParagraphNonQuoteRuns) {
                            for (String name : sieve.scanForNames(nonQuoteRange).first) {
                                if (p.aliases.contains(name))
                                    features.setCount("nameInPrevParagraphNonQuote", 1);
                            }
                        }
                    }
                }
            }
            if (isTraining) {
                if (QuoteAttributionUtils.rangeContains(new Pair<>(gold.mentionStartTokenIndex, gold.mentionEndTokenIndex), new Pair<>(mention.begin, mention.end))) {
                    RVFDatum<String, String> datum = new RVFDatum<>(features, "isMention");
                    datum.setID(Integer.toString(dataset.size()));
                    mapDatumToMention.put(dataset.size(), mention);
                    dataset.add(datum);
                } else {
                    RVFDatum<String, String> datum = new RVFDatum<>(features, "isNotMention");
                    datum.setID(Integer.toString(dataset.size()));
                    dataset.add(datum);
                    mapDatumToMention.put(dataset.size(), mention);
                }
            } else {
                RVFDatum<String, String> datum = new RVFDatum<>(features, "none");
                datum.setID(Integer.toString(dataset.size()));
                mapDatumToMention.put(dataset.size(), mention);
                dataset.add(datum);
            }
        }
        mapQuoteToDataRange.put(quoteIdx, new Pair<>(initialSize, dataset.size() - 1));
    }
    return new FeaturesData(mapQuoteToDataRange, mapDatumToMention, dataset);
}
Also used : RVFDataset(edu.stanford.nlp.classify.RVFDataset) RVFDatum(edu.stanford.nlp.ling.RVFDatum) Sieve(edu.stanford.nlp.quoteattribution.Sieves.Sieve) Pair(edu.stanford.nlp.util.Pair) Annotation(edu.stanford.nlp.pipeline.Annotation) CoreLabel(edu.stanford.nlp.ling.CoreLabel) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) CoreMap(edu.stanford.nlp.util.CoreMap)

Aggregations

RVFDataset (edu.stanford.nlp.classify.RVFDataset)2 LinearClassifierFactory (edu.stanford.nlp.classify.LinearClassifierFactory)1 CoreAnnotations (edu.stanford.nlp.ling.CoreAnnotations)1 CoreLabel (edu.stanford.nlp.ling.CoreLabel)1 IndexedWord (edu.stanford.nlp.ling.IndexedWord)1 RVFDatum (edu.stanford.nlp.ling.RVFDatum)1 QNMinimizer (edu.stanford.nlp.optimization.QNMinimizer)1 Annotation (edu.stanford.nlp.pipeline.Annotation)1 Sieve (edu.stanford.nlp.quoteattribution.Sieves.Sieve)1 SceneGraphImage (edu.stanford.nlp.scenegraph.image.SceneGraphImage)1 SceneGraphImageRegion (edu.stanford.nlp.scenegraph.image.SceneGraphImageRegion)1 SemanticGraph (edu.stanford.nlp.semgraph.SemanticGraph)1 ClassicCounter (edu.stanford.nlp.stats.ClassicCounter)1 CoreMap (edu.stanford.nlp.util.CoreMap)1 Pair (edu.stanford.nlp.util.Pair)1 Triple (edu.stanford.nlp.util.Triple)1