use of edu.stanford.nlp.classify.RVFDataset in project CoreNLP by stanfordnlp.
the class EntityClassifier method train.
private static void train(List<SceneGraphImage> images, String modelPath, Embedding embeddings) throws IOException {
RVFDataset<String, String> dataset = new RVFDataset<String, String>();
SceneGraphSentenceMatcher sentenceMatcher = new SceneGraphSentenceMatcher(embeddings);
for (SceneGraphImage img : images) {
for (SceneGraphImageRegion region : img.regions) {
SemanticGraph sg = region.getEnhancedSemanticGraph();
SemanticGraphEnhancer.enhance(sg);
List<Triple<IndexedWord, IndexedWord, String>> relationTriples = sentenceMatcher.getRelationTriples(region);
for (Triple<IndexedWord, IndexedWord, String> relation : relationTriples) {
IndexedWord w1 = sg.getNodeByIndexSafe(relation.first.index());
if (w1 != null) {
dataset.add(getDatum(w1, relation.first.get(SceneGraphCoreAnnotations.GoldEntityAnnotation.class), embeddings));
}
}
}
}
LinearClassifierFactory<String, String> classifierFactory = new LinearClassifierFactory<String, String>(new QNMinimizer(15), 1e-4, false, REG_STRENGTH);
Classifier<String, String> classifier = classifierFactory.trainClassifier(dataset);
IOUtils.writeObjectToFile(classifier, modelPath);
System.err.println(classifier.evaluateAccuracy(dataset));
}
use of edu.stanford.nlp.classify.RVFDataset in project CoreNLP by stanfordnlp.
the class SupervisedSieveTraining method featurize.
// goldList null if not training
public static FeaturesData featurize(SieveData sd, List<XMLToAnnotation.GoldQuoteInfo> goldList, boolean isTraining) {
Annotation doc = sd.doc;
// use to access functions
Sieve sieve = new Sieve(doc, sd.characterMap, sd.pronounCorefMap, sd.animacyList);
List<CoreMap> quotes = doc.get(CoreAnnotations.QuotationsAnnotation.class);
List<CoreMap> sentences = doc.get(CoreAnnotations.SentencesAnnotation.class);
List<CoreLabel> tokens = doc.get(CoreAnnotations.TokensAnnotation.class);
Map<Integer, List<CoreMap>> paragraphToQuotes = getQuotesInParagraph(doc);
GeneralDataset<String, String> dataset = new RVFDataset<>();
// necessary for 'ScoreBestMention'
// maps quote to corresponding indices in the dataset
Map<Integer, Pair<Integer, Integer>> mapQuoteToDataRange = new HashMap<>();
Map<Integer, Sieve.MentionData> mapDatumToMention = new HashMap<>();
if (isTraining && goldList.size() != quotes.size()) {
throw new RuntimeException("Gold Quote List size doesn't match quote list size!");
}
for (int quoteIdx = 0; quoteIdx < quotes.size(); quoteIdx++) {
int initialSize = dataset.size();
CoreMap quote = quotes.get(quoteIdx);
XMLToAnnotation.GoldQuoteInfo gold = null;
if (isTraining) {
gold = goldList.get(quoteIdx);
if (gold.speaker.isEmpty()) {
continue;
}
}
CoreMap quoteFirstSentence = sentences.get(quote.get(CoreAnnotations.SentenceBeginAnnotation.class));
Pair<Integer, Integer> quoteRun = new Pair<>(quote.get(CoreAnnotations.TokenBeginAnnotation.class), quote.get(CoreAnnotations.TokenEndAnnotation.class));
// int quoteChapter = quoteFirstSentence.get(ChapterAnnotator.ChapterAnnotation.class);
int quoteParagraphIdx = quoteFirstSentence.get(CoreAnnotations.ParagraphIndexAnnotation.class);
// add mentions before quote up to the previous paragraph
int rightValue = quoteRun.first - 1;
int leftValue = quoteRun.first - 1;
// move left value to be the first token idx of the previous paragraph
for (int sentIdx = quote.get(CoreAnnotations.SentenceBeginAnnotation.class); sentIdx >= 0; sentIdx--) {
CoreMap sentence = sentences.get(sentIdx);
if (sentence.get(CoreAnnotations.ParagraphIndexAnnotation.class) == quoteParagraphIdx) {
continue;
}
if (sentence.get(CoreAnnotations.ParagraphIndexAnnotation.class) == quoteParagraphIdx - 1) {
// quoteParagraphIdx - 1 for this and prev
leftValue = sentence.get(CoreAnnotations.TokenBeginAnnotation.class);
} else {
break;
}
}
List<Sieve.MentionData> mentionsInPreviousParagraph = new ArrayList<>();
if (leftValue > -1 && rightValue > -1)
mentionsInPreviousParagraph = eliminateDuplicates(sieve.findClosestMentionsInSpanBackward(new Pair<>(leftValue, rightValue)));
// mentions in next paragraph
leftValue = quoteRun.second + 1;
rightValue = quoteRun.second + 1;
for (int sentIdx = quote.get(CoreAnnotations.SentenceEndAnnotation.class); sentIdx < sentences.size(); sentIdx++) {
CoreMap sentence = sentences.get(sentIdx);
// }
if (sentence.get(CoreAnnotations.ParagraphIndexAnnotation.class) == quoteParagraphIdx) {
// quoteParagraphIdx + 1
rightValue = sentence.get(CoreAnnotations.TokenEndAnnotation.class) - 1;
} else {
break;
}
}
List<Sieve.MentionData> mentionsInNextParagraph = new ArrayList<>();
if (leftValue < tokens.size() && rightValue < tokens.size())
mentionsInNextParagraph = sieve.findClosestMentionsInSpanForward(new Pair<>(leftValue, rightValue));
List<Sieve.MentionData> candidateMentions = new ArrayList<>();
candidateMentions.addAll(mentionsInPreviousParagraph);
candidateMentions.addAll(mentionsInNextParagraph);
// System.out.println(candidateMentions.size());
int rankedDistance = 1;
int numBackwards = mentionsInPreviousParagraph.size();
for (Sieve.MentionData mention : candidateMentions) {
// List<CoreLabel> mentionCandidateTokens = doc.get(CoreAnnotations.TokensAnnotation.class).subList(mention.begin, mention.end + 1);
// CoreMap mentionCandidateSentence = sentences.get(mentionCandidateTokens.get(0).sentIndex());
// if (mentionCandidateSentence.get(ChapterAnnotator.ChapterAnnotation.class) != quoteChapter) {
// continue;
// }
Counter<String> features = new ClassicCounter<>();
boolean isLeft = true;
int distance = quoteRun.first - mention.end;
if (distance < 0) {
isLeft = false;
distance = mention.begin - quoteRun.second;
}
if (distance < 0) {
// disregard mention-in-quote cases.
continue;
}
features.setCount("wordDistance", distance);
List<CoreLabel> betweenTokens;
if (isLeft) {
betweenTokens = tokens.subList(mention.end + 1, quoteRun.first);
} else {
betweenTokens = tokens.subList(quoteRun.second + 1, mention.begin);
}
// Punctuation in between
for (CoreLabel token : betweenTokens) {
if (punctuation.contains(token.word())) {
features.setCount("punctuationPresence:" + token.word(), 1);
}
}
// number of mentions away
features.setCount("rankedDistance", rankedDistance);
rankedDistance++;
if (rankedDistance == numBackwards) {
// reset for the forward
rankedDistance = 1;
}
// int quoteParagraphIdx = quoteFirstSentence.get(CoreAnnotations.ParagraphIndexAnnotation.class);
// third distance: # of paragraphs away
int mentionParagraphIdx = -1;
CoreMap sentenceInMentionParagraph = null;
int quoteParagraphBeginToken = getParagraphBeginToken(quoteFirstSentence, sentences);
int quoteParagraphEndToken = getParagraphEndToken(quoteFirstSentence, sentences);
if (isLeft) {
if (quoteParagraphBeginToken <= mention.begin && mention.end <= quoteParagraphEndToken) {
features.setCount("leftParagraphDistance", 0);
mentionParagraphIdx = quoteParagraphIdx;
sentenceInMentionParagraph = quoteFirstSentence;
} else {
int paragraphDistance = 1;
int currParagraphIdx = quoteParagraphIdx - paragraphDistance;
CoreMap currSentence = quoteFirstSentence;
int currSentenceIdx = currSentence.get(CoreAnnotations.SentenceIndexAnnotation.class);
while (currParagraphIdx >= 0) {
// extract begin and end tokens of
while (currSentence.get(CoreAnnotations.ParagraphIndexAnnotation.class) != currParagraphIdx) {
currSentenceIdx--;
currSentence = sentences.get(currSentenceIdx);
}
int prevParagraphBegin = getParagraphBeginToken(currSentence, sentences);
int prevParagraphEnd = getParagraphEndToken(currSentence, sentences);
if (prevParagraphBegin <= mention.begin && mention.end <= prevParagraphEnd) {
mentionParagraphIdx = currParagraphIdx;
sentenceInMentionParagraph = currSentence;
features.setCount("leftParagraphDistance", paragraphDistance);
if (paragraphDistance % 2 == 0)
features.setCount("leftParagraphDistanceEven", 1);
break;
}
paragraphDistance++;
currParagraphIdx--;
}
}
} else // right
{
if (quoteParagraphBeginToken <= mention.begin && mention.end <= quoteParagraphEndToken) {
features.setCount("rightParagraphDistance", 0);
sentenceInMentionParagraph = quoteFirstSentence;
mentionParagraphIdx = quoteParagraphIdx;
} else {
int paragraphDistance = 1;
int nextParagraphIndex = quoteParagraphIdx + paragraphDistance;
CoreMap currSentence = quoteFirstSentence;
int currSentenceIdx = currSentence.get(CoreAnnotations.SentenceIndexAnnotation.class);
while (currSentenceIdx < sentences.size()) {
while (currSentence.get(CoreAnnotations.ParagraphIndexAnnotation.class) != nextParagraphIndex) {
currSentenceIdx++;
currSentence = sentences.get(currSentenceIdx);
}
int nextParagraphBegin = getParagraphBeginToken(currSentence, sentences);
int nextParagraphEnd = getParagraphEndToken(currSentence, sentences);
if (nextParagraphBegin <= mention.begin && mention.end <= nextParagraphEnd) {
sentenceInMentionParagraph = currSentence;
features.setCount("rightParagraphDistance", paragraphDistance);
break;
}
paragraphDistance++;
nextParagraphIndex++;
}
}
}
// 2. mention features
if (sentenceInMentionParagraph != null) {
int mentionParagraphBegin = getParagraphBeginToken(sentenceInMentionParagraph, sentences);
int mentionParagraphEnd = getParagraphEndToken(sentenceInMentionParagraph, sentences);
if (!(mentionParagraphBegin == quoteParagraphBeginToken && mentionParagraphEnd == quoteParagraphEndToken)) {
List<CoreMap> quotesInMentionParagraph = paragraphToQuotes.getOrDefault(mentionParagraphIdx, new ArrayList<>());
Pair<ArrayList<String>, ArrayList<Pair<Integer, Integer>>> namesInMentionParagraph = sieve.scanForNames(new Pair<>(mentionParagraphBegin, mentionParagraphEnd));
features.setCount("quotesInMentionParagraph", quotesInMentionParagraph.size());
features.setCount("wordsInMentionParagraph", mentionParagraphEnd - mentionParagraphBegin + 1);
features.setCount("namesInMentionParagraph", namesInMentionParagraph.first.size());
// mention ordering in paragraph it is in
for (int i = 0; i < namesInMentionParagraph.second.size(); i++) {
if (ExtractQuotesUtil.rangeContains(new Pair<>(mention.begin, mention.end), namesInMentionParagraph.second.get(i)))
features.setCount("orderInParagraph", i);
}
if (quotesInMentionParagraph.size() == 1) {
CoreMap qInMentionParagraph = quotesInMentionParagraph.get(0);
if (qInMentionParagraph.get(CoreAnnotations.TokenBeginAnnotation.class) == mentionParagraphBegin && qInMentionParagraph.get(CoreAnnotations.TokenEndAnnotation.class) - 1 == mentionParagraphEnd) {
features.setCount("mentionParagraphIsInConversation", 1);
} else {
features.setCount("mentionParagraphIsInConversation", -1);
}
}
for (CoreMap quoteIMP : quotesInMentionParagraph) {
if (ExtractQuotesUtil.rangeContains(new Pair<>(quoteIMP.get(CoreAnnotations.TokenBeginAnnotation.class), quoteIMP.get(CoreAnnotations.TokenEndAnnotation.class) - 1), new Pair<>(mention.begin, mention.end)))
features.setCount("mentionInQuote", 1);
}
if (features.getCount("mentionInQuote") != 1)
features.setCount("mentionNotInQuote", 1);
}
}
// or there will be an array index crash
if (mention.begin > 0) {
CoreLabel prevWord = tokens.get(mention.begin - 1);
features.setCount("prevWordType:" + prevWord.tag(), 1);
if (punctuationForFeatures.contains(prevWord.lemma()))
features.setCount("prevWordPunct:" + prevWord.lemma(), 1);
}
if (mention.end + 1 < tokens.size()) {
CoreLabel nextWord = tokens.get(mention.end + 1);
features.setCount("nextWordType:" + nextWord.tag(), 1);
if (punctuationForFeatures.contains(nextWord.lemma()))
features.setCount("nextWordPunct:" + nextWord.lemma(), 1);
}
// features.setCount("prevAndNext:" + prevWord.tag()+ ";" + nextWord.tag(), 1);
// quote paragraph features
List<CoreMap> quotesInQuoteParagraph = paragraphToQuotes.get(quoteParagraphIdx);
features.setCount("QuotesInQuoteParagraph", quotesInQuoteParagraph.size());
features.setCount("WordsInQuoteParagraph", quoteParagraphEndToken - quoteParagraphBeginToken + 1);
features.setCount("NamesInQuoteParagraph", sieve.scanForNames(new Pair<>(quoteParagraphBeginToken, quoteParagraphEndToken)).first.size());
// quote features
features.setCount("quoteLength", quote.get(CoreAnnotations.TokenEndAnnotation.class) - quote.get(CoreAnnotations.TokenBeginAnnotation.class) + 1);
for (int i = 0; i < quotesInQuoteParagraph.size(); i++) {
if (quotesInQuoteParagraph.get(i).equals(quote)) {
features.setCount("quotePosition", i + 1);
}
}
if (features.getCount("quotePosition") == 0)
throw new RuntimeException("Check this (equality not working)");
Pair<ArrayList<String>, ArrayList<Pair<Integer, Integer>>> namesData = sieve.scanForNames(quoteRun);
for (String name : namesData.first) {
features.setCount("charactersInQuote:" + sd.characterMap.get(name).get(0).name, 1);
}
// if quote encompasses entire paragraph
if (quote.get(CoreAnnotations.TokenBeginAnnotation.class) == quoteParagraphBeginToken && quote.get(CoreAnnotations.TokenEndAnnotation.class) == quoteParagraphEndToken) {
features.setCount("isImplicitSpeaker", 1);
} else {
features.setCount("isImplicitSpeaker", -1);
}
// Vocative detection
if (mention.type.equals("name")) {
List<Person> pList = sd.characterMap.get(sieve.tokenRangeToString(new Pair<>(mention.begin, mention.end)));
Person p = null;
if (pList != null)
p = pList.get(0);
else {
Pair<ArrayList<String>, ArrayList<Pair<Integer, Integer>>> scanForNamesResultPair = sieve.scanForNames(new Pair<>(mention.begin, mention.end));
if (scanForNamesResultPair.first.size() != 0) {
String scanForNamesResultString = scanForNamesResultPair.first.get(0);
if (scanForNamesResultString != null && sd.characterMap.containsKey(scanForNamesResultString)) {
p = sd.characterMap.get(scanForNamesResultString).get(0);
}
}
}
if (p != null) {
for (String name : namesData.first) {
if (p.aliases.contains(name))
features.setCount("nameInQuote", 1);
}
if (quoteParagraphIdx > 0) {
// Paragraph prevParagraph = paragraphs.get(ex.paragraph_idx - 1);
List<CoreMap> quotesInPrevParagraph = paragraphToQuotes.getOrDefault(quoteParagraphIdx - 1, new ArrayList<>());
List<Pair<Integer, Integer>> exclusionList = new ArrayList<>();
for (CoreMap quoteIPP : quotesInPrevParagraph) {
Pair<Integer, Integer> quoteRange = new Pair<>(quoteIPP.get(CoreAnnotations.TokenBeginAnnotation.class), quoteIPP.get(CoreAnnotations.TokenEndAnnotation.class));
exclusionList.add(quoteRange);
for (String name : sieve.scanForNames(quoteRange).first) {
if (p.aliases.contains(name))
features.setCount("nameInPrevParagraphQuote", 1);
}
}
int sentenceIdx = quoteFirstSentence.get(CoreAnnotations.SentenceIndexAnnotation.class);
CoreMap sentenceInPrevParagraph = null;
for (int i = sentenceIdx - 1; i >= 0; i--) {
CoreMap currSentence = sentences.get(i);
if (currSentence.get(CoreAnnotations.ParagraphIndexAnnotation.class) == quoteParagraphIdx - 1) {
sentenceInPrevParagraph = currSentence;
break;
}
}
int prevParagraphBegin = getParagraphBeginToken(sentenceInPrevParagraph, sentences);
int prevParagraphEnd = getParagraphEndToken(sentenceInPrevParagraph, sentences);
List<Pair<Integer, Integer>> prevParagraphNonQuoteRuns = getRangeExclusion(new Pair<>(prevParagraphBegin, prevParagraphEnd), exclusionList);
for (Pair<Integer, Integer> nonQuoteRange : prevParagraphNonQuoteRuns) {
for (String name : sieve.scanForNames(nonQuoteRange).first) {
if (p.aliases.contains(name))
features.setCount("nameInPrevParagraphNonQuote", 1);
}
}
}
}
}
if (isTraining) {
if (QuoteAttributionUtils.rangeContains(new Pair<>(gold.mentionStartTokenIndex, gold.mentionEndTokenIndex), new Pair<>(mention.begin, mention.end))) {
RVFDatum<String, String> datum = new RVFDatum<>(features, "isMention");
datum.setID(Integer.toString(dataset.size()));
mapDatumToMention.put(dataset.size(), mention);
dataset.add(datum);
} else {
RVFDatum<String, String> datum = new RVFDatum<>(features, "isNotMention");
datum.setID(Integer.toString(dataset.size()));
dataset.add(datum);
mapDatumToMention.put(dataset.size(), mention);
}
} else {
RVFDatum<String, String> datum = new RVFDatum<>(features, "none");
datum.setID(Integer.toString(dataset.size()));
mapDatumToMention.put(dataset.size(), mention);
dataset.add(datum);
}
}
mapQuoteToDataRange.put(quoteIdx, new Pair<>(initialSize, dataset.size() - 1));
}
return new FeaturesData(mapQuoteToDataRange, mapDatumToMention, dataset);
}
Aggregations