Search in sources :

Example 46 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class RelationExtractorResultsPrinter method printResults.

@Override
public void printResults(PrintWriter pw, List<CoreMap> goldStandard, List<CoreMap> extractorOutput) {
    ResultsPrinter.align(goldStandard, extractorOutput);
    // the mention factory cannot be null here
    assert relationMentionFactory != null : "ERROR: RelationExtractorResultsPrinter.relationMentionFactory cannot be null in printResults!";
    // Count predicted-actual relation type pairs
    Counter<Pair<String, String>> results = new ClassicCounter<>();
    ClassicCounter<String> labelCount = new ClassicCounter<>();
    // TODO: assumes binary relations
    for (int goldSentenceIndex = 0; goldSentenceIndex < goldStandard.size(); goldSentenceIndex++) {
        for (RelationMention goldRelation : AnnotationUtils.getAllRelations(relationMentionFactory, goldStandard.get(goldSentenceIndex), createUnrelatedRelations)) {
            CoreMap extractorSentence = extractorOutput.get(goldSentenceIndex);
            List<RelationMention> extractorRelations = AnnotationUtils.getRelations(relationMentionFactory, extractorSentence, goldRelation.getArg(0), goldRelation.getArg(1));
            labelCount.incrementCount(goldRelation.getType());
            for (RelationMention extractorRelation : extractorRelations) {
                results.incrementCount(new Pair<>(extractorRelation.getType(), goldRelation.getType()));
            }
        }
    }
    printResultsInternal(pw, results, labelCount);
}
Also used : RelationMention(edu.stanford.nlp.ie.machinereading.structure.RelationMention) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) CoreMap(edu.stanford.nlp.util.CoreMap) Pair(edu.stanford.nlp.util.Pair)

Example 47 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class ClauseSplitter method train.

/**
   * Train a clause searcher factory. That is, train a classifier for which arcs should be
   * new clauses.
   *
   * @param trainingData The training data. This is a stream of triples of:
   *                     <ol>
   *                       <li>The sentence containing a known extraction.</li>
   *                       <li>The span of the subject in the sentence, as a token span.</li>
   *                       <li>The span of the object in the sentence, as a token span.</li>
   *                     </ol>
   * @param modelPath The path to save the model to. This is useful for {@link ClauseSplitter#load(String)}.
   * @param trainingDataDump The path to save the training data, as a set of labeled featurized datums.
   * @param featurizer The featurizer to use for this classifier.
   *
   * @return A factory for creating searchers from a given dependency tree.
   */
static ClauseSplitter train(Stream<Pair<CoreMap, Collection<Pair<Span, Span>>>> trainingData, Optional<File> modelPath, Optional<File> trainingDataDump, Featurizer featurizer) {
    // Parse options
    LinearClassifierFactory<ClauseClassifierLabel, String> factory = new LinearClassifierFactory<>();
    // Generally useful objects
    OpenIE openie = new OpenIE(PropertiesUtils.asProperties("splitter.nomodel", "true", "optimizefor", "GENERAL"));
    WeightedDataset<ClauseClassifierLabel, String> dataset = new WeightedDataset<>();
    AtomicInteger numExamplesProcessed = new AtomicInteger(0);
    final Optional<PrintWriter> datasetDumpWriter = trainingDataDump.map(file -> {
        try {
            return new PrintWriter(new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(trainingDataDump.get()))));
        } catch (IOException e) {
            throw new RuntimeIOException(e);
        }
    });
    // Step 1: Loop over data
    forceTrack("Training inference");
    trainingData.forEach(rawExample -> {
        CoreMap sentence = rawExample.first;
        Collection<Pair<Span, Span>> spans = rawExample.second;
        List<CoreLabel> tokens = sentence.get(CoreAnnotations.TokensAnnotation.class);
        SemanticGraph tree = sentence.get(SemanticGraphCoreAnnotations.EnhancedDependenciesAnnotation.class);
        ClauseSplitterSearchProblem problem = new ClauseSplitterSearchProblem(tree, true);
        problem.search(fragmentAndScore -> {
            List<Counter<String>> features = fragmentAndScore.second;
            SentenceFragment fragment = fragmentAndScore.third.get();
            Set<RelationTriple> extractions = new HashSet<>(openie.relationsInFragments(openie.entailmentsFromClause(fragment)));
            Trilean correct = Trilean.FALSE;
            RELATION_TRIPLE_LOOP: for (RelationTriple extraction : extractions) {
                Span subjectGuess = Span.fromValues(extraction.subject.get(0).index() - 1, extraction.subject.get(extraction.subject.size() - 1).index());
                Span objectGuess = Span.fromValues(extraction.object.get(0).index() - 1, extraction.object.get(extraction.object.size() - 1).index());
                for (Pair<Span, Span> candidateGold : spans) {
                    Span subjectSpan = candidateGold.first;
                    Span objectSpan = candidateGold.second;
                    if ((subjectGuess.equals(subjectSpan) && objectGuess.equals(objectSpan)) || (subjectGuess.equals(objectSpan) && objectGuess.equals(subjectSpan))) {
                        correct = Trilean.TRUE;
                        break RELATION_TRIPLE_LOOP;
                    } else if (Util.nerOverlap(tokens, subjectSpan, subjectGuess) && Util.nerOverlap(tokens, objectSpan, objectGuess) || Util.nerOverlap(tokens, subjectSpan, objectGuess) && Util.nerOverlap(tokens, objectSpan, subjectGuess)) {
                        if (!correct.isTrue()) {
                            correct = Trilean.TRUE;
                            break RELATION_TRIPLE_LOOP;
                        }
                    } else {
                        if (!correct.isTrue()) {
                            correct = Trilean.UNKNOWN;
                            break RELATION_TRIPLE_LOOP;
                        }
                    }
                }
            }
            if (!features.isEmpty()) {
                List<Pair<Counter<String>, ClauseClassifierLabel>> decisionsToAddAsDatums = new ArrayList<>();
                if (correct.isTrue()) {
                    for (int i = 0; i < features.size(); ++i) {
                        if (i == features.size() - 1) {
                            decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT));
                        } else {
                            decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM));
                        }
                    }
                } else if (correct.isFalse()) {
                    decisionsToAddAsDatums.add(Pair.makePair(features.get(features.size() - 1), ClauseClassifierLabel.NOT_A_CLAUSE));
                } else if (correct.isUnknown()) {
                    boolean isSimpleSplit = false;
                    for (Counter<String> feats : features) {
                        if (featurizer.isSimpleSplit(feats)) {
                            isSimpleSplit = true;
                            break;
                        }
                    }
                    if (isSimpleSplit) {
                        for (int i = 0; i < features.size(); ++i) {
                            if (i == features.size() - 1) {
                                decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT));
                            } else {
                                decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM));
                            }
                        }
                    }
                }
                for (Pair<Counter<String>, ClauseClassifierLabel> decision : decisionsToAddAsDatums) {
                    RVFDatum<ClauseClassifierLabel, String> datum = new RVFDatum<>(decision.first);
                    datum.setLabel(decision.second);
                    if (datasetDumpWriter.isPresent()) {
                        datasetDumpWriter.get().println(decision.second + "\t" + StringUtils.join(decision.first.entrySet().stream().map(entry -> entry.getKey() + "->" + entry.getValue()), ";"));
                    }
                    dataset.add(datum);
                }
            }
            return true;
        }, new LinearClassifier<>(new ClassicCounter<>()), Collections.emptyMap(), featurizer, 10000);
        if (numExamplesProcessed.incrementAndGet() % 100 == 0) {
            log("processed " + numExamplesProcessed + " training sentences: " + dataset.size() + " datums");
        }
    });
    endTrack("Training inference");
    // Close the file
    if (datasetDumpWriter.isPresent()) {
        datasetDumpWriter.get().close();
    }
    // Step 2: Train classifier
    forceTrack("Training");
    Classifier<ClauseClassifierLabel, String> fullClassifier = factory.trainClassifier(dataset);
    endTrack("Training");
    if (modelPath.isPresent()) {
        Pair<Classifier<ClauseClassifierLabel, String>, Featurizer> toSave = Pair.makePair(fullClassifier, featurizer);
        try {
            IOUtils.writeObjectToFile(toSave, modelPath.get());
            log("SUCCESS: wrote model to " + modelPath.get().getPath());
        } catch (IOException e) {
            log("ERROR: failed to save model to path: " + modelPath.get().getPath());
            err(e);
        }
    }
    // Step 3: Check accuracy of classifier
    forceTrack("Training accuracy");
    dataset.randomize(42L);
    Util.dumpAccuracy(fullClassifier, dataset);
    endTrack("Training accuracy");
    int numFolds = 5;
    forceTrack(numFolds + " fold cross-validation");
    for (int fold = 0; fold < numFolds; ++fold) {
        forceTrack("Fold " + (fold + 1));
        forceTrack("Training");
        Pair<GeneralDataset<ClauseClassifierLabel, String>, GeneralDataset<ClauseClassifierLabel, String>> foldData = dataset.splitOutFold(fold, numFolds);
        Classifier<ClauseClassifierLabel, String> classifier = factory.trainClassifier(foldData.first);
        endTrack("Training");
        forceTrack("Test");
        Util.dumpAccuracy(classifier, foldData.second);
        endTrack("Test");
        endTrack("Fold " + (fold + 1));
    }
    endTrack(numFolds + " fold cross-validation");
    // Step 5: return factory
    return (tree, truth) -> new ClauseSplitterSearchProblem(tree, truth, Optional.of(fullClassifier), Optional.of(featurizer));
}
Also used : CoreLabel(edu.stanford.nlp.ling.CoreLabel) java.util(java.util) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) IOUtils(edu.stanford.nlp.io.IOUtils) edu.stanford.nlp.util(edu.stanford.nlp.util) BiFunction(java.util.function.BiFunction) Redwood(edu.stanford.nlp.util.logging.Redwood) Util(edu.stanford.nlp.util.logging.Redwood.Util) Span(edu.stanford.nlp.ie.machinereading.structure.Span) Counter(edu.stanford.nlp.stats.Counter) Stream(java.util.stream.Stream) java.io(java.io) SemanticGraphCoreAnnotations(edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) edu.stanford.nlp.classify(edu.stanford.nlp.classify) RelationTriple(edu.stanford.nlp.ie.util.RelationTriple) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) GZIPOutputStream(java.util.zip.GZIPOutputStream) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) RVFDatum(edu.stanford.nlp.ling.RVFDatum) ClauseSplitterSearchProblem(edu.stanford.nlp.naturalli.ClauseSplitterSearchProblem) Counter(edu.stanford.nlp.stats.Counter) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) GZIPOutputStream(java.util.zip.GZIPOutputStream) RelationTriple(edu.stanford.nlp.ie.util.RelationTriple) ClauseSplitterSearchProblem(edu.stanford.nlp.naturalli.ClauseSplitterSearchProblem) SemanticGraphCoreAnnotations(edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations) CoreLabel(edu.stanford.nlp.ling.CoreLabel) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) Span(edu.stanford.nlp.ie.machinereading.structure.Span) RVFDatum(edu.stanford.nlp.ling.RVFDatum) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) SemanticGraphCoreAnnotations(edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter)

Example 48 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class FactoredLexicon method main.

/**
   * @param args
   */
public static void main(String[] args) {
    if (args.length != 4) {
        System.err.printf("Usage: java %s language features train_file dev_file%n", FactoredLexicon.class.getName());
        System.exit(-1);
    }
    // Command line options
    Language language = Language.valueOf(args[0]);
    TreebankLangParserParams tlpp = language.params;
    Treebank trainTreebank = tlpp.diskTreebank();
    trainTreebank.loadPath(args[2]);
    Treebank devTreebank = tlpp.diskTreebank();
    devTreebank.loadPath(args[3]);
    MorphoFeatureSpecification morphoSpec;
    Options options = getOptions(language);
    if (language.equals(Language.Arabic)) {
        morphoSpec = new ArabicMorphoFeatureSpecification();
        String[] languageOptions = { "-arabicFactored" };
        tlpp.setOptionFlag(languageOptions, 0);
    } else if (language.equals(Language.French)) {
        morphoSpec = new FrenchMorphoFeatureSpecification();
        String[] languageOptions = { "-frenchFactored" };
        tlpp.setOptionFlag(languageOptions, 0);
    } else {
        throw new UnsupportedOperationException();
    }
    String featureList = args[1];
    String[] features = featureList.trim().split(",");
    for (String feature : features) {
        morphoSpec.activate(MorphoFeatureType.valueOf(feature));
    }
    System.out.println("Language: " + language.toString());
    System.out.println("Features: " + args[1]);
    // Create word and tag indices
    // Save trees in a collection since the interface requires that....
    System.out.print("Loading training trees...");
    List<Tree> trainTrees = new ArrayList<>(19000);
    Index<String> wordIndex = new HashIndex<>();
    Index<String> tagIndex = new HashIndex<>();
    for (Tree tree : trainTreebank) {
        for (Tree subTree : tree) {
            if (!subTree.isLeaf()) {
                tlpp.transformTree(subTree, tree);
            }
        }
        trainTrees.add(tree);
    }
    System.out.printf("Done! (%d trees)%n", trainTrees.size());
    // Setup and train the lexicon.
    System.out.print("Collecting sufficient statistics for lexicon...");
    FactoredLexicon lexicon = new FactoredLexicon(options, morphoSpec, wordIndex, tagIndex);
    lexicon.initializeTraining(trainTrees.size());
    lexicon.train(trainTrees, null);
    lexicon.finishTraining();
    System.out.println("Done!");
    trainTrees = null;
    // Load the tuning set
    System.out.print("Loading tuning set...");
    List<FactoredLexiconEvent> tuningSet = getTuningSet(devTreebank, lexicon, tlpp);
    System.out.printf("...Done! (%d events)%n", tuningSet.size());
    // Print the probabilities that we obtain
    // TODO(spenceg): Implement tagging accuracy with FactLex
    int nCorrect = 0;
    Counter<String> errors = new ClassicCounter<>();
    for (FactoredLexiconEvent event : tuningSet) {
        Iterator<IntTaggedWord> itr = lexicon.ruleIteratorByWord(event.word(), event.getLoc(), event.featureStr());
        Counter<Integer> logScores = new ClassicCounter<>();
        boolean noRules = true;
        int goldTagId = -1;
        while (itr.hasNext()) {
            noRules = false;
            IntTaggedWord iTW = itr.next();
            if (iTW.tag() == event.tagId()) {
                log.info("GOLD-");
                goldTagId = iTW.tag();
            }
            float tagScore = lexicon.score(iTW, event.getLoc(), event.word(), event.featureStr());
            logScores.incrementCount(iTW.tag(), tagScore);
        }
        if (noRules) {
            System.err.printf("NO TAGGINGS: %s %s%n", event.word(), event.featureStr());
        } else {
            // Score the tagging
            int hypTagId = Counters.argmax(logScores);
            if (hypTagId == goldTagId) {
                ++nCorrect;
            } else {
                String goldTag = goldTagId < 0 ? "UNSEEN" : lexicon.tagIndex.get(goldTagId);
                errors.incrementCount(goldTag);
            }
        }
        log.info();
    }
    // Output accuracy
    double acc = (double) nCorrect / (double) tuningSet.size();
    System.err.printf("%n%nACCURACY: %.2f%n%n", acc * 100.0);
    log.info("% of errors by type:");
    List<String> biggestKeys = new ArrayList<>(errors.keySet());
    Collections.sort(biggestKeys, Counters.toComparator(errors, false, true));
    Counters.normalize(errors);
    for (String key : biggestKeys) {
        System.err.printf("%s\t%.2f%n", key, errors.getCount(key) * 100.0);
    }
}
Also used : FrenchMorphoFeatureSpecification(edu.stanford.nlp.international.french.FrenchMorphoFeatureSpecification) Treebank(edu.stanford.nlp.trees.Treebank) ArrayList(java.util.ArrayList) Language(edu.stanford.nlp.international.Language) Tree(edu.stanford.nlp.trees.Tree) ArabicMorphoFeatureSpecification(edu.stanford.nlp.international.arabic.ArabicMorphoFeatureSpecification) MorphoFeatureSpecification(edu.stanford.nlp.international.morph.MorphoFeatureSpecification) ArabicMorphoFeatureSpecification(edu.stanford.nlp.international.arabic.ArabicMorphoFeatureSpecification) FrenchMorphoFeatureSpecification(edu.stanford.nlp.international.french.FrenchMorphoFeatureSpecification) HashIndex(edu.stanford.nlp.util.HashIndex) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter)

Example 49 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class SisterAnnotationStats method sideCounters.

protected void sideCounters(String label, List rewrite, List sideSisters, Map sideRules) {
    for (Object sideSister : sideSisters) {
        String sis = (String) sideSister;
        if (!((Map) sideRules.get(label)).containsKey(sis)) {
            ((Map) sideRules.get(label)).put(sis, new ClassicCounter());
        }
        ((ClassicCounter) ((HashMap) sideRules.get(label)).get(sis)).incrementCount(rewrite);
    }
}
Also used : ClassicCounter(edu.stanford.nlp.stats.ClassicCounter)

Example 50 with ClassicCounter

use of edu.stanford.nlp.stats.ClassicCounter in project CoreNLP by stanfordnlp.

the class LearnImportantFeatures method getDatum.

private RVFDatum<String, String> getDatum(CoreLabel[] sent, int i) {
    Counter<String> feat = new ClassicCounter<>();
    CoreLabel l = sent[i];
    String label;
    if (l.get(answerClass).toString().equals(answerLabel))
        label = answerLabel;
    else
        label = "O";
    CollectionValuedMap<String, CandidatePhrase> matchedPhrases = l.get(PatternsAnnotations.MatchedPhrases.class);
    if (matchedPhrases == null) {
        matchedPhrases = new CollectionValuedMap<>();
        matchedPhrases.add(label, CandidatePhrase.createOrGet(l.word()));
    }
    for (CandidatePhrase w : matchedPhrases.allValues()) {
        Integer num = this.clusterIds.get(w.getPhrase());
        if (num == null)
            num = -1;
        feat.setCount("Cluster-" + num, 1.0);
    }
    // feat.incrementCount("WORD-" + l.word());
    // feat.incrementCount("LEMMA-" + l.lemma());
    // feat.incrementCount("TAG-" + l.tag());
    int window = 0;
    for (int j = Math.max(0, i - window); j < i; j++) {
        CoreLabel lj = sent[j];
        feat.incrementCount("PREV-" + "WORD-" + lj.word());
        feat.incrementCount("PREV-" + "LEMMA-" + lj.lemma());
        feat.incrementCount("PREV-" + "TAG-" + lj.tag());
    }
    for (int j = i + 1; j < sent.length && j <= i + window; j++) {
        CoreLabel lj = sent[j];
        feat.incrementCount("NEXT-" + "WORD-" + lj.word());
        feat.incrementCount("NEXT-" + "LEMMA-" + lj.lemma());
        feat.incrementCount("NEXT-" + "TAG-" + lj.tag());
    }
    // System.out.println("adding " + l.word() + " as " + label);
    return new RVFDatum<>(feat, label);
}
Also used : CoreLabel(edu.stanford.nlp.ling.CoreLabel) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) RVFDatum(edu.stanford.nlp.ling.RVFDatum) PatternsAnnotations(edu.stanford.nlp.patterns.PatternsAnnotations) CandidatePhrase(edu.stanford.nlp.patterns.CandidatePhrase)

Aggregations

ClassicCounter (edu.stanford.nlp.stats.ClassicCounter)69 CoreLabel (edu.stanford.nlp.ling.CoreLabel)27 ArrayList (java.util.ArrayList)21 CoreAnnotations (edu.stanford.nlp.ling.CoreAnnotations)18 Tree (edu.stanford.nlp.trees.Tree)13 Pair (edu.stanford.nlp.util.Pair)11 Counter (edu.stanford.nlp.stats.Counter)10 List (java.util.List)10 Mention (edu.stanford.nlp.coref.data.Mention)8 Language (edu.stanford.nlp.international.Language)7 RuntimeIOException (edu.stanford.nlp.io.RuntimeIOException)7 CoreMap (edu.stanford.nlp.util.CoreMap)7 IOUtils (edu.stanford.nlp.io.IOUtils)6 Label (edu.stanford.nlp.ling.Label)6 TreebankLangParserParams (edu.stanford.nlp.parser.lexparser.TreebankLangParserParams)6 PrintWriter (java.io.PrintWriter)6 java.util (java.util)6 HashSet (java.util.HashSet)6 RVFDatum (edu.stanford.nlp.ling.RVFDatum)5 DiskTreebank (edu.stanford.nlp.trees.DiskTreebank)5