Search in sources :

Example 1 with ClauseSplitterSearchProblem

use of edu.stanford.nlp.naturalli.ClauseSplitterSearchProblem in project CoreNLP by stanfordnlp.

the class ClauseSplitter method train.

/**
   * Train a clause searcher factory. That is, train a classifier for which arcs should be
   * new clauses.
   *
   * @param trainingData The training data. This is a stream of triples of:
   *                     <ol>
   *                       <li>The sentence containing a known extraction.</li>
   *                       <li>The span of the subject in the sentence, as a token span.</li>
   *                       <li>The span of the object in the sentence, as a token span.</li>
   *                     </ol>
   * @param modelPath The path to save the model to. This is useful for {@link ClauseSplitter#load(String)}.
   * @param trainingDataDump The path to save the training data, as a set of labeled featurized datums.
   * @param featurizer The featurizer to use for this classifier.
   *
   * @return A factory for creating searchers from a given dependency tree.
   */
static ClauseSplitter train(Stream<Pair<CoreMap, Collection<Pair<Span, Span>>>> trainingData, Optional<File> modelPath, Optional<File> trainingDataDump, Featurizer featurizer) {
    // Parse options
    LinearClassifierFactory<ClauseClassifierLabel, String> factory = new LinearClassifierFactory<>();
    // Generally useful objects
    OpenIE openie = new OpenIE(PropertiesUtils.asProperties("splitter.nomodel", "true", "optimizefor", "GENERAL"));
    WeightedDataset<ClauseClassifierLabel, String> dataset = new WeightedDataset<>();
    AtomicInteger numExamplesProcessed = new AtomicInteger(0);
    final Optional<PrintWriter> datasetDumpWriter = trainingDataDump.map(file -> {
        try {
            return new PrintWriter(new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(trainingDataDump.get()))));
        } catch (IOException e) {
            throw new RuntimeIOException(e);
        }
    });
    // Step 1: Loop over data
    forceTrack("Training inference");
    trainingData.forEach(rawExample -> {
        CoreMap sentence = rawExample.first;
        Collection<Pair<Span, Span>> spans = rawExample.second;
        List<CoreLabel> tokens = sentence.get(CoreAnnotations.TokensAnnotation.class);
        SemanticGraph tree = sentence.get(SemanticGraphCoreAnnotations.EnhancedDependenciesAnnotation.class);
        ClauseSplitterSearchProblem problem = new ClauseSplitterSearchProblem(tree, true);
        problem.search(fragmentAndScore -> {
            List<Counter<String>> features = fragmentAndScore.second;
            SentenceFragment fragment = fragmentAndScore.third.get();
            Set<RelationTriple> extractions = new HashSet<>(openie.relationsInFragments(openie.entailmentsFromClause(fragment)));
            Trilean correct = Trilean.FALSE;
            RELATION_TRIPLE_LOOP: for (RelationTriple extraction : extractions) {
                Span subjectGuess = Span.fromValues(extraction.subject.get(0).index() - 1, extraction.subject.get(extraction.subject.size() - 1).index());
                Span objectGuess = Span.fromValues(extraction.object.get(0).index() - 1, extraction.object.get(extraction.object.size() - 1).index());
                for (Pair<Span, Span> candidateGold : spans) {
                    Span subjectSpan = candidateGold.first;
                    Span objectSpan = candidateGold.second;
                    if ((subjectGuess.equals(subjectSpan) && objectGuess.equals(objectSpan)) || (subjectGuess.equals(objectSpan) && objectGuess.equals(subjectSpan))) {
                        correct = Trilean.TRUE;
                        break RELATION_TRIPLE_LOOP;
                    } else if (Util.nerOverlap(tokens, subjectSpan, subjectGuess) && Util.nerOverlap(tokens, objectSpan, objectGuess) || Util.nerOverlap(tokens, subjectSpan, objectGuess) && Util.nerOverlap(tokens, objectSpan, subjectGuess)) {
                        if (!correct.isTrue()) {
                            correct = Trilean.TRUE;
                            break RELATION_TRIPLE_LOOP;
                        }
                    } else {
                        if (!correct.isTrue()) {
                            correct = Trilean.UNKNOWN;
                            break RELATION_TRIPLE_LOOP;
                        }
                    }
                }
            }
            if (!features.isEmpty()) {
                List<Pair<Counter<String>, ClauseClassifierLabel>> decisionsToAddAsDatums = new ArrayList<>();
                if (correct.isTrue()) {
                    for (int i = 0; i < features.size(); ++i) {
                        if (i == features.size() - 1) {
                            decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT));
                        } else {
                            decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM));
                        }
                    }
                } else if (correct.isFalse()) {
                    decisionsToAddAsDatums.add(Pair.makePair(features.get(features.size() - 1), ClauseClassifierLabel.NOT_A_CLAUSE));
                } else if (correct.isUnknown()) {
                    boolean isSimpleSplit = false;
                    for (Counter<String> feats : features) {
                        if (featurizer.isSimpleSplit(feats)) {
                            isSimpleSplit = true;
                            break;
                        }
                    }
                    if (isSimpleSplit) {
                        for (int i = 0; i < features.size(); ++i) {
                            if (i == features.size() - 1) {
                                decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT));
                            } else {
                                decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM));
                            }
                        }
                    }
                }
                for (Pair<Counter<String>, ClauseClassifierLabel> decision : decisionsToAddAsDatums) {
                    RVFDatum<ClauseClassifierLabel, String> datum = new RVFDatum<>(decision.first);
                    datum.setLabel(decision.second);
                    if (datasetDumpWriter.isPresent()) {
                        datasetDumpWriter.get().println(decision.second + "\t" + StringUtils.join(decision.first.entrySet().stream().map(entry -> entry.getKey() + "->" + entry.getValue()), ";"));
                    }
                    dataset.add(datum);
                }
            }
            return true;
        }, new LinearClassifier<>(new ClassicCounter<>()), Collections.emptyMap(), featurizer, 10000);
        if (numExamplesProcessed.incrementAndGet() % 100 == 0) {
            log("processed " + numExamplesProcessed + " training sentences: " + dataset.size() + " datums");
        }
    });
    endTrack("Training inference");
    // Close the file
    if (datasetDumpWriter.isPresent()) {
        datasetDumpWriter.get().close();
    }
    // Step 2: Train classifier
    forceTrack("Training");
    Classifier<ClauseClassifierLabel, String> fullClassifier = factory.trainClassifier(dataset);
    endTrack("Training");
    if (modelPath.isPresent()) {
        Pair<Classifier<ClauseClassifierLabel, String>, Featurizer> toSave = Pair.makePair(fullClassifier, featurizer);
        try {
            IOUtils.writeObjectToFile(toSave, modelPath.get());
            log("SUCCESS: wrote model to " + modelPath.get().getPath());
        } catch (IOException e) {
            log("ERROR: failed to save model to path: " + modelPath.get().getPath());
            err(e);
        }
    }
    // Step 3: Check accuracy of classifier
    forceTrack("Training accuracy");
    dataset.randomize(42L);
    Util.dumpAccuracy(fullClassifier, dataset);
    endTrack("Training accuracy");
    int numFolds = 5;
    forceTrack(numFolds + " fold cross-validation");
    for (int fold = 0; fold < numFolds; ++fold) {
        forceTrack("Fold " + (fold + 1));
        forceTrack("Training");
        Pair<GeneralDataset<ClauseClassifierLabel, String>, GeneralDataset<ClauseClassifierLabel, String>> foldData = dataset.splitOutFold(fold, numFolds);
        Classifier<ClauseClassifierLabel, String> classifier = factory.trainClassifier(foldData.first);
        endTrack("Training");
        forceTrack("Test");
        Util.dumpAccuracy(classifier, foldData.second);
        endTrack("Test");
        endTrack("Fold " + (fold + 1));
    }
    endTrack(numFolds + " fold cross-validation");
    // Step 5: return factory
    return (tree, truth) -> new ClauseSplitterSearchProblem(tree, truth, Optional.of(fullClassifier), Optional.of(featurizer));
}
Also used : CoreLabel(edu.stanford.nlp.ling.CoreLabel) java.util(java.util) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) IOUtils(edu.stanford.nlp.io.IOUtils) edu.stanford.nlp.util(edu.stanford.nlp.util) BiFunction(java.util.function.BiFunction) Redwood(edu.stanford.nlp.util.logging.Redwood) Util(edu.stanford.nlp.util.logging.Redwood.Util) Span(edu.stanford.nlp.ie.machinereading.structure.Span) Counter(edu.stanford.nlp.stats.Counter) Stream(java.util.stream.Stream) java.io(java.io) SemanticGraphCoreAnnotations(edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) edu.stanford.nlp.classify(edu.stanford.nlp.classify) RelationTriple(edu.stanford.nlp.ie.util.RelationTriple) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) GZIPOutputStream(java.util.zip.GZIPOutputStream) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) RVFDatum(edu.stanford.nlp.ling.RVFDatum) ClauseSplitterSearchProblem(edu.stanford.nlp.naturalli.ClauseSplitterSearchProblem) Counter(edu.stanford.nlp.stats.Counter) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) GZIPOutputStream(java.util.zip.GZIPOutputStream) RelationTriple(edu.stanford.nlp.ie.util.RelationTriple) ClauseSplitterSearchProblem(edu.stanford.nlp.naturalli.ClauseSplitterSearchProblem) SemanticGraphCoreAnnotations(edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations) CoreLabel(edu.stanford.nlp.ling.CoreLabel) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) Span(edu.stanford.nlp.ie.machinereading.structure.Span) RVFDatum(edu.stanford.nlp.ling.RVFDatum) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) SemanticGraphCoreAnnotations(edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter)

Aggregations

edu.stanford.nlp.classify (edu.stanford.nlp.classify)1 Span (edu.stanford.nlp.ie.machinereading.structure.Span)1 RelationTriple (edu.stanford.nlp.ie.util.RelationTriple)1 IOUtils (edu.stanford.nlp.io.IOUtils)1 RuntimeIOException (edu.stanford.nlp.io.RuntimeIOException)1 CoreAnnotations (edu.stanford.nlp.ling.CoreAnnotations)1 CoreLabel (edu.stanford.nlp.ling.CoreLabel)1 RVFDatum (edu.stanford.nlp.ling.RVFDatum)1 ClauseSplitterSearchProblem (edu.stanford.nlp.naturalli.ClauseSplitterSearchProblem)1 SemanticGraph (edu.stanford.nlp.semgraph.SemanticGraph)1 SemanticGraphCoreAnnotations (edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations)1 ClassicCounter (edu.stanford.nlp.stats.ClassicCounter)1 Counter (edu.stanford.nlp.stats.Counter)1 edu.stanford.nlp.util (edu.stanford.nlp.util)1 Redwood (edu.stanford.nlp.util.logging.Redwood)1 Util (edu.stanford.nlp.util.logging.Redwood.Util)1 java.io (java.io)1 java.util (java.util)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 BiFunction (java.util.function.BiFunction)1