use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class ACEMentionExtractor method printRawDoc.
private static void printRawDoc(List<CoreMap> sentences, List<List<Mention>> allMentions, String filename, boolean gold) throws FileNotFoundException {
StringBuilder doc = new StringBuilder();
int previousOffset = 0;
Counter<Integer> mentionCount = new ClassicCounter<>();
for (List<Mention> l : allMentions) {
for (Mention m : l) {
mentionCount.incrementCount(m.goldCorefClusterID);
}
}
for (int i = 0; i < sentences.size(); i++) {
CoreMap sentence = sentences.get(i);
List<Mention> mentions = allMentions.get(i);
String[] tokens = sentence.get(CoreAnnotations.TextAnnotation.class).split(" ");
String sent = "";
List<CoreLabel> t = sentence.get(CoreAnnotations.TokensAnnotation.class);
if (previousOffset + 2 < t.get(0).get(CoreAnnotations.CharacterOffsetBeginAnnotation.class))
sent += "\n";
previousOffset = t.get(t.size() - 1).get(CoreAnnotations.CharacterOffsetEndAnnotation.class);
Counter<Integer> startCounts = new ClassicCounter<>();
Counter<Integer> endCounts = new ClassicCounter<>();
Map<Integer, Set<Integer>> endID = Generics.newHashMap();
for (Mention m : mentions) {
startCounts.incrementCount(m.startIndex);
endCounts.incrementCount(m.endIndex);
if (!endID.containsKey(m.endIndex))
endID.put(m.endIndex, Generics.<Integer>newHashSet());
endID.get(m.endIndex).add(m.goldCorefClusterID);
}
for (int j = 0; j < tokens.length; j++) {
if (endID.containsKey(j)) {
for (Integer id : endID.get(j)) {
if (mentionCount.getCount(id) != 1 && gold)
sent += "]_" + id;
else
sent += "]";
}
}
for (int k = 0; k < startCounts.getCount(j); k++) {
if (!sent.endsWith("["))
sent += " ";
sent += "[";
}
sent += " ";
sent = sent + tokens[j];
}
for (int k = 0; k < endCounts.getCount(tokens.length); k++) {
sent += "]";
}
sent += "\n";
doc.append(sent);
}
if (gold)
logger.fine("New DOC: (GOLD MENTIONS) ==================================================");
else
logger.fine("New DOC: (Predicted Mentions) ==================================================");
logger.fine(doc.toString());
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class StatisticalCorefAlgorithm method runCoref.
@Override
public void runCoref(Document document) {
Compressor<String> compressor = new Compressor<>();
if (Thread.interrupted()) {
// Allow interrupting
throw new RuntimeInterruptedException();
}
Map<Pair<Integer, Integer>, Boolean> pairs = new HashMap<>();
for (Map.Entry<Integer, List<Integer>> e : CorefUtils.heuristicFilter(CorefUtils.getSortedMentions(document), maxMentionDistance, maxMentionDistanceWithStringMatch).entrySet()) {
for (int m1 : e.getValue()) {
pairs.put(new Pair<>(m1, e.getKey()), true);
}
}
DocumentExamples examples = extractor.extract(0, document, pairs, compressor);
Counter<Pair<Integer, Integer>> pairwiseScores = new ClassicCounter<>();
for (Example mentionPair : examples.examples) {
if (Thread.interrupted()) {
// Allow interrupting
throw new RuntimeInterruptedException();
}
pairwiseScores.incrementCount(new Pair<>(mentionPair.mentionId1, mentionPair.mentionId2), classifier.predict(mentionPair, examples.mentionFeatures, compressor));
}
List<Pair<Integer, Integer>> mentionPairs = new ArrayList<>(pairwiseScores.keySet());
Collections.sort(mentionPairs, (p1, p2) -> {
double diff = pairwiseScores.getCount(p2) - pairwiseScores.getCount(p1);
return diff == 0 ? 0 : (int) Math.signum(diff);
});
Set<Integer> seenAnaphors = new HashSet<>();
for (Pair<Integer, Integer> pair : mentionPairs) {
if (seenAnaphors.contains(pair.second)) {
continue;
}
if (Thread.interrupted()) {
// Allow interrupting
throw new RuntimeInterruptedException();
}
seenAnaphors.add(pair.second);
MentionType mt1 = document.predictedMentionsByID.get(pair.first).mentionType;
MentionType mt2 = document.predictedMentionsByID.get(pair.second).mentionType;
if (pairwiseScores.getCount(pair) > thresholds.get(new Pair<>(mt1 == MentionType.PRONOMINAL, mt2 == MentionType.PRONOMINAL))) {
CorefUtils.mergeCoreferenceClusters(pair, document);
}
}
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class SieveCoreferenceSystem method printRawDoc.
/** Print raw document for analysis */
private static void printRawDoc(Document document, boolean gold) throws FileNotFoundException {
List<CoreMap> sentences = document.annotation.get(CoreAnnotations.SentencesAnnotation.class);
List<List<Mention>> allMentions;
if (gold) {
allMentions = document.goldOrderedMentionsBySentence;
} else {
allMentions = document.predictedOrderedMentionsBySentence;
}
// String filename = document.annotation.get()
StringBuilder doc = new StringBuilder();
int previousOffset = 0;
for (int i = 0; i < sentences.size(); i++) {
CoreMap sentence = sentences.get(i);
List<Mention> mentions = allMentions.get(i);
List<CoreLabel> t = sentence.get(CoreAnnotations.TokensAnnotation.class);
String[] tokens = new String[t.size()];
for (CoreLabel c : t) {
tokens[c.index() - 1] = c.word();
}
if (previousOffset + 2 < t.get(0).get(CoreAnnotations.CharacterOffsetBeginAnnotation.class)) {
doc.append("\n");
}
previousOffset = t.get(t.size() - 1).get(CoreAnnotations.CharacterOffsetEndAnnotation.class);
Counter<Integer> startCounts = new ClassicCounter<>();
Counter<Integer> endCounts = new ClassicCounter<>();
Map<Integer, Set<Mention>> endMentions = Generics.newHashMap();
for (Mention m : mentions) {
startCounts.incrementCount(m.startIndex);
endCounts.incrementCount(m.endIndex);
if (!endMentions.containsKey(m.endIndex))
endMentions.put(m.endIndex, Generics.<Mention>newHashSet());
endMentions.get(m.endIndex).add(m);
}
for (int j = 0; j < tokens.length; j++) {
if (endMentions.containsKey(j)) {
for (Mention m : endMentions.get(j)) {
int corefChainId = (gold) ? m.goldCorefClusterID : m.corefClusterID;
doc.append("]_").append(corefChainId);
}
}
for (int k = 0; k < startCounts.getCount(j); k++) {
if (doc.length() > 0 && doc.charAt(doc.length() - 1) != '[')
doc.append(" ");
doc.append("[");
}
if (doc.length() > 0 && doc.charAt(doc.length() - 1) != '[')
doc.append(" ");
doc.append(tokens[j]);
}
if (endMentions.containsKey(tokens.length)) {
for (Mention m : endMentions.get(tokens.length)) {
int corefChainId = (gold) ? m.goldCorefClusterID : m.corefClusterID;
//append("_").append(m.mentionID);
doc.append("]_").append(corefChainId);
}
}
doc.append("\n");
}
logger.fine(document.annotation.get(CoreAnnotations.DocIDAnnotation.class));
if (gold) {
logger.fine("New DOC: (GOLD MENTIONS) ==================================================");
} else {
logger.fine("New DOC: (Predicted Mentions) ==================================================");
}
logger.fine(doc.toString());
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class GenericDataSetReader method modifyUsingCoreNLPNER.
private void modifyUsingCoreNLPNER(Annotation doc) {
Properties ann = new Properties();
ann.setProperty("annotators", "pos, lemma, ner");
StanfordCoreNLP pipeline = new StanfordCoreNLP(ann, false);
pipeline.annotate(doc);
for (CoreMap sentence : doc.get(CoreAnnotations.SentencesAnnotation.class)) {
List<EntityMention> entities = sentence.get(MachineReadingAnnotations.EntityMentionsAnnotation.class);
if (entities != null) {
List<CoreLabel> tokens = sentence.get(CoreAnnotations.TokensAnnotation.class);
for (EntityMention en : entities) {
//System.out.println("old ner tag for " + en.getExtentString() + " was " + en.getType());
Span s = en.getExtent();
Counter<String> allNertagforSpan = new ClassicCounter<>();
for (int i = s.start(); i < s.end(); i++) {
allNertagforSpan.incrementCount(tokens.get(i).ner());
}
String entityNertag = Counters.argmax(allNertagforSpan);
en.setType(entityNertag);
//System.out.println("new ner tag is " + entityNertag);
}
}
}
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class RelationExtractorResultsPrinter method printResultsInternal.
private void printResultsInternal(PrintWriter pw, Counter<Pair<String, String>> results, ClassicCounter<String> labelCount) {
ClassicCounter<String> correct = new ClassicCounter<>();
ClassicCounter<String> predictionCount = new ClassicCounter<>();
boolean countGoldLabels = false;
if (labelCount == null) {
labelCount = new ClassicCounter<>();
countGoldLabels = true;
}
for (Pair<String, String> predictedActual : results.keySet()) {
String predicted = predictedActual.first;
String actual = predictedActual.second;
if (predicted.equals(actual)) {
correct.incrementCount(actual, results.getCount(predictedActual));
}
predictionCount.incrementCount(predicted, results.getCount(predictedActual));
if (countGoldLabels) {
labelCount.incrementCount(actual, results.getCount(predictedActual));
}
}
DecimalFormat formatter = new DecimalFormat();
formatter.setMaximumFractionDigits(1);
formatter.setMinimumFractionDigits(1);
double totalCount = 0;
double totalCorrect = 0;
double totalPredicted = 0;
pw.println("Label\tCorrect\tPredict\tActual\tPrecn\tRecall\tF");
List<String> labels = new ArrayList<>(labelCount.keySet());
Collections.sort(labels);
for (String label : labels) {
double numcorrect = correct.getCount(label);
double predicted = predictionCount.getCount(label);
double trueCount = labelCount.getCount(label);
double precision = (predicted > 0) ? (numcorrect / predicted) : 0;
double recall = numcorrect / trueCount;
double f = (precision + recall > 0) ? 2 * precision * recall / (precision + recall) : 0.0;
pw.println(StringUtils.padOrTrim(label, MAX_LABEL_LENGTH) + "\t" + numcorrect + "\t" + predicted + "\t" + trueCount + "\t" + formatter.format(precision * 100) + "\t" + formatter.format(100 * recall) + "\t" + formatter.format(100 * f));
if (!RelationMention.isUnrelatedLabel(label)) {
totalCount += trueCount;
totalCorrect += numcorrect;
totalPredicted += predicted;
}
}
double precision = (totalPredicted > 0) ? (totalCorrect / totalPredicted) : 0;
double recall = totalCorrect / totalCount;
double f = (totalPredicted > 0 && totalCorrect > 0) ? 2 * precision * recall / (precision + recall) : 0.0;
pw.println("Total\t" + totalCorrect + "\t" + totalPredicted + "\t" + totalCount + "\t" + formatter.format(100 * precision) + "\t" + formatter.format(100 * recall) + "\t" + formatter.format(100 * f));
}
Aggregations