use of edu.stanford.nlp.classify.Classifier in project CoreNLP by stanfordnlp.
the class KBPAnnotator method annotate.
/**
* Annotate this document for KBP relations.
* @param annotation The document to annotate.
*/
@Override
public void annotate(Annotation annotation) {
// get a list of sentences for this annotation
List<CoreMap> sentences = annotation.get(CoreAnnotations.SentencesAnnotation.class);
// Create simple document
Document doc = new Document(kbpProperties, serializer.toProto(annotation));
// Get the mentions in the document
List<CoreMap> mentions = new ArrayList<>();
for (CoreMap sentence : sentences) {
mentions.addAll(sentence.get(CoreAnnotations.MentionsAnnotation.class));
}
// Compute coreferent clusters
// (map an index to a KBP mention)
Map<Pair<Integer, Integer>, CoreMap> mentionByStartIndex = new HashMap<>();
for (CoreMap mention : mentions) {
for (CoreLabel token : mention.get(CoreAnnotations.TokensAnnotation.class)) {
mentionByStartIndex.put(Pair.makePair(token.sentIndex(), token.index()), mention);
}
}
// (collect coreferent KBP mentions)
// map from canonical mention -> other mentions
Map<CoreMap, Set<CoreMap>> mentionsMap = new HashMap<>();
if (annotation.get(CorefCoreAnnotations.CorefChainAnnotation.class) != null) {
for (Map.Entry<Integer, CorefChain> chain : annotation.get(CorefCoreAnnotations.CorefChainAnnotation.class).entrySet()) {
CoreMap firstMention = null;
for (CorefChain.CorefMention mention : chain.getValue().getMentionsInTextualOrder()) {
CoreMap kbpMention = null;
for (int i = mention.startIndex; i < mention.endIndex; ++i) {
if (mentionByStartIndex.containsKey(Pair.makePair(mention.sentNum - 1, i))) {
kbpMention = mentionByStartIndex.get(Pair.makePair(mention.sentNum - 1, i));
break;
}
}
if (firstMention == null) {
firstMention = kbpMention;
}
if (kbpMention != null) {
if (!mentionsMap.containsKey(firstMention)) {
mentionsMap.put(firstMention, new LinkedHashSet<>());
}
mentionsMap.get(firstMention).add(kbpMention);
}
}
}
}
// (coreference acronyms)
acronymMatch(mentions, mentionsMap);
// (ensure valid NER tag for canonical mention)
for (CoreMap key : new HashSet<>(mentionsMap.keySet())) {
if (key.get(CoreAnnotations.NamedEntityTagAnnotation.class) == null) {
CoreMap newKey = null;
for (CoreMap candidate : mentionsMap.get(key)) {
if (candidate.get(CoreAnnotations.NamedEntityTagAnnotation.class) != null) {
newKey = candidate;
break;
}
}
if (newKey != null) {
mentionsMap.put(newKey, mentionsMap.remove(key));
} else {
// case: no mention in this chain has an NER tag.
mentionsMap.remove(key);
}
}
}
// Propagate Entity Link
for (Map.Entry<CoreMap, Set<CoreMap>> entry : mentionsMap.entrySet()) {
String entityLink = entry.getKey().get(CoreAnnotations.WikipediaEntityAnnotation.class);
if (entityLink != null) {
for (CoreMap mention : entry.getValue()) {
for (CoreLabel token : mention.get(CoreAnnotations.TokensAnnotation.class)) {
token.set(CoreAnnotations.WikipediaEntityAnnotation.class, entityLink);
}
}
}
}
// create a mapping of char offset pairs to KBPMention
HashMap<Pair<Integer, Integer>, CoreMap> charOffsetToKBPMention = new HashMap<>();
for (CoreMap mention : mentions) {
int nerMentionCharBegin = mention.get(CoreAnnotations.CharacterOffsetBeginAnnotation.class);
int nerMentionCharEnd = mention.get(CoreAnnotations.CharacterOffsetEndAnnotation.class);
charOffsetToKBPMention.put(new Pair<>(nerMentionCharBegin, nerMentionCharEnd), mention);
}
// Create a canonical mention map
Map<CoreMap, CoreMap> mentionToCanonicalMention;
if (kbpLanguage.equals(LanguageInfo.HumanLanguage.SPANISH)) {
mentionToCanonicalMention = spanishCorefSystem.canonicalMentionMapFromEntityMentions(mentions);
if (VERBOSE) {
log.info("---");
log.info("basic spanish coref results");
for (CoreMap originalMention : mentionToCanonicalMention.keySet()) {
if (!originalMention.equals(mentionToCanonicalMention.get(originalMention))) {
log.info("mapped: " + originalMention + " to: " + mentionToCanonicalMention.get(originalMention));
}
}
}
} else {
mentionToCanonicalMention = new HashMap<>();
}
// check if there is coref info
Set<Map.Entry<Integer, CorefChain>> corefChains;
if (annotation.get(CorefCoreAnnotations.CorefChainAnnotation.class) != null && !kbpLanguage.equals(LanguageInfo.HumanLanguage.SPANISH))
corefChains = annotation.get(CorefCoreAnnotations.CorefChainAnnotation.class).entrySet();
else
corefChains = new HashSet<>();
for (Map.Entry<Integer, CorefChain> indexCorefChainPair : corefChains) {
CorefChain corefChain = indexCorefChainPair.getValue();
Pair<List<CoreMap>, CoreMap> corefChainKBPMentionsAndBestIndex = corefChainToKBPMentions(corefChain, annotation, charOffsetToKBPMention);
List<CoreMap> corefChainKBPMentions = corefChainKBPMentionsAndBestIndex.first();
CoreMap bestKBPMentionForChain = corefChainKBPMentionsAndBestIndex.second();
if (bestKBPMentionForChain != null) {
for (CoreMap kbpMention : corefChainKBPMentions) {
if (kbpMention != null) {
// System.err.println("---");
// ad hoc filters ; assume acceptable unless a filter blocks it
boolean acceptableLink = true;
// block people matches without a token overlap, exempting pronominal to non-pronominal
// good: Ashton --> Catherine Ashton
// good: she --> Catherine Ashton
// bad: Morsi --> Catherine Ashton
String kbpMentionNERTag = kbpMention.get(CoreAnnotations.NamedEntityTagAnnotation.class);
String bestKBPMentionForChainNERTag = bestKBPMentionForChain.get(CoreAnnotations.NamedEntityTagAnnotation.class);
if (kbpMentionNERTag != null && bestKBPMentionForChainNERTag != null && kbpMentionNERTag.equals("PERSON") && bestKBPMentionForChainNERTag.equals("PERSON") && !kbpIsPronominalMention(kbpMention.get(CoreAnnotations.TokensAnnotation.class).get(0)) && !kbpIsPronominalMention(bestKBPMentionForChain.get(CoreAnnotations.TokensAnnotation.class).get(0))) {
// System.err.println("testing PERSON to PERSON coref link");
boolean tokenMatchFound = false;
for (CoreLabel kbpToken : kbpMention.get(CoreAnnotations.TokensAnnotation.class)) {
for (CoreLabel bestKBPToken : bestKBPMentionForChain.get(CoreAnnotations.TokensAnnotation.class)) {
if (kbpToken.word().toLowerCase().equals(bestKBPToken.word().toLowerCase())) {
tokenMatchFound = true;
break;
}
}
if (tokenMatchFound)
break;
}
if (!tokenMatchFound)
acceptableLink = false;
}
// check the coref link passed the filters
if (acceptableLink)
mentionToCanonicalMention.put(kbpMention, bestKBPMentionForChain);
// System.err.println("kbp mention: " + kbpMention.get(CoreAnnotations.TextAnnotation.class));
// System.err.println("coref mention: " + bestKBPMentionForChain.get(CoreAnnotations.TextAnnotation.class));
}
}
}
}
// (add missing mentions)
mentions.stream().filter(mention -> mentionToCanonicalMention.get(mention) == null).forEach(mention -> mentionToCanonicalMention.put(mention, mention));
// handle acronym coreference
HashMap<String, List<CoreMap>> acronymClusters = new HashMap<>();
HashMap<String, List<CoreMap>> acronymInstances = new HashMap<>();
for (CoreMap acronymMention : mentionToCanonicalMention.keySet()) {
String acronymNERTag = acronymMention.get(CoreAnnotations.NamedEntityTagAnnotation.class);
if ((acronymMention == mentionToCanonicalMention.get(acronymMention)) && acronymNERTag != null && (acronymNERTag.equals(KBPRelationExtractor.NERTag.ORGANIZATION.name) || acronymNERTag.equals(KBPRelationExtractor.NERTag.LOCATION.name))) {
String acronymText = acronymMention.get(CoreAnnotations.TextAnnotation.class);
// define acronyms as not containing spaces (e.g. ACLU)
if (!acronymText.contains(" ")) {
int numCoreferentsChecked = 0;
for (CoreMap coreferentMention : mentions) {
// only check first 1000
if (numCoreferentsChecked > 1000)
break;
// don't check a mention against itself
if (acronymMention == coreferentMention)
continue;
// don't check other mentions without " "
String coreferentText = coreferentMention.get(CoreAnnotations.TextAnnotation.class);
if (!coreferentText.contains(" "))
continue;
numCoreferentsChecked++;
List<String> coreferentTokenStrings = coreferentMention.get(CoreAnnotations.TokensAnnotation.class).stream().map(CoreLabel::word).collect(Collectors.toList());
// afterwards find the best mention in acronymClusters, and match it to every mention in acronymInstances
if (AcronymMatcher.isAcronym(acronymText, coreferentTokenStrings)) {
if (!acronymClusters.containsKey(acronymText))
acronymClusters.put(acronymText, new ArrayList<>());
if (!acronymInstances.containsKey(acronymText))
acronymInstances.put(acronymText, new ArrayList<>());
acronymClusters.get(acronymText).add(coreferentMention);
acronymInstances.get(acronymText).add(acronymMention);
}
}
}
}
}
// process each acronym (e.g. ACLU)
for (String acronymText : acronymInstances.keySet()) {
// find longest ORG or null
CoreMap bestORG = null;
for (CoreMap coreferentMention : acronymClusters.get(acronymText)) {
if (!coreferentMention.get(CoreAnnotations.NamedEntityTagAnnotation.class).equals(KBPRelationExtractor.NERTag.ORGANIZATION.name))
continue;
if (bestORG == null)
bestORG = coreferentMention;
else if (coreferentMention.get(CoreAnnotations.TextAnnotation.class).length() > bestORG.get(CoreAnnotations.TextAnnotation.class).length())
bestORG = coreferentMention;
}
// find longest LOC or null
CoreMap bestLOC = null;
for (CoreMap coreferentMention : acronymClusters.get(acronymText)) {
if (!coreferentMention.get(CoreAnnotations.NamedEntityTagAnnotation.class).equals(KBPRelationExtractor.NERTag.LOCATION.name))
continue;
if (bestLOC == null)
bestLOC = coreferentMention;
else if (coreferentMention.get(CoreAnnotations.TextAnnotation.class).length() > bestLOC.get(CoreAnnotations.TextAnnotation.class).length())
bestLOC = coreferentMention;
}
// link ACLU to "American Civil Liberties Union" ; make sure NER types match
for (CoreMap acronymMention : acronymInstances.get(acronymText)) {
String mentionType = acronymMention.get(CoreAnnotations.NamedEntityTagAnnotation.class);
if (mentionType.equals(KBPRelationExtractor.NERTag.ORGANIZATION.name) && bestORG != null)
mentionToCanonicalMention.put(acronymMention, bestORG);
if (mentionType.equals(KBPRelationExtractor.NERTag.LOCATION.name) && bestLOC != null)
mentionToCanonicalMention.put(acronymMention, bestLOC);
}
}
// Cluster mentions by sentence
@SuppressWarnings("unchecked") List<CoreMap>[] mentionsBySentence = new List[annotation.get(CoreAnnotations.SentencesAnnotation.class).size()];
for (int i = 0; i < mentionsBySentence.length; ++i) {
mentionsBySentence[i] = new ArrayList<>();
}
for (CoreMap mention : mentionToCanonicalMention.keySet()) {
mentionsBySentence[mention.get(CoreAnnotations.SentenceIndexAnnotation.class)].add(mention);
}
// Classify
for (int sentenceI = 0; sentenceI < mentionsBySentence.length; ++sentenceI) {
HashMap<String, RelationTriple> relationStringsToTriples = new HashMap<>();
// the annotations
List<RelationTriple> finalTriplesList = new ArrayList<>();
List<CoreMap> candidates = mentionsBySentence[sentenceI];
// determine sentence length
int sentenceLength = annotation.get(CoreAnnotations.SentencesAnnotation.class).get(sentenceI).get(CoreAnnotations.TokensAnnotation.class).size();
// check if sentence is too long, if it's too long don't run kbp
if (maxLength != -1 && sentenceLength > maxLength) {
// set the triples annotation to an empty list of RelationTriples
annotation.get(CoreAnnotations.SentencesAnnotation.class).get(sentenceI).set(CoreAnnotations.KBPTriplesAnnotation.class, finalTriplesList);
// continue to next sentence
continue;
}
// sentence isn't too long, so continue processing this sentence
for (int subjI = 0; subjI < candidates.size(); ++subjI) {
CoreMap subj = candidates.get(subjI);
int subjBegin = subj.get(CoreAnnotations.TokensAnnotation.class).get(0).index() - 1;
int subjEnd = subj.get(CoreAnnotations.TokensAnnotation.class).get(subj.get(CoreAnnotations.TokensAnnotation.class).size() - 1).index();
Optional<KBPRelationExtractor.NERTag> subjNER = KBPRelationExtractor.NERTag.fromString(subj.get(CoreAnnotations.NamedEntityTagAnnotation.class));
if (subjNER.isPresent()) {
for (int objI = 0; objI < candidates.size(); ++objI) {
if (subjI == objI) {
continue;
}
if (Thread.interrupted()) {
throw new RuntimeInterruptedException();
}
CoreMap obj = candidates.get(objI);
int objBegin = obj.get(CoreAnnotations.TokensAnnotation.class).get(0).index() - 1;
int objEnd = obj.get(CoreAnnotations.TokensAnnotation.class).get(obj.get(CoreAnnotations.TokensAnnotation.class).size() - 1).index();
Optional<KBPRelationExtractor.NERTag> objNER = KBPRelationExtractor.NERTag.fromString(obj.get(CoreAnnotations.NamedEntityTagAnnotation.class));
if (objNER.isPresent() && KBPRelationExtractor.RelationType.plausiblyHasRelation(subjNER.get(), objNER.get())) {
// type check
KBPRelationExtractor.KBPInput input = new KBPRelationExtractor.KBPInput(new Span(subjBegin, subjEnd), new Span(objBegin, objEnd), subjNER.get(), objNER.get(), doc.sentence(sentenceI));
// -- BEGIN Classify
Pair<String, Double> prediction = extractor.classify(input);
// Handle the classifier output
if (!KBPStatisticalExtractor.NO_RELATION.equals(prediction.first)) {
RelationTriple triple = new RelationTriple.WithLink(subj.get(CoreAnnotations.TokensAnnotation.class), mentionToCanonicalMention.get(subj).get(CoreAnnotations.TokensAnnotation.class), Collections.singletonList(new CoreLabel(new Word(convertRelationNameToLatest(prediction.first)))), obj.get(CoreAnnotations.TokensAnnotation.class), mentionToCanonicalMention.get(obj).get(CoreAnnotations.TokensAnnotation.class), prediction.second, sentences.get(sentenceI).get(SemanticGraphCoreAnnotations.CollapsedCCProcessedDependenciesAnnotation.class), subj.get(CoreAnnotations.WikipediaEntityAnnotation.class), obj.get(CoreAnnotations.WikipediaEntityAnnotation.class));
String tripleString = triple.subjectGloss() + "\t" + triple.relationGloss() + "\t" + triple.objectGloss();
// ad hoc checks for problems
boolean acceptableTriple = true;
if (triple.objectGloss().equals(triple.subjectGloss()) && triple.relationGloss().endsWith("alternate_names"))
acceptableTriple = false;
// different confidence scores, so we want to filter out the lower confidence versions
if (acceptableTriple && !relationStringsToTriples.containsKey(tripleString))
relationStringsToTriples.put(tripleString, triple);
else if (acceptableTriple && triple.confidence > relationStringsToTriples.get(tripleString).confidence)
relationStringsToTriples.put(tripleString, triple);
}
}
}
}
}
finalTriplesList = new ArrayList<>(relationStringsToTriples.values());
// Set triples
annotation.get(CoreAnnotations.SentencesAnnotation.class).get(sentenceI).set(CoreAnnotations.KBPTriplesAnnotation.class, finalTriplesList);
}
}
Aggregations