Search in sources :

Example 6 with Counter

use of edu.stanford.nlp.stats.Counter in project CoreNLP by stanfordnlp.

the class KBPStatisticalExtractor method main.

public static void main(String[] args) throws IOException, ClassNotFoundException {
    // Disable SLF4J crap.
    RedwoodConfiguration.standard().apply();
    // Fill command-line options
    ArgumentParser.fillOptions(KBPStatisticalExtractor.class, args);
    // Load the test (or dev) data
    forceTrack("Test data");
    List<Pair<KBPInput, String>> testExamples = KBPRelationExtractor.readDataset(TEST_FILE);
    log.info("Read " + testExamples.size() + " examples");
    endTrack("Test data");
    // If we can't find an existing model, train one
    if (!IOUtils.existsInClasspathOrFileSystem(MODEL_FILE)) {
        forceTrack("Training data");
        List<Pair<KBPInput, String>> trainExamples = KBPRelationExtractor.readDataset(TRAIN_FILE);
        log.info("Read " + trainExamples.size() + " examples");
        log.info("" + trainExamples.stream().map(Pair::second).filter(NO_RELATION::equals).count() + " are " + NO_RELATION);
        endTrack("Training data");
        // Featurize + create the dataset
        forceTrack("Creating dataset");
        RVFDataset<String, String> dataset = new RVFDataset<>();
        final AtomicInteger i = new AtomicInteger(0);
        long beginTime = System.currentTimeMillis();
        trainExamples.stream().parallel().forEach(example -> {
            if (i.incrementAndGet() % 1000 == 0) {
                log.info("[" + Redwood.formatTimeDifference(System.currentTimeMillis() - beginTime) + "] Featurized " + i.get() + " / " + trainExamples.size() + " examples");
            }
            Counter<String> features = features(example.first);
            synchronized (dataset) {
                dataset.add(new RVFDatum<>(features, example.second));
            }
        });
        // Free up some memory
        trainExamples.clear();
        endTrack("Creating dataset");
        // Train the classifier
        log.info("Training classifier:");
        Classifier<String, String> classifier = trainMultinomialClassifier(dataset, FEATURE_THRESHOLD, SIGMA);
        // Free up some memory
        dataset.clear();
        // Save the classifier
        IOUtils.writeObjectToFile(new KBPStatisticalExtractor(classifier), MODEL_FILE);
    }
    // Read either a newly-trained or pre-trained model
    Object model = IOUtils.readObjectFromURLOrClasspathOrFileSystem(MODEL_FILE);
    KBPStatisticalExtractor classifier;
    if (model instanceof Classifier) {
        //noinspection unchecked
        classifier = new KBPStatisticalExtractor((Classifier<String, String>) model);
    } else {
        classifier = ((KBPStatisticalExtractor) model);
    }
    // Evaluate the model
    classifier.computeAccuracy(testExamples.stream(), PREDICTIONS.map(x -> {
        try {
            return "stdout".equalsIgnoreCase(x) ? System.out : new PrintStream(new FileOutputStream(x));
        } catch (IOException e) {
            throw new RuntimeIOException(e);
        }
    }));
}
Also used : edu.stanford.nlp.optimization(edu.stanford.nlp.optimization) CoreLabel(edu.stanford.nlp.ling.CoreLabel) java.util(java.util) Counters(edu.stanford.nlp.stats.Counters) IOUtils(edu.stanford.nlp.io.IOUtils) DefaultPaths(edu.stanford.nlp.pipeline.DefaultPaths) edu.stanford.nlp.util(edu.stanford.nlp.util) Redwood(edu.stanford.nlp.util.logging.Redwood) Util(edu.stanford.nlp.util.logging.Redwood.Util) Datum(edu.stanford.nlp.ling.Datum) Function(java.util.function.Function) Collectors(java.util.stream.Collectors) Span(edu.stanford.nlp.ie.machinereading.structure.Span) Counter(edu.stanford.nlp.stats.Counter) java.io(java.io) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) edu.stanford.nlp.classify(edu.stanford.nlp.classify) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) Sentence(edu.stanford.nlp.simple.Sentence) RedwoodConfiguration(edu.stanford.nlp.util.logging.RedwoodConfiguration) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) RVFDatum(edu.stanford.nlp.ling.RVFDatum) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) AtomicInteger(java.util.concurrent.atomic.AtomicInteger)

Example 7 with Counter

use of edu.stanford.nlp.stats.Counter in project CoreNLP by stanfordnlp.

the class ConstantsAndVariables method getAllOptions.

//  public void addLearnedWords(String trainLabel, Counter<CandidatePhrase> identifiedWords) {
//    if(!learnedWords.containsKey(trainLabel))
//      learnedWords.put(trainLabel, new ClassicCounter<CandidatePhrase>());
//    this.learnedWords.get(trainLabel).addAll(identifiedWords);
//  }
public Map<String, String> getAllOptions() {
    Map<String, String> values = new HashMap<>();
    if (props != null)
        props.forEach((x, y) -> values.put(x.toString(), y == null ? "null" : y.toString()));
    Class<?> thisClass;
    try {
        thisClass = Class.forName(this.getClass().getName());
        Field[] aClassFields = thisClass.getDeclaredFields();
        for (Field f : aClassFields) {
            if (f.getType().getClass().isPrimitive() || Arrays.binarySearch(GetPatternsFromDataMultiClass.printOptionClass, f.getType()) >= 0) {
                String fName = f.getName();
                Object fvalue = f.get(this);
                values.put(fName, fvalue == null ? "null" : fvalue.toString());
            }
        }
    } catch (Exception e) {
        e.printStackTrace();
    }
    return values;
}
Also used : WordShapeClassifier(edu.stanford.nlp.process.WordShapeClassifier) java.util(java.util) Key(edu.stanford.nlp.util.TypesafeMap.Key) JsonArrayBuilder(javax.json.JsonArrayBuilder) edu.stanford.nlp.util(edu.stanford.nlp.util) DepPatternFactory(edu.stanford.nlp.patterns.dep.DepPatternFactory) WordScoring(edu.stanford.nlp.patterns.GetPatternsFromDataMultiClass.WordScoring) NodePattern(edu.stanford.nlp.ling.tokensregex.NodePattern) Counter(edu.stanford.nlp.stats.Counter) SurfacePatternFactory(edu.stanford.nlp.patterns.surface.SurfacePatternFactory) Json(javax.json.Json) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) Option(edu.stanford.nlp.util.ArgumentParser.Option) TokenSequencePattern(edu.stanford.nlp.ling.tokensregex.TokenSequencePattern) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) Counters(edu.stanford.nlp.stats.Counters) IOUtils(edu.stanford.nlp.io.IOUtils) Redwood(edu.stanford.nlp.util.logging.Redwood) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) IOException(java.io.IOException) Field(java.lang.reflect.Field) PatternScoring(edu.stanford.nlp.patterns.GetPatternsFromDataMultiClass.PatternScoring) File(java.io.File) Serializable(java.io.Serializable) Entry(java.util.Map.Entry) Env(edu.stanford.nlp.ling.tokensregex.Env) Pattern(java.util.regex.Pattern) JsonObjectBuilder(javax.json.JsonObjectBuilder) Field(java.lang.reflect.Field) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) IOException(java.io.IOException)

Example 8 with Counter

use of edu.stanford.nlp.stats.Counter in project CoreNLP by stanfordnlp.

the class GetPatternsFromDataMultiClass method loadFromSavedPatternsWordsDir.

public static <E extends Pattern> Map<E, String> loadFromSavedPatternsWordsDir(GetPatternsFromDataMultiClass<E> model, Properties props) throws IOException, ClassNotFoundException {
    boolean labelSentsUsingModel = Boolean.parseBoolean(props.getProperty("labelSentsUsingModel", "true"));
    boolean applyPatsUsingModel = Boolean.parseBoolean(props.getProperty("applyPatsUsingModel", "true"));
    int numIterationsOfSavedPatternsToLoad = Integer.parseInt(props.getProperty(Flags.numIterationsOfSavedPatternsToLoad, String.valueOf(Integer.MAX_VALUE)));
    Map<E, String> labelsForPattterns = new HashMap<>();
    String patternsWordsDirValue = props.getProperty(Flags.patternsWordsDir);
    String patternsWordsDir;
    //    if(patternsWordsDirValue.endsWith(".zip")){
    //      File tempdir = File.createTempFile("patternswordsdir","dir");
    //      tempdir.deleteOnExit();
    //      tempdir.delete();
    //      tempdir.mkdirs();
    //      patternsWordsDir = tempdir.getAbsolutePath();
    //      unzip(patternsWordsDirValue, patternsWordsDir);
    //    }else
    patternsWordsDir = patternsWordsDirValue;
    String sentsOutFile = props.getProperty("sentsOutFile");
    String loadModelForLabels = props.getProperty(Flags.loadModelForLabels);
    List<String> loadModelForLabelsList = null;
    if (loadModelForLabels != null)
        loadModelForLabelsList = Arrays.asList(loadModelForLabels.split("[,;]"));
    for (String label : model.constVars.getLabels()) {
        if (loadModelForLabels != null && !loadModelForLabelsList.contains(label))
            continue;
        assert (new File(patternsWordsDir + "/" + label).exists()) : "Why does the directory " + patternsWordsDir + "/" + label + " not exist?";
        readClassesInEnv(patternsWordsDir + "/env.txt", model.constVars.env, ConstantsAndVariables.globalEnv);
        //Read the token mapping
        if (model.constVars.patternType.equals(PatternFactory.PatternType.SURFACE))
            Token.setClass2KeyMapping(new File(patternsWordsDir + "/tokenenv.txt"));
        //Load Patterns
        File patf = new File(patternsWordsDir + "/" + label + "/patternsEachIter.ser");
        if (patf.exists()) {
            Map<Integer, Counter<E>> patterns = IOUtils.readObjectFromFile(patf);
            if (numIterationsOfSavedPatternsToLoad < Integer.MAX_VALUE) {
                Set<Integer> toremove = new HashSet<>();
                for (Integer i : patterns.keySet()) {
                    if (i >= numIterationsOfSavedPatternsToLoad) {
                        System.out.println("Removing patterns from iteration " + i);
                        toremove.add(i);
                    }
                }
                for (Integer i : toremove) patterns.remove(i);
            }
            Counter<E> pats = Counters.flatten(patterns);
            for (E p : pats.keySet()) {
                labelsForPattterns.put(p, label);
            }
            numIterationsLoadedModel = Math.max(numIterationsLoadedModel, patterns.size());
            model.setLearnedPatterns(pats, label);
            model.setLearnedPatternsEachIter(patterns, label);
            Redwood.log(Redwood.DBG, "Loaded " + model.getLearnedPatterns().get(label).size() + " patterns from " + patf);
        }
        //Load Words
        File wordf = new File(patternsWordsDir + "/" + label + "/phrases.txt");
        if (wordf.exists()) {
            TreeMap<Integer, Counter<CandidatePhrase>> words = GetPatternsFromDataMultiClass.readLearnedWordsFromFile(wordf);
            model.constVars.setLearnedWordsEachIter(words, label);
            if (numIterationsOfSavedPatternsToLoad < Integer.MAX_VALUE) {
                Set<Integer> toremove = new HashSet<>();
                for (Integer i : words.keySet()) {
                    if (i >= numIterationsOfSavedPatternsToLoad) {
                        System.out.println("Removing patterns from iteration " + i);
                        toremove.add(i);
                    }
                }
                for (Integer i : toremove) words.remove(i);
            }
            numIterationsLoadedModel = Math.max(numIterationsLoadedModel, words.size());
            Redwood.log(Redwood.DBG, "Loaded " + words.size() + " phrases from " + wordf);
        }
        CollectionValuedMap<E, Triple<String, Integer, Integer>> matchedTokensByPat = new CollectionValuedMap<>();
        Iterator<Pair<Map<String, DataInstance>, File>> sentsIter = new ConstantsAndVariables.DataSentsIterator(model.constVars.batchProcessSents);
        TwoDimensionalCounter<CandidatePhrase, E> wordsandLemmaPatExtracted = new TwoDimensionalCounter<>();
        Set<CandidatePhrase> alreadyLabeledWords = new HashSet<>();
        while (sentsIter.hasNext()) {
            Pair<Map<String, DataInstance>, File> sents = sentsIter.next();
            if (labelSentsUsingModel) {
                Redwood.log(Redwood.DBG, "labeling sentences from " + sents.second() + " with the already learned words");
                assert sents.first() != null : "Why are sents null";
                model.labelWords(label, sents.first(), model.constVars.getLearnedWords(label).keySet(), sentsOutFile, matchedTokensByPat);
                if (sents.second().exists())
                    IOUtils.writeObjectToFile(sents, sents.second());
            }
            if (model.constVars.restrictToMatched || applyPatsUsingModel) {
                Redwood.log(Redwood.DBG, "Applying patterns to " + sents.first().size() + " sentences");
                model.constVars.invertedIndex.add(sents.first(), true);
                model.constVars.invertedIndex.add(sents.first(), true);
                model.scorePhrases.applyPats(model.getLearnedPatterns(label), label, wordsandLemmaPatExtracted, matchedTokensByPat, alreadyLabeledWords);
            }
        }
        Counters.addInPlace(model.wordsPatExtracted.get(label), wordsandLemmaPatExtracted);
        System.out.println("All Extracted phrases are " + wordsandLemmaPatExtracted.firstKeySet());
    }
    System.out.flush();
    System.err.flush();
    return labelsForPattterns;
}
Also used : Counter(edu.stanford.nlp.stats.Counter) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) TwoDimensionalCounter(edu.stanford.nlp.stats.TwoDimensionalCounter) TwoDimensionalCounter(edu.stanford.nlp.stats.TwoDimensionalCounter) AtomicInteger(java.util.concurrent.atomic.AtomicInteger)

Example 9 with Counter

use of edu.stanford.nlp.stats.Counter in project CoreNLP by stanfordnlp.

the class ClauseSplitterSearchProblem method search.

/**
   * The core implementation of the search.
   *
   * @param root The root word to search from. Traditionally, this is the root of the sentence.
   * @param candidateFragments The callback for the resulting sentence fragments.
   *                           This is a predicate of a triple of values.
   *                           The return value of the predicate determines whether we should continue searching.
   *                           The triple is a triple of
   *                           <ol>
   *                             <li>The log probability of the sentence fragment, according to the featurizer and the weights</li>
   *                             <li>The features along the path to this fragment. The last element of this is the features from the most recent step.</li>
   *                             <li>The sentence fragment. Because it is relatively expensive to compute the resulting tree, this is returned as a lazy {@link Supplier}.</li>
   *                           </ol>
   * @param classifier The classifier for whether an arc should be on the path to a clause split, a clause split itself, or neither.
   * @param featurizer The featurizer to use. Make sure this matches the weights!
   * @param actionSpace The action space we are allowed to take. Each action defines a means of splitting a clause on a dependency boundary.
   */
protected void search(// The root to search from
IndexedWord root, // The output specs
final Predicate<Triple<Double, List<Counter<String>>, Supplier<SentenceFragment>>> candidateFragments, // The learning specs
final Classifier<ClauseSplitter.ClauseClassifierLabel, String> classifier, Map<String, ? extends List<String>> hardCodedSplits, final Function<Triple<State, Action, State>, Counter<String>> featurizer, final Collection<Action> actionSpace, final int maxTicks) {
    // (the fringe)
    PriorityQueue<Pair<State, List<Counter<String>>>> fringe = new FixedPrioritiesPriorityQueue<>();
    // (avoid duplicate work)
    Set<IndexedWord> seenWords = new HashSet<>();
    State firstState = new State(null, null, -9000, null, x -> {
    }, // First state is implicitly "done"
    true);
    fringe.add(Pair.makePair(firstState, new ArrayList<>(0)), -0.0);
    int ticks = 0;
    while (!fringe.isEmpty()) {
        if (++ticks > maxTicks) {
            //        log.info("WARNING! Timed out on search with " + ticks + " ticks");
            return;
        }
        // Useful variables
        double logProbSoFar = fringe.getPriority();
        assert logProbSoFar <= 0.0;
        Pair<State, List<Counter<String>>> lastStatePair = fringe.removeFirst();
        State lastState = lastStatePair.first;
        List<Counter<String>> featuresSoFar = lastStatePair.second;
        IndexedWord rootWord = lastState.edge == null ? root : lastState.edge.getDependent();
        // Register thunk
        if (lastState.isDone) {
            if (!candidateFragments.test(Triple.makeTriple(logProbSoFar, featuresSoFar, () -> {
                SemanticGraph copy = new SemanticGraph(tree);
                lastState.thunk.andThen(x -> {
                    for (IndexedWord newTreeRoot : x.getRoots()) {
                        if (newTreeRoot != null) {
                            for (SemanticGraphEdge extraEdge : extraEdgesByGovernor.get(newTreeRoot)) {
                                assert Util.isTree(x);
                                addSubtree(x, newTreeRoot, extraEdge.getRelation().toString(), tree, extraEdge.getDependent(), tree.getIncomingEdgesSorted(newTreeRoot));
                                assert Util.isTree(x);
                            }
                        }
                    }
                }).accept(copy);
                return new SentenceFragment(copy, assumedTruth, false);
            }))) {
                break;
            }
        }
        // Find relevant auxilliary terms
        SemanticGraphEdge subjOrNull = null;
        SemanticGraphEdge objOrNull = null;
        for (SemanticGraphEdge auxEdge : tree.outgoingEdgeIterable(rootWord)) {
            String relString = auxEdge.getRelation().toString();
            if (relString.contains("obj")) {
                objOrNull = auxEdge;
            } else if (relString.contains("subj")) {
                subjOrNull = auxEdge;
            }
        }
        // For each outgoing edge...
        for (SemanticGraphEdge outgoingEdge : tree.outgoingEdgeIterable(rootWord)) {
            // This fires if the governor is an indirect speech verb, and the outgoing edge is a ccomp
            if (outgoingEdge.getRelation().toString().equals("ccomp") && ((outgoingEdge.getGovernor().lemma() != null && INDIRECT_SPEECH_LEMMAS.contains(outgoingEdge.getGovernor().lemma())) || INDIRECT_SPEECH_LEMMAS.contains(outgoingEdge.getGovernor().word()))) {
                continue;
            }
            // Get some variables
            String outgoingEdgeRelation = outgoingEdge.getRelation().toString();
            List<String> forcedArcOrder = hardCodedSplits.get(outgoingEdgeRelation);
            if (forcedArcOrder == null && outgoingEdgeRelation.contains(":")) {
                forcedArcOrder = hardCodedSplits.get(outgoingEdgeRelation.substring(0, outgoingEdgeRelation.indexOf(":")) + ":*");
            }
            boolean doneForcedArc = false;
            // For each action...
            for (Action action : (forcedArcOrder == null ? actionSpace : orderActions(actionSpace, forcedArcOrder))) {
                // Check the prerequisite
                if (!action.prerequisitesMet(tree, outgoingEdge)) {
                    continue;
                }
                if (forcedArcOrder != null && doneForcedArc) {
                    break;
                }
                // 1. Compute the child state
                Optional<State> candidate = action.applyTo(tree, lastState, outgoingEdge, subjOrNull, objOrNull);
                if (candidate.isPresent()) {
                    double logProbability;
                    ClauseClassifierLabel bestLabel;
                    Counter<String> features = featurizer.apply(Triple.makeTriple(lastState, action, candidate.get()));
                    if (forcedArcOrder != null && !doneForcedArc) {
                        logProbability = 0.0;
                        bestLabel = ClauseClassifierLabel.CLAUSE_SPLIT;
                        doneForcedArc = true;
                    } else if (features.containsKey("__undocumented_junit_no_classifier")) {
                        logProbability = Double.NEGATIVE_INFINITY;
                        bestLabel = ClauseClassifierLabel.CLAUSE_INTERM;
                    } else {
                        Counter<ClauseClassifierLabel> scores = classifier.scoresOf(new RVFDatum<>(features));
                        if (scores.size() > 0) {
                            Counters.logNormalizeInPlace(scores);
                        }
                        String rel = outgoingEdge.getRelation().toString();
                        if ("nsubj".equals(rel) || "dobj".equals(rel)) {
                            // Always at least yield on nsubj and dobj
                            scores.remove(ClauseClassifierLabel.NOT_A_CLAUSE);
                        }
                        logProbability = Counters.max(scores, Double.NEGATIVE_INFINITY);
                        bestLabel = Counters.argmax(scores, (x, y) -> 0, ClauseClassifierLabel.CLAUSE_SPLIT);
                    }
                    if (bestLabel != ClauseClassifierLabel.NOT_A_CLAUSE) {
                        Pair<State, List<Counter<String>>> childState = Pair.makePair(candidate.get().withIsDone(bestLabel), new ArrayList<Counter<String>>(featuresSoFar) {

                            {
                                add(features);
                            }
                        });
                        // 2. Register the child state
                        if (!seenWords.contains(childState.first.edge.getDependent())) {
                            //            log.info("  pushing " + action.signature() + " with " + argmax.first.edge);
                            fringe.add(childState, logProbability);
                        }
                    }
                }
            }
        }
        seenWords.add(rootWord);
    }
//    log.info("Search finished in " + ticks + " ticks and " + classifierEvals + " classifier evaluations.");
}
Also used : java.util(java.util) SemanticGraphEdge(edu.stanford.nlp.semgraph.SemanticGraphEdge) GrammaticalRelation(edu.stanford.nlp.trees.GrammaticalRelation) Counters(edu.stanford.nlp.stats.Counters) Predicate(java.util.function.Predicate) Redwood(edu.stanford.nlp.util.logging.Redwood) edu.stanford.nlp.util(edu.stanford.nlp.util) ClauseClassifierLabel(edu.stanford.nlp.naturalli.ClauseSplitter.ClauseClassifierLabel) Function(java.util.function.Function) Supplier(java.util.function.Supplier) Consumer(java.util.function.Consumer) Counter(edu.stanford.nlp.stats.Counter) Stream(java.util.stream.Stream) java.io(java.io) edu.stanford.nlp.classify(edu.stanford.nlp.classify) Language(edu.stanford.nlp.international.Language) edu.stanford.nlp.ling(edu.stanford.nlp.ling) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) PriorityQueue(edu.stanford.nlp.util.PriorityQueue) ClauseClassifierLabel(edu.stanford.nlp.naturalli.ClauseSplitter.ClauseClassifierLabel) Counter(edu.stanford.nlp.stats.Counter) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) SemanticGraphEdge(edu.stanford.nlp.semgraph.SemanticGraphEdge) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph)

Example 10 with Counter

use of edu.stanford.nlp.stats.Counter in project CoreNLP by stanfordnlp.

the class ClauseSplitter method train.

/**
   * Train a clause searcher factory. That is, train a classifier for which arcs should be
   * new clauses.
   *
   * @param trainingData The training data. This is a stream of triples of:
   *                     <ol>
   *                       <li>The sentence containing a known extraction.</li>
   *                       <li>The span of the subject in the sentence, as a token span.</li>
   *                       <li>The span of the object in the sentence, as a token span.</li>
   *                     </ol>
   * @param modelPath The path to save the model to. This is useful for {@link ClauseSplitter#load(String)}.
   * @param trainingDataDump The path to save the training data, as a set of labeled featurized datums.
   * @param featurizer The featurizer to use for this classifier.
   *
   * @return A factory for creating searchers from a given dependency tree.
   */
static ClauseSplitter train(Stream<Pair<CoreMap, Collection<Pair<Span, Span>>>> trainingData, Optional<File> modelPath, Optional<File> trainingDataDump, Featurizer featurizer) {
    // Parse options
    LinearClassifierFactory<ClauseClassifierLabel, String> factory = new LinearClassifierFactory<>();
    // Generally useful objects
    OpenIE openie = new OpenIE(PropertiesUtils.asProperties("splitter.nomodel", "true", "optimizefor", "GENERAL"));
    WeightedDataset<ClauseClassifierLabel, String> dataset = new WeightedDataset<>();
    AtomicInteger numExamplesProcessed = new AtomicInteger(0);
    final Optional<PrintWriter> datasetDumpWriter = trainingDataDump.map(file -> {
        try {
            return new PrintWriter(new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(trainingDataDump.get()))));
        } catch (IOException e) {
            throw new RuntimeIOException(e);
        }
    });
    // Step 1: Loop over data
    forceTrack("Training inference");
    trainingData.forEach(rawExample -> {
        CoreMap sentence = rawExample.first;
        Collection<Pair<Span, Span>> spans = rawExample.second;
        List<CoreLabel> tokens = sentence.get(CoreAnnotations.TokensAnnotation.class);
        SemanticGraph tree = sentence.get(SemanticGraphCoreAnnotations.EnhancedDependenciesAnnotation.class);
        ClauseSplitterSearchProblem problem = new ClauseSplitterSearchProblem(tree, true);
        problem.search(fragmentAndScore -> {
            List<Counter<String>> features = fragmentAndScore.second;
            SentenceFragment fragment = fragmentAndScore.third.get();
            Set<RelationTriple> extractions = new HashSet<>(openie.relationsInFragments(openie.entailmentsFromClause(fragment)));
            Trilean correct = Trilean.FALSE;
            RELATION_TRIPLE_LOOP: for (RelationTriple extraction : extractions) {
                Span subjectGuess = Span.fromValues(extraction.subject.get(0).index() - 1, extraction.subject.get(extraction.subject.size() - 1).index());
                Span objectGuess = Span.fromValues(extraction.object.get(0).index() - 1, extraction.object.get(extraction.object.size() - 1).index());
                for (Pair<Span, Span> candidateGold : spans) {
                    Span subjectSpan = candidateGold.first;
                    Span objectSpan = candidateGold.second;
                    if ((subjectGuess.equals(subjectSpan) && objectGuess.equals(objectSpan)) || (subjectGuess.equals(objectSpan) && objectGuess.equals(subjectSpan))) {
                        correct = Trilean.TRUE;
                        break RELATION_TRIPLE_LOOP;
                    } else if (Util.nerOverlap(tokens, subjectSpan, subjectGuess) && Util.nerOverlap(tokens, objectSpan, objectGuess) || Util.nerOverlap(tokens, subjectSpan, objectGuess) && Util.nerOverlap(tokens, objectSpan, subjectGuess)) {
                        if (!correct.isTrue()) {
                            correct = Trilean.TRUE;
                            break RELATION_TRIPLE_LOOP;
                        }
                    } else {
                        if (!correct.isTrue()) {
                            correct = Trilean.UNKNOWN;
                            break RELATION_TRIPLE_LOOP;
                        }
                    }
                }
            }
            if (!features.isEmpty()) {
                List<Pair<Counter<String>, ClauseClassifierLabel>> decisionsToAddAsDatums = new ArrayList<>();
                if (correct.isTrue()) {
                    for (int i = 0; i < features.size(); ++i) {
                        if (i == features.size() - 1) {
                            decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT));
                        } else {
                            decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM));
                        }
                    }
                } else if (correct.isFalse()) {
                    decisionsToAddAsDatums.add(Pair.makePair(features.get(features.size() - 1), ClauseClassifierLabel.NOT_A_CLAUSE));
                } else if (correct.isUnknown()) {
                    boolean isSimpleSplit = false;
                    for (Counter<String> feats : features) {
                        if (featurizer.isSimpleSplit(feats)) {
                            isSimpleSplit = true;
                            break;
                        }
                    }
                    if (isSimpleSplit) {
                        for (int i = 0; i < features.size(); ++i) {
                            if (i == features.size() - 1) {
                                decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT));
                            } else {
                                decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM));
                            }
                        }
                    }
                }
                for (Pair<Counter<String>, ClauseClassifierLabel> decision : decisionsToAddAsDatums) {
                    RVFDatum<ClauseClassifierLabel, String> datum = new RVFDatum<>(decision.first);
                    datum.setLabel(decision.second);
                    if (datasetDumpWriter.isPresent()) {
                        datasetDumpWriter.get().println(decision.second + "\t" + StringUtils.join(decision.first.entrySet().stream().map(entry -> entry.getKey() + "->" + entry.getValue()), ";"));
                    }
                    dataset.add(datum);
                }
            }
            return true;
        }, new LinearClassifier<>(new ClassicCounter<>()), Collections.emptyMap(), featurizer, 10000);
        if (numExamplesProcessed.incrementAndGet() % 100 == 0) {
            log("processed " + numExamplesProcessed + " training sentences: " + dataset.size() + " datums");
        }
    });
    endTrack("Training inference");
    // Close the file
    if (datasetDumpWriter.isPresent()) {
        datasetDumpWriter.get().close();
    }
    // Step 2: Train classifier
    forceTrack("Training");
    Classifier<ClauseClassifierLabel, String> fullClassifier = factory.trainClassifier(dataset);
    endTrack("Training");
    if (modelPath.isPresent()) {
        Pair<Classifier<ClauseClassifierLabel, String>, Featurizer> toSave = Pair.makePair(fullClassifier, featurizer);
        try {
            IOUtils.writeObjectToFile(toSave, modelPath.get());
            log("SUCCESS: wrote model to " + modelPath.get().getPath());
        } catch (IOException e) {
            log("ERROR: failed to save model to path: " + modelPath.get().getPath());
            err(e);
        }
    }
    // Step 3: Check accuracy of classifier
    forceTrack("Training accuracy");
    dataset.randomize(42L);
    Util.dumpAccuracy(fullClassifier, dataset);
    endTrack("Training accuracy");
    int numFolds = 5;
    forceTrack(numFolds + " fold cross-validation");
    for (int fold = 0; fold < numFolds; ++fold) {
        forceTrack("Fold " + (fold + 1));
        forceTrack("Training");
        Pair<GeneralDataset<ClauseClassifierLabel, String>, GeneralDataset<ClauseClassifierLabel, String>> foldData = dataset.splitOutFold(fold, numFolds);
        Classifier<ClauseClassifierLabel, String> classifier = factory.trainClassifier(foldData.first);
        endTrack("Training");
        forceTrack("Test");
        Util.dumpAccuracy(classifier, foldData.second);
        endTrack("Test");
        endTrack("Fold " + (fold + 1));
    }
    endTrack(numFolds + " fold cross-validation");
    // Step 5: return factory
    return (tree, truth) -> new ClauseSplitterSearchProblem(tree, truth, Optional.of(fullClassifier), Optional.of(featurizer));
}
Also used : CoreLabel(edu.stanford.nlp.ling.CoreLabel) java.util(java.util) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) IOUtils(edu.stanford.nlp.io.IOUtils) edu.stanford.nlp.util(edu.stanford.nlp.util) BiFunction(java.util.function.BiFunction) Redwood(edu.stanford.nlp.util.logging.Redwood) Util(edu.stanford.nlp.util.logging.Redwood.Util) Span(edu.stanford.nlp.ie.machinereading.structure.Span) Counter(edu.stanford.nlp.stats.Counter) Stream(java.util.stream.Stream) java.io(java.io) SemanticGraphCoreAnnotations(edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) edu.stanford.nlp.classify(edu.stanford.nlp.classify) RelationTriple(edu.stanford.nlp.ie.util.RelationTriple) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) GZIPOutputStream(java.util.zip.GZIPOutputStream) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) RVFDatum(edu.stanford.nlp.ling.RVFDatum) ClauseSplitterSearchProblem(edu.stanford.nlp.naturalli.ClauseSplitterSearchProblem) Counter(edu.stanford.nlp.stats.Counter) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) GZIPOutputStream(java.util.zip.GZIPOutputStream) RelationTriple(edu.stanford.nlp.ie.util.RelationTriple) ClauseSplitterSearchProblem(edu.stanford.nlp.naturalli.ClauseSplitterSearchProblem) SemanticGraphCoreAnnotations(edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations) CoreLabel(edu.stanford.nlp.ling.CoreLabel) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) Span(edu.stanford.nlp.ie.machinereading.structure.Span) RVFDatum(edu.stanford.nlp.ling.RVFDatum) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) RuntimeIOException(edu.stanford.nlp.io.RuntimeIOException) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) SemanticGraphCoreAnnotations(edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter)

Aggregations

ClassicCounter (edu.stanford.nlp.stats.ClassicCounter)13 Counter (edu.stanford.nlp.stats.Counter)13 CoreLabel (edu.stanford.nlp.ling.CoreLabel)7 IOUtils (edu.stanford.nlp.io.IOUtils)6 CoreAnnotations (edu.stanford.nlp.ling.CoreAnnotations)6 edu.stanford.nlp.util (edu.stanford.nlp.util)6 java.util (java.util)6 Redwood (edu.stanford.nlp.util.logging.Redwood)5 edu.stanford.nlp.classify (edu.stanford.nlp.classify)4 RuntimeIOException (edu.stanford.nlp.io.RuntimeIOException)4 SemanticGraph (edu.stanford.nlp.semgraph.SemanticGraph)4 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)4 Function (java.util.function.Function)4 IndexedWord (edu.stanford.nlp.ling.IndexedWord)3 RVFDatum (edu.stanford.nlp.ling.RVFDatum)3 TokenSequencePattern (edu.stanford.nlp.ling.tokensregex.TokenSequencePattern)3 Counters (edu.stanford.nlp.stats.Counters)3 TwoDimensionalCounter (edu.stanford.nlp.stats.TwoDimensionalCounter)3 Util (edu.stanford.nlp.util.logging.Redwood.Util)3 java.io (java.io)3