use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class KBPStatisticalExtractor method surfaceFeatures.
@SuppressWarnings("UnusedParameters")
private static void surfaceFeatures(KBPInput input, Sentence simpleSentence, ClassicCounter<String> feats) {
List<String> lemmaSpan = spanBetweenMentions(input, CoreLabel::lemma);
List<String> nerSpan = spanBetweenMentions(input, CoreLabel::ner);
List<String> posSpan = spanBetweenMentions(input, CoreLabel::tag);
// Unigram features of the sentence
List<CoreLabel> tokens = input.sentence.asCoreLabels(Sentence::lemmas, Sentence::nerTags);
for (CoreLabel token : tokens) {
indicator(feats, "sentence_unigram", token.lemma());
}
// Full lemma span ( -0.3 F1 )
// if (lemmaSpan.size() <= 5) {
// indicator(feats, "full_lemma_span", withMentionsPositioned(input, StringUtils.join(lemmaSpan, " ")));
// }
// Lemma n-grams
String lastLemma = "_^_";
for (String lemma : lemmaSpan) {
indicator(feats, "lemma_bigram", withMentionsPositioned(input, lastLemma + " " + lemma));
indicator(feats, "lemma_unigram", withMentionsPositioned(input, lemma));
lastLemma = lemma;
}
indicator(feats, "lemma_bigram", withMentionsPositioned(input, lastLemma + " _$_"));
// NER + lemma bi-grams
for (int i = 0; i < lemmaSpan.size() - 1; ++i) {
if (!"O".equals(nerSpan.get(i)) && "O".equals(nerSpan.get(i + 1)) && "IN".equals(posSpan.get(i + 1))) {
indicator(feats, "ner/lemma_bigram", withMentionsPositioned(input, nerSpan.get(i) + " " + lemmaSpan.get(i + 1)));
}
if (!"O".equals(nerSpan.get(i + 1)) && "O".equals(nerSpan.get(i)) && "IN".equals(posSpan.get(i))) {
indicator(feats, "ner/lemma_bigram", withMentionsPositioned(input, lemmaSpan.get(i) + " " + nerSpan.get(i + 1)));
}
}
// Distance between mentions
String distanceBucket = ">10";
if (lemmaSpan.size() == 0) {
distanceBucket = "0";
} else if (lemmaSpan.size() <= 3) {
distanceBucket = "<=3";
} else if (lemmaSpan.size() <= 5) {
distanceBucket = "<=5";
} else if (lemmaSpan.size() <= 10) {
distanceBucket = "<=10";
} else if (lemmaSpan.size() <= 15) {
distanceBucket = "<=15";
}
indicator(feats, "distance_between_entities_bucket", distanceBucket);
// Punctuation features
int numCommasInSpan = 0;
int numQuotesInSpan = 0;
int parenParity = 0;
for (String lemma : lemmaSpan) {
if (lemma.equals(",")) {
numCommasInSpan += 1;
}
if (lemma.equals("\"") || lemma.equals("``") || lemma.equals("''")) {
numQuotesInSpan += 1;
}
if (lemma.equals("(") || lemma.equals("-LRB-")) {
parenParity += 1;
}
if (lemma.equals(")") || lemma.equals("-RRB-")) {
parenParity -= 1;
}
}
indicator(feats, "comma_parity", numCommasInSpan % 2 == 0 ? "even" : "odd");
indicator(feats, "quote_parity", numQuotesInSpan % 2 == 0 ? "even" : "odd");
indicator(feats, "paren_parity", "" + parenParity);
// Is broken by entity
Set<String> intercedingNERTags = nerSpan.stream().filter(ner -> !ner.equals("O")).collect(Collectors.toSet());
if (!intercedingNERTags.isEmpty()) {
indicator(feats, "has_interceding_ner", "t");
}
for (String ner : intercedingNERTags) {
indicator(feats, "interceding_ner", ner);
}
// Left and right context
List<CoreLabel> sentence = input.sentence.asCoreLabels(Sentence::nerTags);
if (input.subjectSpan.start() == 0) {
indicator(feats, "subj_left", "^");
} else {
indicator(feats, "subj_left", sentence.get(input.subjectSpan.start() - 1).lemma());
}
if (input.subjectSpan.end() == sentence.size()) {
indicator(feats, "subj_right", "$");
} else {
indicator(feats, "subj_right", sentence.get(input.subjectSpan.end()).lemma());
}
if (input.objectSpan.start() == 0) {
indicator(feats, "obj_left", "^");
} else {
indicator(feats, "obj_left", sentence.get(input.objectSpan.start() - 1).lemma());
}
if (input.objectSpan.end() == sentence.size()) {
indicator(feats, "obj_right", "$");
} else {
indicator(feats, "obj_right", sentence.get(input.objectSpan.end()).lemma());
}
// Skip-word patterns
if (lemmaSpan.size() == 1 && input.subjectSpan.isBefore(input.objectSpan)) {
String left = input.subjectSpan.start() == 0 ? "^" : sentence.get(input.subjectSpan.start() - 1).lemma();
indicator(feats, "X<subj>Y<obj>", left + "_" + lemmaSpan.get(0));
}
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class QuasiDeterminizer method computeLambda.
/**
* Takes time linear in number of arcs.
*/
public static ClassicCounter computeLambda(TransducerGraph graph) {
LinkedList queue = new LinkedList();
ClassicCounter lambda = new ClassicCounter();
ClassicCounter length = new ClassicCounter();
Map first = new HashMap();
Set nodes = graph.getNodes();
for (Object node : nodes) {
lambda.setCount(node, 0);
length.setCount(node, Double.POSITIVE_INFINITY);
}
Set endNodes = graph.getEndNodes();
for (Object o : endNodes) {
lambda.setCount(o, 0);
length.setCount(o, 0);
queue.addLast(o);
}
// Breadth first search
// get the first node from the queue
Object node = null;
try {
node = queue.removeFirst();
} catch (NoSuchElementException e) {
}
while (node != null) {
double oldLen = length.getCount(node);
Set arcs = graph.getArcsByTarget(node);
if (arcs != null) {
for (Object arc1 : arcs) {
TransducerGraph.Arc arc = (TransducerGraph.Arc) arc1;
Object newNode = arc.getSourceNode();
Comparable a = (Comparable) arc.getInput();
double k = ((Double) arc.getOutput()).doubleValue();
double newLen = length.getCount(newNode);
if (newLen == Double.POSITIVE_INFINITY) {
// we are discovering this
queue.addLast(newNode);
}
Comparable f = (Comparable) first.get(newNode);
if (newLen == Double.POSITIVE_INFINITY || (newLen == oldLen + 1 && a.compareTo(f) < 0)) {
// f can't be null, since we have a newLen
// we do this to this to newNode when we have new info, possibly many times
// ejecting old one if necessary
first.put(newNode, a);
// this may already be the case
length.setCount(newNode, oldLen + 1);
lambda.setCount(newNode, k + lambda.getCount(node));
}
}
}
// get a new node from the queue
node = null;
try {
node = queue.removeFirst();
} catch (NoSuchElementException e) {
}
}
return lambda;
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class EntityExtractorResultsPrinter method printResults.
@Override
public void printResults(PrintWriter pw, List<CoreMap> goldStandard, List<CoreMap> extractorOutput) {
ResultsPrinter.align(goldStandard, extractorOutput);
Counter<String> correct = new ClassicCounter<>();
Counter<String> predicted = new ClassicCounter<>();
Counter<String> gold = new ClassicCounter<>();
for (int i = 0; i < goldStandard.size(); i++) {
CoreMap goldSent = goldStandard.get(i);
CoreMap sysSent = extractorOutput.get(i);
String sysText = sysSent.get(TextAnnotation.class);
String goldText = goldSent.get(TextAnnotation.class);
if (verbose) {
log.info("SCORING THE FOLLOWING SENTENCE:");
log.info(sysSent.get(CoreAnnotations.TokensAnnotation.class));
}
HashSet<String> matchedGolds = new HashSet<>();
List<EntityMention> goldEntities = goldSent.get(MachineReadingAnnotations.EntityMentionsAnnotation.class);
if (goldEntities == null) {
goldEntities = new ArrayList<>();
}
for (EntityMention m : goldEntities) {
String label = makeLabel(m);
if (excludedClasses != null && excludedClasses.contains(label))
continue;
gold.incrementCount(label);
}
List<EntityMention> sysEntities = sysSent.get(MachineReadingAnnotations.EntityMentionsAnnotation.class);
if (sysEntities == null) {
sysEntities = new ArrayList<>();
}
for (EntityMention m : sysEntities) {
String label = makeLabel(m);
if (excludedClasses != null && excludedClasses.contains(label))
continue;
predicted.incrementCount(label);
if (verbose)
log.info("COMPARING PREDICTED MENTION: " + m);
boolean found = false;
for (EntityMention gm : goldEntities) {
if (matchedGolds.contains(gm.getObjectId()))
continue;
if (verbose)
log.info("\tagainst: " + gm);
if (gm.equals(m, useSubTypes)) {
if (verbose)
log.info("\t\t\tMATCH!");
found = true;
matchedGolds.add(gm.getObjectId());
if (verboseInstances) {
log.info("TRUE POSITIVE: " + m + " matched " + gm);
log.info("In sentence: " + sysText);
}
break;
}
}
if (found) {
correct.incrementCount(label);
} else if (verboseInstances) {
log.info("FALSE POSITIVE: " + m.toString());
log.info("In sentence: " + sysText);
}
}
if (verboseInstances) {
for (EntityMention m : goldEntities) {
String label = makeLabel(m);
if (!matchedGolds.contains(m.getObjectId()) && (excludedClasses == null || !excludedClasses.contains(label))) {
log.info("FALSE NEGATIVE: " + m.toString());
log.info("In sentence: " + goldText);
}
}
}
}
double totalCount = 0;
double totalCorrect = 0;
double totalPredicted = 0;
pw.println("Label\tCorrect\tPredict\tActual\tPrecn\tRecall\tF");
List<String> labels = new ArrayList<>(gold.keySet());
Collections.sort(labels);
for (String label : labels) {
if (excludedClasses != null && excludedClasses.contains(label))
continue;
double numCorrect = correct.getCount(label);
double numPredicted = predicted.getCount(label);
double trueCount = gold.getCount(label);
double precision = (numPredicted > 0) ? (numCorrect / numPredicted) : 0;
double recall = numCorrect / trueCount;
double f = (precision + recall > 0) ? 2 * precision * recall / (precision + recall) : 0.0;
pw.println(StringUtils.padOrTrim(label, 21) + "\t" + numCorrect + "\t" + numPredicted + "\t" + trueCount + "\t" + FORMATTER.format(precision * 100) + "\t" + FORMATTER.format(100 * recall) + "\t" + FORMATTER.format(100 * f));
totalCount += trueCount;
totalCorrect += numCorrect;
totalPredicted += numPredicted;
}
double precision = (totalPredicted > 0) ? (totalCorrect / totalPredicted) : 0;
double recall = totalCorrect / totalCount;
double f = (totalPredicted > 0 && totalCorrect > 0) ? 2 * precision * recall / (precision + recall) : 0.0;
pw.println("Total\t" + totalCorrect + "\t" + totalPredicted + "\t" + totalCount + "\t" + FORMATTER.format(100 * precision) + "\t" + FORMATTER.format(100 * recall) + "\t" + FORMATTER.format(100 * f));
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class RothResultsByRelation method printResults.
@Override
public void printResults(PrintWriter pw, List<CoreMap> goldStandard, List<CoreMap> extractorOutput) {
featureFactory = MachineReading.makeRelationFeatureFactory(MachineReadingProperties.relationFeatureFactoryClass, MachineReadingProperties.relationFeatures, false);
// generic mentions work well in this domain
mentionFactory = new RelationMentionFactory();
ResultsPrinter.align(goldStandard, extractorOutput);
List<RelationMention> relations = new ArrayList<>();
final Map<RelationMention, String> predictions = new HashMap<>();
for (int i = 0; i < goldStandard.size(); i++) {
List<RelationMention> goldRelations = AnnotationUtils.getAllRelations(mentionFactory, goldStandard.get(i), true);
relations.addAll(goldRelations);
for (RelationMention rel : goldRelations) {
predictions.put(rel, AnnotationUtils.getRelation(mentionFactory, extractorOutput.get(i), rel.getArg(0), rel.getArg(1)).getType());
}
}
final Counter<Pair<Pair<String, String>, String>> pathCounts = new ClassicCounter<>();
for (RelationMention rel : relations) {
pathCounts.incrementCount(new Pair<>(new Pair<>(rel.getArg(0).getType(), rel.getArg(1).getType()), featureFactory.getFeature(rel, "dependency_path_lowlevel")));
}
Counter<String> singletonCorrect = new ClassicCounter<>();
Counter<String> singletonPredicted = new ClassicCounter<>();
Counter<String> singletonActual = new ClassicCounter<>();
for (RelationMention rel : relations) {
if (pathCounts.getCount(new Pair<>(new Pair<>(rel.getArg(0).getType(), rel.getArg(1).getType()), featureFactory.getFeature(rel, "dependency_path_lowlevel"))) == 1.0) {
String prediction = predictions.get(rel);
if (prediction.equals(rel.getType())) {
singletonCorrect.incrementCount(prediction);
}
singletonPredicted.incrementCount(prediction);
singletonActual.incrementCount(rel.getType());
}
}
class RelComp implements Comparator<RelationMention> {
@Override
public int compare(RelationMention rel1, RelationMention rel2) {
// Group together actual relations of a type with relations that were
// predicted to be that type
String prediction1 = predictions.get(rel1);
String prediction2 = predictions.get(rel2);
// String rel1group = RelationsSentence.isUnrelatedLabel(rel1.getType())
// ? prediction1 : rel1.getType();
// String rel2group = RelationsSentence.isUnrelatedLabel(rel2.getType())
// ? prediction2 : rel2.getType();
int entComp = (rel1.getArg(0).getType() + rel1.getArg(1).getType()).compareTo(rel2.getArg(0).getType() + rel2.getArg(1).getType());
// int groupComp = rel1group.compareTo(rel2group);
int typeComp = rel1.getType().compareTo(rel2.getType());
int predictionComp = prediction1.compareTo(prediction2);
// int pathComp =
// getFeature(rel1,"generalized_dependency_path").compareTo(getFeature(rel2,"generalized_dependency_path"));
double pathCount1 = pathCounts.getCount(new Pair<>(new Pair<>(rel1.getArg(0).getType(), rel1.getArg(1).getType()), featureFactory.getFeature(rel1, "dependency_path_lowlevel")));
double pathCount2 = pathCounts.getCount(new Pair<>(new Pair<>(rel2.getArg(0).getType(), rel2.getArg(1).getType()), featureFactory.getFeature(rel2, "dependency_path_lowlevel")));
if (entComp != 0) {
return entComp;
// } else if (pathComp != 0) {
// return pathComp;
} else if (pathCount1 < pathCount2) {
return -1;
} else if (pathCount1 > pathCount2) {
return 1;
} else if (typeComp != 0) {
return typeComp;
} else if (predictionComp != 0) {
return predictionComp;
} else {
return rel1.getSentence().get(CoreAnnotations.TextAnnotation.class).compareTo(rel2.getSentence().get(CoreAnnotations.TextAnnotation.class));
}
}
}
RelComp relComp = new RelComp();
Collections.sort(relations, relComp);
for (RelationMention rel : relations) {
String prediction = predictions.get(rel);
// if (RelationsSentence.isUnrelatedLabel(prediction) &&
// RelationsSentence.isUnrelatedLabel(rel.getType())) {
// continue;
// }
String type1 = rel.getArg(0).getType();
String type2 = rel.getArg(1).getType();
String path = featureFactory.getFeature(rel, "dependency_path_lowlevel");
if (!((type1.equals("PEOPLE") && type2.equals("PEOPLE")) || (type1.equals("PEOPLE") && type2.equals("LOCATION")) || (type1.equals("LOCATION") && type2.equals("LOCATION")) || (type1.equals("ORGANIZATION") && type2.equals("LOCATION")) || (type1.equals("PEOPLE") && type2.equals("ORGANIZATION")))) {
continue;
}
if (path.equals("")) {
continue;
}
pw.println("\nLABEL: " + prediction);
pw.println(rel);
pw.println(path);
pw.println(featureFactory.getFeatures(rel, "dependency_path_words"));
pw.println(featureFactory.getFeature(rel, "surface_path_POS"));
}
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class CMMClassifier method classifySeq.
/**
* Classify a List of {@link CoreLabel}s using sequence information
* (i.e. Viterbi or Beam Search).
*
* @param document A List of {@link CoreLabel}s to be classified
*/
private void classifySeq(List<IN> document) {
if (document.isEmpty()) {
return;
}
SequenceModel ts = getSequenceModel(document);
// TagScorer ts = new PrevOnlyScorer(document, tagIndex, this, (!flags.useTaggySequences ? (flags.usePrevSequences ? 1 : 0) : flags.maxLeft), 0, answerArrays);
int[] tags;
//log.info("***begin test***");
if (flags.useViterbi) {
ExactBestSequenceFinder ti = new ExactBestSequenceFinder();
tags = ti.bestSequence(ts);
} else {
BeamBestSequenceFinder ti = new BeamBestSequenceFinder(flags.beamSize, true, true);
tags = ti.bestSequence(ts, document.size());
}
// used to improve recall in task 1b
if (flags.lowerNewgeneThreshold) {
log.info("Using NEWGENE threshold: " + flags.newgeneThreshold);
int[] copy = new int[tags.length];
System.arraycopy(tags, 0, copy, 0, tags.length);
// for each sequence marked as NEWGENE in the gazette
// tag the entire sequence as NEWGENE and sum the score
// if the score is greater than newgeneThreshold, accept
int ngTag = classIndex.indexOf("G");
//int bgTag = classIndex.indexOf(BACKGROUND);
int bgTag = classIndex.indexOf(flags.backgroundSymbol);
for (int i = 0, dSize = document.size(); i < dSize; i++) {
CoreLabel wordInfo = document.get(i);
if ("NEWGENE".equals(wordInfo.get(CoreAnnotations.GazAnnotation.class))) {
int start = i;
int j;
for (j = i; j < document.size(); j++) {
wordInfo = document.get(j);
if (!"NEWGENE".equals(wordInfo.get(CoreAnnotations.GazAnnotation.class))) {
break;
}
}
int end = j;
//int end = i + 1;
int winStart = Math.max(0, start - 4);
int winEnd = Math.min(tags.length, end + 4);
// clear a window around the sequences
for (j = winStart; j < winEnd; j++) {
copy[j] = bgTag;
}
// score as nongene
double bgScore = 0.0;
for (j = start; j < end; j++) {
double[] scores = ts.scoresOf(copy, j);
scores = Scorer.recenter(scores);
bgScore += scores[bgTag];
}
// first pass, compute all of the scores
ClassicCounter<Pair<Integer, Integer>> prevScores = new ClassicCounter<>();
for (j = start; j < end; j++) {
// clear the sequence
for (int k = start; k < end; k++) {
copy[k] = bgTag;
}
// grow the sequence from j until the end
for (int k = j; k < end; k++) {
copy[k] = ngTag;
// score the sequence
double ngScore = 0.0;
for (int m = start; m < end; m++) {
double[] scores = ts.scoresOf(copy, m);
scores = Scorer.recenter(scores);
ngScore += scores[tags[m]];
}
prevScores.incrementCount(new Pair<>(Integer.valueOf(j), Integer.valueOf(k)), ngScore - bgScore);
}
}
for (j = start; j < end; j++) {
// grow the sequence from j until the end
for (int k = j; k < end; k++) {
double score = prevScores.getCount(new Pair<>(Integer.valueOf(j), Integer.valueOf(k)));
// adding a word to the left
Pair<Integer, Integer> al = new Pair<>(Integer.valueOf(j - 1), Integer.valueOf(k));
// adding a word to the right
Pair<Integer, Integer> ar = new Pair<>(Integer.valueOf(j), Integer.valueOf(k + 1));
// subtracting word from left
Pair<Integer, Integer> sl = new Pair<>(Integer.valueOf(j + 1), Integer.valueOf(k));
// subtracting word from right
Pair<Integer, Integer> sr = new Pair<>(Integer.valueOf(j), Integer.valueOf(k - 1));
// make sure the score is greater than all its neighbors (one add or subtract)
if (score >= flags.newgeneThreshold && (!prevScores.containsKey(al) || score > prevScores.getCount(al)) && (!prevScores.containsKey(ar) || score > prevScores.getCount(ar)) && (!prevScores.containsKey(sl) || score > prevScores.getCount(sl)) && (!prevScores.containsKey(sr) || score > prevScores.getCount(sr))) {
StringBuilder sb = new StringBuilder();
wordInfo = document.get(j);
String docId = wordInfo.get(CoreAnnotations.IDAnnotation.class);
String startIndex = wordInfo.get(CoreAnnotations.PositionAnnotation.class);
wordInfo = document.get(k);
String endIndex = wordInfo.get(CoreAnnotations.PositionAnnotation.class);
for (int m = j; m <= k; m++) {
wordInfo = document.get(m);
sb.append(wordInfo.word());
sb.append(' ');
}
/*log.info(sb.toString()+"score:"+score+
" al:"+prevScores.getCount(al)+
" ar:"+prevScores.getCount(ar)+
" sl:"+prevScores.getCount(sl)+" sr:"+ prevScores.getCount(sr));*/
System.out.println(docId + '|' + startIndex + ' ' + endIndex + '|' + sb.toString().trim());
}
}
}
// restore the original tags
for (j = winStart; j < winEnd; j++) {
copy[j] = tags[j];
}
i = end;
}
}
}
for (int i = 0, docSize = document.size(); i < docSize; i++) {
CoreLabel lineInfo = document.get(i);
String answer = classIndex.get(tags[i]);
lineInfo.set(CoreAnnotations.AnswerAnnotation.class, answer);
}
if (flags.justify && classifier instanceof LinearClassifier) {
LinearClassifier<String, String> lc = (LinearClassifier<String, String>) classifier;
if (flags.dump) {
lc.dump();
}
for (int i = 0, docSize = document.size(); i < docSize; i++) {
CoreLabel lineInfo = document.get(i);
log.info("@@ Position is: " + i + ": ");
log.info(lineInfo.word() + ' ' + lineInfo.get(CoreAnnotations.AnswerAnnotation.class));
lc.justificationOf(makeDatum(document, i, featureFactories));
}
}
if (flags.useReverse) {
Collections.reverse(document);
}
}
Aggregations