use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class RelationExtractorResultsPrinter method printResults.
@Override
public void printResults(PrintWriter pw, List<CoreMap> goldStandard, List<CoreMap> extractorOutput) {
ResultsPrinter.align(goldStandard, extractorOutput);
// the mention factory cannot be null here
assert relationMentionFactory != null : "ERROR: RelationExtractorResultsPrinter.relationMentionFactory cannot be null in printResults!";
// Count predicted-actual relation type pairs
Counter<Pair<String, String>> results = new ClassicCounter<>();
ClassicCounter<String> labelCount = new ClassicCounter<>();
// TODO: assumes binary relations
for (int goldSentenceIndex = 0; goldSentenceIndex < goldStandard.size(); goldSentenceIndex++) {
for (RelationMention goldRelation : AnnotationUtils.getAllRelations(relationMentionFactory, goldStandard.get(goldSentenceIndex), createUnrelatedRelations)) {
CoreMap extractorSentence = extractorOutput.get(goldSentenceIndex);
List<RelationMention> extractorRelations = AnnotationUtils.getRelations(relationMentionFactory, extractorSentence, goldRelation.getArg(0), goldRelation.getArg(1));
labelCount.incrementCount(goldRelation.getType());
for (RelationMention extractorRelation : extractorRelations) {
results.incrementCount(new Pair<>(extractorRelation.getType(), goldRelation.getType()));
}
}
}
printResultsInternal(pw, results, labelCount);
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class ClauseSplitter method train.
/**
* Train a clause searcher factory. That is, train a classifier for which arcs should be
* new clauses.
*
* @param trainingData The training data. This is a stream of triples of:
* <ol>
* <li>The sentence containing a known extraction.</li>
* <li>The span of the subject in the sentence, as a token span.</li>
* <li>The span of the object in the sentence, as a token span.</li>
* </ol>
* @param modelPath The path to save the model to. This is useful for {@link ClauseSplitter#load(String)}.
* @param trainingDataDump The path to save the training data, as a set of labeled featurized datums.
* @param featurizer The featurizer to use for this classifier.
*
* @return A factory for creating searchers from a given dependency tree.
*/
static ClauseSplitter train(Stream<Pair<CoreMap, Collection<Pair<Span, Span>>>> trainingData, Optional<File> modelPath, Optional<File> trainingDataDump, Featurizer featurizer) {
// Parse options
LinearClassifierFactory<ClauseClassifierLabel, String> factory = new LinearClassifierFactory<>();
// Generally useful objects
OpenIE openie = new OpenIE(PropertiesUtils.asProperties("splitter.nomodel", "true", "optimizefor", "GENERAL"));
WeightedDataset<ClauseClassifierLabel, String> dataset = new WeightedDataset<>();
AtomicInteger numExamplesProcessed = new AtomicInteger(0);
final Optional<PrintWriter> datasetDumpWriter = trainingDataDump.map(file -> {
try {
return new PrintWriter(new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(trainingDataDump.get()))));
} catch (IOException e) {
throw new RuntimeIOException(e);
}
});
// Step 1: Loop over data
forceTrack("Training inference");
trainingData.forEach(rawExample -> {
CoreMap sentence = rawExample.first;
Collection<Pair<Span, Span>> spans = rawExample.second;
List<CoreLabel> tokens = sentence.get(CoreAnnotations.TokensAnnotation.class);
SemanticGraph tree = sentence.get(SemanticGraphCoreAnnotations.EnhancedDependenciesAnnotation.class);
ClauseSplitterSearchProblem problem = new ClauseSplitterSearchProblem(tree, true);
problem.search(fragmentAndScore -> {
List<Counter<String>> features = fragmentAndScore.second;
SentenceFragment fragment = fragmentAndScore.third.get();
Set<RelationTriple> extractions = new HashSet<>(openie.relationsInFragments(openie.entailmentsFromClause(fragment)));
Trilean correct = Trilean.FALSE;
RELATION_TRIPLE_LOOP: for (RelationTriple extraction : extractions) {
Span subjectGuess = Span.fromValues(extraction.subject.get(0).index() - 1, extraction.subject.get(extraction.subject.size() - 1).index());
Span objectGuess = Span.fromValues(extraction.object.get(0).index() - 1, extraction.object.get(extraction.object.size() - 1).index());
for (Pair<Span, Span> candidateGold : spans) {
Span subjectSpan = candidateGold.first;
Span objectSpan = candidateGold.second;
if ((subjectGuess.equals(subjectSpan) && objectGuess.equals(objectSpan)) || (subjectGuess.equals(objectSpan) && objectGuess.equals(subjectSpan))) {
correct = Trilean.TRUE;
break RELATION_TRIPLE_LOOP;
} else if (Util.nerOverlap(tokens, subjectSpan, subjectGuess) && Util.nerOverlap(tokens, objectSpan, objectGuess) || Util.nerOverlap(tokens, subjectSpan, objectGuess) && Util.nerOverlap(tokens, objectSpan, subjectGuess)) {
if (!correct.isTrue()) {
correct = Trilean.TRUE;
break RELATION_TRIPLE_LOOP;
}
} else {
if (!correct.isTrue()) {
correct = Trilean.UNKNOWN;
break RELATION_TRIPLE_LOOP;
}
}
}
}
if (!features.isEmpty()) {
List<Pair<Counter<String>, ClauseClassifierLabel>> decisionsToAddAsDatums = new ArrayList<>();
if (correct.isTrue()) {
for (int i = 0; i < features.size(); ++i) {
if (i == features.size() - 1) {
decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT));
} else {
decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM));
}
}
} else if (correct.isFalse()) {
decisionsToAddAsDatums.add(Pair.makePair(features.get(features.size() - 1), ClauseClassifierLabel.NOT_A_CLAUSE));
} else if (correct.isUnknown()) {
boolean isSimpleSplit = false;
for (Counter<String> feats : features) {
if (featurizer.isSimpleSplit(feats)) {
isSimpleSplit = true;
break;
}
}
if (isSimpleSplit) {
for (int i = 0; i < features.size(); ++i) {
if (i == features.size() - 1) {
decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT));
} else {
decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM));
}
}
}
}
for (Pair<Counter<String>, ClauseClassifierLabel> decision : decisionsToAddAsDatums) {
RVFDatum<ClauseClassifierLabel, String> datum = new RVFDatum<>(decision.first);
datum.setLabel(decision.second);
if (datasetDumpWriter.isPresent()) {
datasetDumpWriter.get().println(decision.second + "\t" + StringUtils.join(decision.first.entrySet().stream().map(entry -> entry.getKey() + "->" + entry.getValue()), ";"));
}
dataset.add(datum);
}
}
return true;
}, new LinearClassifier<>(new ClassicCounter<>()), Collections.emptyMap(), featurizer, 10000);
if (numExamplesProcessed.incrementAndGet() % 100 == 0) {
log("processed " + numExamplesProcessed + " training sentences: " + dataset.size() + " datums");
}
});
endTrack("Training inference");
// Close the file
if (datasetDumpWriter.isPresent()) {
datasetDumpWriter.get().close();
}
// Step 2: Train classifier
forceTrack("Training");
Classifier<ClauseClassifierLabel, String> fullClassifier = factory.trainClassifier(dataset);
endTrack("Training");
if (modelPath.isPresent()) {
Pair<Classifier<ClauseClassifierLabel, String>, Featurizer> toSave = Pair.makePair(fullClassifier, featurizer);
try {
IOUtils.writeObjectToFile(toSave, modelPath.get());
log("SUCCESS: wrote model to " + modelPath.get().getPath());
} catch (IOException e) {
log("ERROR: failed to save model to path: " + modelPath.get().getPath());
err(e);
}
}
// Step 3: Check accuracy of classifier
forceTrack("Training accuracy");
dataset.randomize(42L);
Util.dumpAccuracy(fullClassifier, dataset);
endTrack("Training accuracy");
int numFolds = 5;
forceTrack(numFolds + " fold cross-validation");
for (int fold = 0; fold < numFolds; ++fold) {
forceTrack("Fold " + (fold + 1));
forceTrack("Training");
Pair<GeneralDataset<ClauseClassifierLabel, String>, GeneralDataset<ClauseClassifierLabel, String>> foldData = dataset.splitOutFold(fold, numFolds);
Classifier<ClauseClassifierLabel, String> classifier = factory.trainClassifier(foldData.first);
endTrack("Training");
forceTrack("Test");
Util.dumpAccuracy(classifier, foldData.second);
endTrack("Test");
endTrack("Fold " + (fold + 1));
}
endTrack(numFolds + " fold cross-validation");
// Step 5: return factory
return (tree, truth) -> new ClauseSplitterSearchProblem(tree, truth, Optional.of(fullClassifier), Optional.of(featurizer));
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class FactoredLexicon method main.
/**
* @param args
*/
public static void main(String[] args) {
if (args.length != 4) {
System.err.printf("Usage: java %s language features train_file dev_file%n", FactoredLexicon.class.getName());
System.exit(-1);
}
// Command line options
Language language = Language.valueOf(args[0]);
TreebankLangParserParams tlpp = language.params;
Treebank trainTreebank = tlpp.diskTreebank();
trainTreebank.loadPath(args[2]);
Treebank devTreebank = tlpp.diskTreebank();
devTreebank.loadPath(args[3]);
MorphoFeatureSpecification morphoSpec;
Options options = getOptions(language);
if (language.equals(Language.Arabic)) {
morphoSpec = new ArabicMorphoFeatureSpecification();
String[] languageOptions = { "-arabicFactored" };
tlpp.setOptionFlag(languageOptions, 0);
} else if (language.equals(Language.French)) {
morphoSpec = new FrenchMorphoFeatureSpecification();
String[] languageOptions = { "-frenchFactored" };
tlpp.setOptionFlag(languageOptions, 0);
} else {
throw new UnsupportedOperationException();
}
String featureList = args[1];
String[] features = featureList.trim().split(",");
for (String feature : features) {
morphoSpec.activate(MorphoFeatureType.valueOf(feature));
}
System.out.println("Language: " + language.toString());
System.out.println("Features: " + args[1]);
// Create word and tag indices
// Save trees in a collection since the interface requires that....
System.out.print("Loading training trees...");
List<Tree> trainTrees = new ArrayList<>(19000);
Index<String> wordIndex = new HashIndex<>();
Index<String> tagIndex = new HashIndex<>();
for (Tree tree : trainTreebank) {
for (Tree subTree : tree) {
if (!subTree.isLeaf()) {
tlpp.transformTree(subTree, tree);
}
}
trainTrees.add(tree);
}
System.out.printf("Done! (%d trees)%n", trainTrees.size());
// Setup and train the lexicon.
System.out.print("Collecting sufficient statistics for lexicon...");
FactoredLexicon lexicon = new FactoredLexicon(options, morphoSpec, wordIndex, tagIndex);
lexicon.initializeTraining(trainTrees.size());
lexicon.train(trainTrees, null);
lexicon.finishTraining();
System.out.println("Done!");
trainTrees = null;
// Load the tuning set
System.out.print("Loading tuning set...");
List<FactoredLexiconEvent> tuningSet = getTuningSet(devTreebank, lexicon, tlpp);
System.out.printf("...Done! (%d events)%n", tuningSet.size());
// Print the probabilities that we obtain
// TODO(spenceg): Implement tagging accuracy with FactLex
int nCorrect = 0;
Counter<String> errors = new ClassicCounter<>();
for (FactoredLexiconEvent event : tuningSet) {
Iterator<IntTaggedWord> itr = lexicon.ruleIteratorByWord(event.word(), event.getLoc(), event.featureStr());
Counter<Integer> logScores = new ClassicCounter<>();
boolean noRules = true;
int goldTagId = -1;
while (itr.hasNext()) {
noRules = false;
IntTaggedWord iTW = itr.next();
if (iTW.tag() == event.tagId()) {
log.info("GOLD-");
goldTagId = iTW.tag();
}
float tagScore = lexicon.score(iTW, event.getLoc(), event.word(), event.featureStr());
logScores.incrementCount(iTW.tag(), tagScore);
}
if (noRules) {
System.err.printf("NO TAGGINGS: %s %s%n", event.word(), event.featureStr());
} else {
// Score the tagging
int hypTagId = Counters.argmax(logScores);
if (hypTagId == goldTagId) {
++nCorrect;
} else {
String goldTag = goldTagId < 0 ? "UNSEEN" : lexicon.tagIndex.get(goldTagId);
errors.incrementCount(goldTag);
}
}
log.info();
}
// Output accuracy
double acc = (double) nCorrect / (double) tuningSet.size();
System.err.printf("%n%nACCURACY: %.2f%n%n", acc * 100.0);
log.info("% of errors by type:");
List<String> biggestKeys = new ArrayList<>(errors.keySet());
Collections.sort(biggestKeys, Counters.toComparator(errors, false, true));
Counters.normalize(errors);
for (String key : biggestKeys) {
System.err.printf("%s\t%.2f%n", key, errors.getCount(key) * 100.0);
}
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class SisterAnnotationStats method sideCounters.
protected void sideCounters(String label, List rewrite, List sideSisters, Map sideRules) {
for (Object sideSister : sideSisters) {
String sis = (String) sideSister;
if (!((Map) sideRules.get(label)).containsKey(sis)) {
((Map) sideRules.get(label)).put(sis, new ClassicCounter());
}
((ClassicCounter) ((HashMap) sideRules.get(label)).get(sis)).incrementCount(rewrite);
}
}
use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.
the class LearnImportantFeatures method getDatum.
private RVFDatum<String, String> getDatum(CoreLabel[] sent, int i) {
Counter<String> feat = new ClassicCounter<>();
CoreLabel l = sent[i];
String label;
if (l.get(answerClass).toString().equals(answerLabel))
label = answerLabel;
else
label = "O";
CollectionValuedMap<String, CandidatePhrase> matchedPhrases = l.get(PatternsAnnotations.MatchedPhrases.class);
if (matchedPhrases == null) {
matchedPhrases = new CollectionValuedMap<>();
matchedPhrases.add(label, CandidatePhrase.createOrGet(l.word()));
}
for (CandidatePhrase w : matchedPhrases.allValues()) {
Integer num = this.clusterIds.get(w.getPhrase());
if (num == null)
num = -1;
feat.setCount("Cluster-" + num, 1.0);
}
// feat.incrementCount("WORD-" + l.word());
// feat.incrementCount("LEMMA-" + l.lemma());
// feat.incrementCount("TAG-" + l.tag());
int window = 0;
for (int j = Math.max(0, i - window); j < i; j++) {
CoreLabel lj = sent[j];
feat.incrementCount("PREV-" + "WORD-" + lj.word());
feat.incrementCount("PREV-" + "LEMMA-" + lj.lemma());
feat.incrementCount("PREV-" + "TAG-" + lj.tag());
}
for (int j = i + 1; j < sent.length && j <= i + window; j++) {
CoreLabel lj = sent[j];
feat.incrementCount("NEXT-" + "WORD-" + lj.word());
feat.incrementCount("NEXT-" + "LEMMA-" + lj.lemma());
feat.incrementCount("NEXT-" + "TAG-" + lj.tag());
}
// System.out.println("adding " + l.word() + " as " + label);
return new RVFDatum<>(feat, label);
}
Aggregations