Search in sources :

Example 1 with Counter

use of edu.stanford.nlp.stats.Counter in project CoreNLP by stanfordnlp.

the class FeatureExtractor method getFeatures.

private Counter<String> getFeatures(Document doc, Mention m, Map<Integer, List<Mention>> mentionsByHeadIndex) {
    Counter<String> features = new ClassicCounter<>();
    // type features
    features.incrementCount("mention-type=" + m.mentionType);
    features.incrementCount("gender=" + m.gender);
    features.incrementCount("person-fine=" + m.person);
    features.incrementCount("head-ne-type=" + m.nerString);
    List<String> singletonFeatures = m.getSingletonFeatures(dictionaries);
    for (Map.Entry<Integer, String> e : SINGLETON_FEATURES.entrySet()) {
        if (e.getKey() < singletonFeatures.size()) {
            features.incrementCount(e.getValue() + "=" + singletonFeatures.get(e.getKey()));
        }
    }
    // length and location features
    addNumeric(features, "mention-length", m.spanToString().length());
    addNumeric(features, "mention-words", m.originalSpan.size());
    addNumeric(features, "sentence-words", m.sentenceWords.size());
    features.incrementCount("sentence-words=" + bin(m.sentenceWords.size()));
    features.incrementCount("mention-position", m.mentionNum / (double) doc.predictedMentions.size());
    features.incrementCount("sentence-position", m.sentNum / (double) doc.numSentences);
    // lexical features
    CoreLabel firstWord = firstWord(m);
    CoreLabel lastWord = lastWord(m);
    CoreLabel headWord = headWord(m);
    CoreLabel prevWord = prevWord(m);
    CoreLabel nextWord = nextWord(m);
    CoreLabel prevprevWord = prevprevWord(m);
    CoreLabel nextnextWord = nextnextWord(m);
    String headPOS = getPOS(headWord);
    String firstPOS = getPOS(firstWord);
    String lastPOS = getPOS(lastWord);
    String prevPOS = getPOS(prevWord);
    String nextPOS = getPOS(nextWord);
    String prevprevPOS = getPOS(prevprevWord);
    String nextnextPOS = getPOS(nextnextWord);
    features.incrementCount("first-word=" + wordIndicator(firstWord, firstPOS));
    features.incrementCount("last-word=" + wordIndicator(lastWord, lastPOS));
    features.incrementCount("head-word=" + wordIndicator(headWord, headPOS));
    features.incrementCount("next-word=" + wordIndicator(nextWord, nextPOS));
    features.incrementCount("prev-word=" + wordIndicator(prevWord, prevPOS));
    features.incrementCount("next-bigram=" + wordIndicator(nextWord, nextnextWord, nextPOS + "_" + nextnextPOS));
    features.incrementCount("prev-bigram=" + wordIndicator(prevprevWord, prevWord, prevprevPOS + "_" + prevPOS));
    features.incrementCount("next-pos=" + nextPOS);
    features.incrementCount("prev-pos=" + prevPOS);
    features.incrementCount("first-pos=" + firstPOS);
    features.incrementCount("last-pos=" + lastPOS);
    features.incrementCount("next-pos-bigram=" + nextPOS + "_" + nextnextPOS);
    features.incrementCount("prev-pos-bigram=" + prevprevPOS + "_" + prevPOS);
    addDependencyFeatures(features, "parent", getDependencyParent(m), true);
    addFeature(features, "ends-with-head", m.headIndex == m.endIndex - 1);
    addFeature(features, "is-generic", m.originalSpan.size() == 1 && firstPOS.equals("NNS"));
    // syntax features
    IndexedWord w = m.headIndexedWord;
    String depPath = "";
    int depth = 0;
    while (w != null) {
        SemanticGraphEdge e = getDependencyParent(m, w);
        depth++;
        if (depth <= 3 && e != null) {
            depPath += (depPath.isEmpty() ? "" : "_") + e.getRelation().toString();
            features.incrementCount("dep-path=" + depPath);
            w = e.getSource();
        } else {
            w = null;
        }
    }
    if (useConstituencyParse) {
        int fullEmbeddingLevel = headEmbeddingLevel(m.contextParseTree, m.headIndex);
        int mentionEmbeddingLevel = headEmbeddingLevel(m.mentionSubTree, m.headIndex - m.startIndex);
        if (fullEmbeddingLevel != -1 && mentionEmbeddingLevel != -1) {
            features.incrementCount("mention-embedding-level=" + bin(fullEmbeddingLevel - mentionEmbeddingLevel));
            features.incrementCount("head-embedding-level=" + bin(mentionEmbeddingLevel));
        } else {
            features.incrementCount("undetermined-embedding-level");
        }
        features.incrementCount("num-embedded-nps=" + bin(numEmbeddedNps(m.mentionSubTree)));
        String syntaxPath = "";
        Tree tree = m.contextParseTree;
        Tree head = tree.getLeaves().get(m.headIndex).ancestor(1, tree);
        depth = 0;
        for (Tree node : tree.pathNodeToNode(head, tree)) {
            syntaxPath += node.value() + "-";
            features.incrementCount("syntax-path=" + syntaxPath);
            depth++;
            if (depth >= 4 || node.value().equals("S")) {
                break;
            }
        }
    }
    // mention containment features
    addFeature(features, "contained-in-other-mention", mentionsByHeadIndex.get(m.headIndex).stream().anyMatch(m2 -> m != m2 && m.insideIn(m2)));
    addFeature(features, "contains-other-mention", mentionsByHeadIndex.get(m.headIndex).stream().anyMatch(m2 -> m != m2 && m2.insideIn(m)));
    // features from dcoref rules
    addFeature(features, "bare-plural", m.originalSpan.size() == 1 && headPOS.equals("NNS"));
    addFeature(features, "quantifier-start", dictionaries.quantifiers.contains(firstWord.word().toLowerCase()));
    addFeature(features, "negative-start", firstWord.word().toLowerCase().matches("none|no|nothing|not"));
    addFeature(features, "partitive", RuleBasedCorefMentionFinder.partitiveRule(m, m.sentenceWords, dictionaries));
    addFeature(features, "adjectival-demonym", dictionaries.isAdjectivalDemonym(m.spanToString()));
    if (doc.docType != DocType.ARTICLE && m.person == Person.YOU && nextWord != null && nextWord.word().equalsIgnoreCase("know")) {
        features.incrementCount("generic-you");
    }
    return features;
}
Also used : SpeakerAnnotation(edu.stanford.nlp.ling.CoreAnnotations.SpeakerAnnotation) Tree(edu.stanford.nlp.trees.Tree) HashMap(java.util.HashMap) Random(java.util.Random) Dictionaries(edu.stanford.nlp.coref.data.Dictionaries) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) Number(edu.stanford.nlp.coref.data.Dictionaries.Number) CorefCluster(edu.stanford.nlp.coref.data.CorefCluster) Mention(edu.stanford.nlp.coref.data.Mention) RuleBasedCorefMentionFinder(edu.stanford.nlp.coref.md.RuleBasedCorefMentionFinder) Counter(edu.stanford.nlp.stats.Counter) Map(java.util.Map) Pair(edu.stanford.nlp.util.Pair) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) CorefRules(edu.stanford.nlp.coref.CorefRules) IndexedWord(edu.stanford.nlp.ling.IndexedWord) CoreLabel(edu.stanford.nlp.ling.CoreLabel) Properties(java.util.Properties) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) SemanticGraphEdge(edu.stanford.nlp.semgraph.SemanticGraphEdge) Iterator(java.util.Iterator) IOUtils(edu.stanford.nlp.io.IOUtils) DocType(edu.stanford.nlp.coref.data.Document.DocType) Set(java.util.Set) Person(edu.stanford.nlp.coref.data.Dictionaries.Person) List(java.util.List) MentionType(edu.stanford.nlp.coref.data.Dictionaries.MentionType) StringUtils(edu.stanford.nlp.util.StringUtils) CorefProperties(edu.stanford.nlp.coref.CorefProperties) Document(edu.stanford.nlp.coref.data.Document) CorefUtils(edu.stanford.nlp.coref.CorefUtils) SemanticGraphEdge(edu.stanford.nlp.semgraph.SemanticGraphEdge) CoreLabel(edu.stanford.nlp.ling.CoreLabel) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) Tree(edu.stanford.nlp.trees.Tree) IndexedWord(edu.stanford.nlp.ling.IndexedWord) HashMap(java.util.HashMap) Map(java.util.Map)

Example 2 with Counter

use of edu.stanford.nlp.stats.Counter in project CoreNLP by stanfordnlp.

the class PairwiseModelTrainer method test.

public static void test(PairwiseModel model, String predictionsName, boolean anaphoricityModel) throws Exception {
    Redwood.log("scoref-train", "Reading compression...");
    Compressor<String> compressor = IOUtils.readObjectFromFile(StatisticalCorefTrainer.compressorFile);
    Redwood.log("scoref-train", "Reading test data...");
    List<DocumentExamples> testDocuments = IOUtils.readObjectFromFile(StatisticalCorefTrainer.extractedFeaturesFile);
    Redwood.log("scoref-train", "Building test set...");
    List<Pair<Example, Map<Integer, CompressedFeatureVector>>> allExamples = anaphoricityModel ? getAnaphoricityExamples(testDocuments) : getExamples(testDocuments);
    Redwood.log("scoref-train", "Testing...");
    PrintWriter writer = new PrintWriter(model.getDefaultOutputPath() + predictionsName);
    Map<Integer, Counter<Pair<Integer, Integer>>> scores = new HashMap<>();
    writeScores(allExamples, compressor, model, writer, scores);
    if (model instanceof MaxMarginMentionRanker) {
        writer.close();
        writer = new PrintWriter(model.getDefaultOutputPath() + predictionsName + "_anaphoricity");
        testDocuments = IOUtils.readObjectFromFile(StatisticalCorefTrainer.extractedFeaturesFile);
        allExamples = getAnaphoricityExamples(testDocuments);
        writeScores(allExamples, compressor, model, writer, scores);
    }
    IOUtils.writeObjectToFile(scores, model.getDefaultOutputPath() + predictionsName + ".ser");
    writer.close();
}
Also used : HashMap(java.util.HashMap) Counter(edu.stanford.nlp.stats.Counter) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) Pair(edu.stanford.nlp.util.Pair) PrintWriter(java.io.PrintWriter)

Example 3 with Counter

use of edu.stanford.nlp.stats.Counter in project CoreNLP by stanfordnlp.

the class ScorePhrases method learnNewPhrasesPrivate.

private Counter<CandidatePhrase> learnNewPhrasesPrivate(String label, PatternsForEachToken patternsForEachToken, Counter<E> patternsLearnedThisIter, Counter<E> allSelectedPatterns, Set<CandidatePhrase> alreadyIdentifiedWords, CollectionValuedMap<E, Triple<String, Integer, Integer>> matchedTokensByPat, Counter<CandidatePhrase> scoreForAllWordsThisIteration, TwoDimensionalCounter<CandidatePhrase, E> terms, TwoDimensionalCounter<CandidatePhrase, E> wordsPatExtracted, TwoDimensionalCounter<E, CandidatePhrase> patternsAndWords4Label, String identifier, Set<CandidatePhrase> ignoreWords, boolean computeProcDataFreq) throws IOException, ClassNotFoundException {
    Set<CandidatePhrase> alreadyLabeledWords = new HashSet<>();
    if (constVars.doNotApplyPatterns) {
        // if want to get the stats by the lossy way of just counting without
        // applying the patterns
        ConstantsAndVariables.DataSentsIterator sentsIter = new ConstantsAndVariables.DataSentsIterator(constVars.batchProcessSents);
        while (sentsIter.hasNext()) {
            Pair<Map<String, DataInstance>, File> sentsf = sentsIter.next();
            this.statsWithoutApplyingPatterns(sentsf.first(), patternsForEachToken, patternsLearnedThisIter, wordsPatExtracted);
        }
    } else {
        if (patternsLearnedThisIter.size() > 0) {
            this.applyPats(patternsLearnedThisIter, label, wordsPatExtracted, matchedTokensByPat, alreadyLabeledWords);
        }
    }
    if (computeProcDataFreq) {
        if (!phraseScorer.wordFreqNorm.equals(Normalization.NONE)) {
            Redwood.log(Redwood.DBG, "computing processed freq");
            for (Entry<CandidatePhrase, Double> fq : Data.rawFreq.entrySet()) {
                Double in = fq.getValue();
                if (phraseScorer.wordFreqNorm.equals(Normalization.SQRT))
                    in = Math.sqrt(in);
                else if (phraseScorer.wordFreqNorm.equals(Normalization.LOG))
                    in = 1 + Math.log(in);
                else
                    throw new RuntimeException("can't understand the normalization");
                assert !in.isNaN() : "Why is processed freq nan when rawfreq is " + in;
                Data.processedDataFreq.setCount(fq.getKey(), in);
            }
        } else
            Data.processedDataFreq = Data.rawFreq;
    }
    if (constVars.wordScoring.equals(WordScoring.WEIGHTEDNORM)) {
        for (CandidatePhrase en : wordsPatExtracted.firstKeySet()) {
            if (!constVars.getOtherSemanticClassesWords().contains(en) && (en.getPhraseLemma() == null || !constVars.getOtherSemanticClassesWords().contains(CandidatePhrase.createOrGet(en.getPhraseLemma()))) && !alreadyLabeledWords.contains(en)) {
                terms.addAll(en, wordsPatExtracted.getCounter(en));
            }
        }
        removeKeys(terms, constVars.getStopWords());
        Counter<CandidatePhrase> phraseScores = phraseScorer.scorePhrases(label, terms, wordsPatExtracted, allSelectedPatterns, alreadyIdentifiedWords, false);
        System.out.println("count for word U.S. is " + phraseScores.getCount(CandidatePhrase.createOrGet("U.S.")));
        Set<CandidatePhrase> ignoreWordsAll;
        if (ignoreWords != null && !ignoreWords.isEmpty()) {
            ignoreWordsAll = CollectionUtils.unionAsSet(ignoreWords, constVars.getOtherSemanticClassesWords());
        } else
            ignoreWordsAll = new HashSet<>(constVars.getOtherSemanticClassesWords());
        ignoreWordsAll.addAll(constVars.getSeedLabelDictionary().get(label));
        ignoreWordsAll.addAll(constVars.getLearnedWords(label).keySet());
        System.out.println("ignoreWordsAll contains word U.S. is " + ignoreWordsAll.contains(CandidatePhrase.createOrGet("U.S.")));
        Counter<CandidatePhrase> finalwords = chooseTopWords(phraseScores, terms, phraseScores, ignoreWordsAll, constVars.thresholdWordExtract);
        phraseScorer.printReasonForChoosing(finalwords);
        scoreForAllWordsThisIteration.clear();
        Counters.addInPlace(scoreForAllWordsThisIteration, phraseScores);
        Redwood.log(ConstantsAndVariables.minimaldebug, "\n\n## Selected Words for " + label + " : " + Counters.toSortedString(finalwords, finalwords.size(), "%1$s:%2$.2f", "\t"));
        if (constVars.goldEntities != null) {
            Map<String, Boolean> goldEntities4Label = constVars.goldEntities.get(label);
            if (goldEntities4Label != null) {
                StringBuffer s = new StringBuffer();
                finalwords.keySet().stream().forEach(x -> s.append(x.getPhrase() + (goldEntities4Label.containsKey(x.getPhrase()) ? ":" + goldEntities4Label.get(x.getPhrase()) : ":UKNOWN") + "\n"));
                Redwood.log(ConstantsAndVariables.minimaldebug, "\n\n## Gold labels for selected words for label " + label + " : " + s.toString());
            } else
                Redwood.log(Redwood.DBG, "No gold entities provided for label " + label);
        }
        if (constVars.outDir != null && !constVars.outDir.isEmpty()) {
            String outputdir = constVars.outDir + "/" + identifier + "/" + label;
            IOUtils.ensureDir(new File(outputdir));
            TwoDimensionalCounter<CandidatePhrase, CandidatePhrase> reasonForWords = new TwoDimensionalCounter<>();
            for (CandidatePhrase word : finalwords.keySet()) {
                for (E l : wordsPatExtracted.getCounter(word).keySet()) {
                    for (CandidatePhrase w2 : patternsAndWords4Label.getCounter(l)) {
                        reasonForWords.incrementCount(word, w2);
                    }
                }
            }
            Redwood.log(ConstantsAndVariables.minimaldebug, "Saving output in " + outputdir);
            String filename = outputdir + "/words.json";
            // the json object is an array corresponding to each iteration - of list
            // of objects,
            // each of which is a bean of entity and reasons
            JsonArrayBuilder obj = Json.createArrayBuilder();
            if (writtenInJustification.containsKey(label) && writtenInJustification.get(label)) {
                JsonReader jsonReader = Json.createReader(new BufferedInputStream(new FileInputStream(filename)));
                JsonArray objarr = jsonReader.readArray();
                for (JsonValue o : objarr) obj.add(o);
                jsonReader.close();
            }
            JsonArrayBuilder objThisIter = Json.createArrayBuilder();
            for (CandidatePhrase w : reasonForWords.firstKeySet()) {
                JsonObjectBuilder objinner = Json.createObjectBuilder();
                JsonArrayBuilder l = Json.createArrayBuilder();
                for (CandidatePhrase w2 : reasonForWords.getCounter(w).keySet()) {
                    l.add(w2.getPhrase());
                }
                JsonArrayBuilder pats = Json.createArrayBuilder();
                for (E p : wordsPatExtracted.getCounter(w)) {
                    pats.add(p.toStringSimple());
                }
                objinner.add("reasonwords", l);
                objinner.add("patterns", pats);
                objinner.add("score", finalwords.getCount(w));
                objinner.add("entity", w.getPhrase());
                objThisIter.add(objinner.build());
            }
            obj.add(objThisIter);
            // Redwood.log(ConstantsAndVariables.minimaldebug, channelNameLogger,
            // "Writing justification at " + filename);
            IOUtils.writeStringToFile(StringUtils.normalize(StringUtils.toAscii(obj.build().toString())), filename, "ASCII");
            writtenInJustification.put(label, true);
        }
        if (constVars.justify) {
            Redwood.log(Redwood.DBG, "\nJustification for phrases:\n");
            for (CandidatePhrase word : finalwords.keySet()) {
                Redwood.log(Redwood.DBG, "Phrase " + word + " extracted because of patterns: \t" + Counters.toSortedString(wordsPatExtracted.getCounter(word), wordsPatExtracted.getCounter(word).size(), "%1$s:%2$f", "\n"));
            }
        }
        return finalwords;
    } else if (constVars.wordScoring.equals(WordScoring.BPB)) {
        Counters.addInPlace(terms, wordsPatExtracted);
        Counter<CandidatePhrase> maxPatWeightTerms = new ClassicCounter<>();
        Map<CandidatePhrase, E> wordMaxPat = new HashMap<>();
        for (Entry<CandidatePhrase, ClassicCounter<E>> en : terms.entrySet()) {
            Counter<E> weights = new ClassicCounter<>();
            for (E k : en.getValue().keySet()) weights.setCount(k, patternsLearnedThisIter.getCount(k));
            maxPatWeightTerms.setCount(en.getKey(), Counters.max(weights));
            wordMaxPat.put(en.getKey(), Counters.argmax(weights));
        }
        Counters.removeKeys(maxPatWeightTerms, alreadyIdentifiedWords);
        double maxvalue = Counters.max(maxPatWeightTerms);
        Set<CandidatePhrase> words = Counters.keysAbove(maxPatWeightTerms, maxvalue - 1e-10);
        CandidatePhrase bestw = null;
        if (words.size() > 1) {
            double max = Double.NEGATIVE_INFINITY;
            for (CandidatePhrase w : words) {
                if (terms.getCount(w, wordMaxPat.get(w)) > max) {
                    max = terms.getCount(w, wordMaxPat.get(w));
                    bestw = w;
                }
            }
        } else if (words.size() == 1)
            bestw = words.iterator().next();
        else
            return new ClassicCounter<>();
        Redwood.log(ConstantsAndVariables.minimaldebug, "Selected Words: " + bestw);
        return Counters.asCounter(Arrays.asList(bestw));
    } else
        throw new RuntimeException("wordscoring " + constVars.wordScoring + " not identified");
}
Also used : Entry(java.util.Map.Entry) Counter(edu.stanford.nlp.stats.Counter) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) TwoDimensionalCounter(edu.stanford.nlp.stats.TwoDimensionalCounter) BufferedInputStream(java.io.BufferedInputStream) JsonReader(javax.json.JsonReader) JsonArrayBuilder(javax.json.JsonArrayBuilder) JsonObjectBuilder(javax.json.JsonObjectBuilder) JsonValue(javax.json.JsonValue) TwoDimensionalCounter(edu.stanford.nlp.stats.TwoDimensionalCounter) FileInputStream(java.io.FileInputStream) JsonArray(javax.json.JsonArray) File(java.io.File)

Example 4 with Counter

use of edu.stanford.nlp.stats.Counter in project CoreNLP by stanfordnlp.

the class GetPatternsFromDataMultiClass method setUpConstructor.

@SuppressWarnings("rawtypes")
private void setUpConstructor(Map<String, DataInstance> sents, Map<String, Set<CandidatePhrase>> seedSets, boolean labelUsingSeedSets, Map<String, Class<? extends TypesafeMap.Key<String>>> answerClass, Map<String, Class> generalizeClasses, Map<String, Map<Class, Object>> ignoreClasses) throws IOException, InstantiationException, IllegalAccessException, IllegalArgumentException, InvocationTargetException, NoSuchMethodException, SecurityException, InterruptedException, ExecutionException, ClassNotFoundException {
    Data.sents = sents;
    ArgumentParser.fillOptions(Data.class, props);
    ArgumentParser.fillOptions(ConstantsAndVariables.class, props);
    PatternFactory.setUp(props, PatternFactory.PatternType.valueOf(props.getProperty(Flags.patternType)), seedSets.keySet());
    constVars = new ConstantsAndVariables(props, seedSets, answerClass, generalizeClasses, ignoreClasses);
    if (constVars.writeMatchedTokensFiles && constVars.batchProcessSents) {
        throw new RuntimeException("writeMatchedTokensFiles and batchProcessSents cannot be true at the same time (not implemented; also doesn't make sense to save a large sentences json file)");
    }
    if (constVars.debug < 1) {
        Redwood.hideChannelsEverywhere(ConstantsAndVariables.minimaldebug);
    }
    if (constVars.debug < 2) {
        Redwood.hideChannelsEverywhere(Redwood.DBG);
    }
    constVars.justify = true;
    if (constVars.debug < 3) {
        constVars.justify = false;
    }
    if (constVars.debug < 4) {
        Redwood.hideChannelsEverywhere(ConstantsAndVariables.extremedebug);
    }
    Redwood.log(Redwood.DBG, "Running with debug output");
    Redwood.log(ConstantsAndVariables.extremedebug, "Running with extreme debug output");
    wordsPatExtracted = new HashMap<>();
    for (String label : answerClass.keySet()) {
        wordsPatExtracted.put(label, new TwoDimensionalCounter<>());
    }
    scorePhrases = new ScorePhrases(props, constVars);
    createPats = new CreatePatterns(props, constVars);
    assert !(constVars.doNotApplyPatterns && (PatternFactory.useStopWordsBeforeTerm || PatternFactory.numWordsCompoundMax > 1)) : " Cannot have both doNotApplyPatterns and (useStopWordsBeforeTerm true or numWordsCompound > 1)!";
    if (constVars.invertedIndexDirectory == null) {
        File f = File.createTempFile("inv", "index");
        f.deleteOnExit();
        f.mkdir();
        constVars.invertedIndexDirectory = f.getAbsolutePath();
    }
    Set<String> extremelySmallStopWordsList = CollectionUtils.asSet(".", ",", "in", "on", "of", "a", "the", "an");
    //Function to use to how to add CoreLabels to index
    Function<CoreLabel, Map<String, String>> transformCoreLabelToString = l -> {
        Map<String, String> add = new HashMap<>();
        for (Class gn : constVars.getGeneralizeClasses().values()) {
            Object b = l.get(gn);
            if (b != null && !b.toString().equals(constVars.backgroundSymbol)) {
                add.put(Token.getKeyForClass(gn), b.toString());
            }
        }
        return add;
    };
    boolean createIndex = false;
    if (constVars.loadInvertedIndex)
        constVars.invertedIndex = SentenceIndex.loadIndex(constVars.invertedIndexClass, props, extremelySmallStopWordsList, constVars.invertedIndexDirectory, transformCoreLabelToString);
    else {
        constVars.invertedIndex = SentenceIndex.createIndex(constVars.invertedIndexClass, null, props, extremelySmallStopWordsList, constVars.invertedIndexDirectory, transformCoreLabelToString);
        createIndex = true;
    }
    int totalNumSents = 0;
    boolean computeDataFreq = false;
    if (Data.rawFreq == null) {
        Data.rawFreq = new ClassicCounter<>();
        computeDataFreq = true;
    }
    ConstantsAndVariables.DataSentsIterator iter = new ConstantsAndVariables.DataSentsIterator(constVars.batchProcessSents);
    while (iter.hasNext()) {
        Pair<Map<String, DataInstance>, File> sentsIter = iter.next();
        Map<String, DataInstance> sentsf = sentsIter.first();
        if (constVars.batchProcessSents) {
            for (Entry<String, DataInstance> en : sentsf.entrySet()) {
                Data.sentId2File.put(en.getKey(), sentsIter.second());
            }
        }
        totalNumSents += sentsf.size();
        if (computeDataFreq) {
            Data.computeRawFreqIfNull(sentsf, PatternFactory.numWordsCompoundMax);
        }
        Redwood.log(Redwood.DBG, "Initializing sents size " + sentsf.size() + " sentences, either by labeling with the seed set or just setting the right classes");
        for (String l : constVars.getAnswerClass().keySet()) {
            Redwood.log(Redwood.DBG, "labelUsingSeedSets is " + labelUsingSeedSets + " and seed set size for " + l + " is " + (seedSets == null ? "null" : seedSets.get(l).size()));
            Set<CandidatePhrase> seed = seedSets == null || !labelUsingSeedSets ? new HashSet<>() : (seedSets.containsKey(l) ? seedSets.get(l) : new HashSet<>());
            if (!matchedSeedWords.containsKey(l)) {
                matchedSeedWords.put(l, new ClassicCounter<>());
            }
            Counter<CandidatePhrase> matched = runLabelSeedWords(sentsf, constVars.getAnswerClass().get(l), l, seed, constVars, labelUsingSeedSets);
            System.out.println("matched phrases for " + l + " is " + matched);
            matchedSeedWords.get(l).addAll(matched);
            if (constVars.addIndvWordsFromPhrasesExceptLastAsNeg) {
                Redwood.log(ConstantsAndVariables.minimaldebug, "adding indv words from phrases except last as neg");
                Set<CandidatePhrase> otherseed = new HashSet<>();
                if (labelUsingSeedSets) {
                    for (CandidatePhrase s : seed) {
                        String[] t = s.getPhrase().split("\\s+");
                        for (int i = 0; i < t.length - 1; i++) {
                            if (!seed.contains(t[i])) {
                                otherseed.add(CandidatePhrase.createOrGet(t[i]));
                            }
                        }
                    }
                }
                runLabelSeedWords(sentsf, PatternsAnnotations.OtherSemanticLabel.class, "OTHERSEM", otherseed, constVars, labelUsingSeedSets);
            }
        }
        if (labelUsingSeedSets && constVars.getOtherSemanticClassesWords() != null) {
            String l = "OTHERSEM";
            if (!matchedSeedWords.containsKey(l)) {
                matchedSeedWords.put(l, new ClassicCounter<>());
            }
            matchedSeedWords.get(l).addAll(runLabelSeedWords(sentsf, PatternsAnnotations.OtherSemanticLabel.class, l, constVars.getOtherSemanticClassesWords(), constVars, labelUsingSeedSets));
        }
        if (constVars.removeOverLappingLabelsFromSeed) {
            removeOverLappingLabels(sentsf);
        }
        if (createIndex)
            constVars.invertedIndex.add(sentsf, true);
        if (sentsIter.second().exists()) {
            Redwood.log(Redwood.DBG, "Saving the labeled seed sents (if given the option) to the same file " + sentsIter.second());
            IOUtils.writeObjectToFile(sentsf, sentsIter.second());
        }
    }
    Redwood.log(Redwood.DBG, "Done loading/creating inverted index of tokens and labeling data with total of " + constVars.invertedIndex.size() + " sentences");
    //If the scorer class is LearnFeatWt then individual word class is added as a feature
    if (scorePhrases.phraseScorerClass.equals(ScorePhrasesAverageFeatures.class) && (constVars.usePatternEvalWordClass || constVars.usePhraseEvalWordClass)) {
        if (constVars.externalFeatureWeightsDir == null) {
            File f = File.createTempFile("tempfeat", ".txt");
            f.delete();
            f.deleteOnExit();
            constVars.externalFeatureWeightsDir = f.getAbsolutePath();
        }
        IOUtils.ensureDir(new File(constVars.externalFeatureWeightsDir));
        for (String label : seedSets.keySet()) {
            String externalFeatureWeightsFileLabel = constVars.externalFeatureWeightsDir + "/" + label;
            File f = new File(externalFeatureWeightsFileLabel);
            if (!f.exists()) {
                Redwood.log(Redwood.DBG, "externalweightsfile for the label " + label + " does not exist: learning weights!");
                LearnImportantFeatures lmf = new LearnImportantFeatures();
                ArgumentParser.fillOptions(lmf, props);
                lmf.answerClass = answerClass.get(label);
                lmf.answerLabel = label;
                lmf.setUp();
                lmf.getTopFeatures(new ConstantsAndVariables.DataSentsIterator(constVars.batchProcessSents), constVars.perSelectRand, constVars.perSelectNeg, externalFeatureWeightsFileLabel);
            }
            Counter<Integer> distSimWeightsLabel = new ClassicCounter<>();
            for (String line : IOUtils.readLines(externalFeatureWeightsFileLabel)) {
                String[] t = line.split(":");
                if (!t[0].startsWith("Cluster"))
                    continue;
                String s = t[0].replace("Cluster-", "");
                Integer clusterNum = Integer.parseInt(s);
                distSimWeightsLabel.setCount(clusterNum, Double.parseDouble(t[1]));
            }
            constVars.distSimWeights.put(label, distSimWeightsLabel);
        }
    }
    // computing semantic odds values
    if (constVars.usePatternEvalSemanticOdds || constVars.usePhraseEvalSemanticOdds) {
        Counter<CandidatePhrase> dictOddsWeightsLabel = new ClassicCounter<>();
        Counter<CandidatePhrase> otherSemanticClassFreq = new ClassicCounter<>();
        for (CandidatePhrase s : constVars.getOtherSemanticClassesWords()) {
            for (String s1 : StringUtils.getNgrams(Arrays.asList(s.getPhrase().split("\\s+")), 1, PatternFactory.numWordsCompoundMax)) otherSemanticClassFreq.incrementCount(CandidatePhrase.createOrGet(s1));
        }
        otherSemanticClassFreq = Counters.add(otherSemanticClassFreq, 1.0);
        // otherSemanticClassFreq.setDefaultReturnValue(1.0);
        Map<String, Counter<CandidatePhrase>> labelDictNgram = new HashMap<>();
        for (String label : seedSets.keySet()) {
            Counter<CandidatePhrase> classFreq = new ClassicCounter<>();
            for (CandidatePhrase s : seedSets.get(label)) {
                for (String s1 : StringUtils.getNgrams(Arrays.asList(s.getPhrase().split("\\s+")), 1, PatternFactory.numWordsCompoundMax)) classFreq.incrementCount(CandidatePhrase.createOrGet(s1));
            }
            classFreq = Counters.add(classFreq, 1.0);
            labelDictNgram.put(label, classFreq);
        // classFreq.setDefaultReturnValue(1.0);
        }
        for (String label : seedSets.keySet()) {
            Counter<CandidatePhrase> otherLabelFreq = new ClassicCounter<>();
            for (String label2 : seedSets.keySet()) {
                if (label.equals(label2))
                    continue;
                otherLabelFreq.addAll(labelDictNgram.get(label2));
            }
            otherLabelFreq.addAll(otherSemanticClassFreq);
            dictOddsWeightsLabel = Counters.divisionNonNaN(labelDictNgram.get(label), otherLabelFreq);
            constVars.dictOddsWeights.put(label, dictOddsWeightsLabel);
        }
    }
//Redwood.log(Redwood.DBG, "All options are:" + "\n" + Maps.toString(getAllOptions(), "","","\t","\n"));
}
Also used : ZipOutputStream(java.util.zip.ZipOutputStream) java.util(java.util) Key(edu.stanford.nlp.util.TypesafeMap.Key) TreeAnnotation(edu.stanford.nlp.trees.TreeCoreAnnotations.TreeAnnotation) edu.stanford.nlp.util(edu.stanford.nlp.util) Tree(edu.stanford.nlp.trees.Tree) edu.stanford.nlp.patterns.surface(edu.stanford.nlp.patterns.surface) IOBUtils(edu.stanford.nlp.sequences.IOBUtils) Constructor(java.lang.reflect.Constructor) Function(java.util.function.Function) SQLException(java.sql.SQLException) Interval(org.joda.time.Interval) Counter(edu.stanford.nlp.stats.Counter) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) StanfordCoreNLP(edu.stanford.nlp.pipeline.StanfordCoreNLP) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) ZipEntry(java.util.zip.ZipEntry) javax.json(javax.json) TwoDimensionalCounter(edu.stanford.nlp.stats.TwoDimensionalCounter) IndexedWord(edu.stanford.nlp.ling.IndexedWord) TokenSequencePattern(edu.stanford.nlp.ling.tokensregex.TokenSequencePattern) Period(org.joda.time.Period) CoreLabel(edu.stanford.nlp.ling.CoreLabel) ScorePhraseMeasures(edu.stanford.nlp.patterns.ConstantsAndVariables.ScorePhraseMeasures) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) DataInstanceDep(edu.stanford.nlp.patterns.dep.DataInstanceDep) SemanticGraphEdge(edu.stanford.nlp.semgraph.SemanticGraphEdge) GrammaticalRelation(edu.stanford.nlp.trees.GrammaticalRelation) Counters(edu.stanford.nlp.stats.Counters) java.util.concurrent(java.util.concurrent) IOUtils(edu.stanford.nlp.io.IOUtils) Redwood(edu.stanford.nlp.util.logging.Redwood) DecimalFormat(java.text.DecimalFormat) Field(java.lang.reflect.Field) InvocationTargetException(java.lang.reflect.InvocationTargetException) java.io(java.io) Annotation(edu.stanford.nlp.pipeline.Annotation) Entry(java.util.Map.Entry) Env(edu.stanford.nlp.ling.tokensregex.Env) RegExFileFilter(edu.stanford.nlp.io.RegExFileFilter) PriorityQueue(edu.stanford.nlp.util.PriorityQueue) GoldAnswerAnnotation(edu.stanford.nlp.ling.CoreAnnotations.GoldAnswerAnnotation) Counter(edu.stanford.nlp.stats.Counter) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) TwoDimensionalCounter(edu.stanford.nlp.stats.TwoDimensionalCounter) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) CoreLabel(edu.stanford.nlp.ling.CoreLabel) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter)

Example 5 with Counter

use of edu.stanford.nlp.stats.Counter in project CoreNLP by stanfordnlp.

the class ApplyDepPatterns method getMatchedTokensIndex.

private Collection<ExtractedPhrase> getMatchedTokensIndex(SemanticGraph graph, SemgrexPattern pattern, DataInstance sent, String label) {
    //TODO: look at the ignoreCommonTags flag
    ExtractPhraseFromPattern extract = new ExtractPhraseFromPattern(false, PatternFactory.numWordsCompoundMapped.get(label));
    Collection<IntPair> outputIndices = new ArrayList<>();
    boolean findSubTrees = true;
    List<CoreLabel> tokensC = sent.getTokens();
    //TODO: see if you can get rid of this (only used for matchedGraphs)
    List<String> tokens = tokensC.stream().map(x -> x.word()).collect(Collectors.toList());
    List<String> outputPhrases = new ArrayList<>();
    List<ExtractedPhrase> extractedPhrases = new ArrayList<>();
    Function<Pair<IndexedWord, SemanticGraph>, Counter<String>> extractFeatures = new Function<Pair<IndexedWord, SemanticGraph>, Counter<String>>() {

        @Override
        public Counter<String> apply(Pair<IndexedWord, SemanticGraph> indexedWordSemanticGraphPair) {
            //TODO: make features;
            Counter<String> feat = new ClassicCounter<>();
            IndexedWord vertex = indexedWordSemanticGraphPair.first();
            SemanticGraph graph = indexedWordSemanticGraphPair.second();
            List<Pair<GrammaticalRelation, IndexedWord>> pt = graph.parentPairs(vertex);
            for (Pair<GrammaticalRelation, IndexedWord> en : pt) {
                feat.incrementCount("PARENTREL-" + en.first());
            }
            return feat;
        }
    };
    extract.getSemGrexPatternNodes(graph, tokens, outputPhrases, outputIndices, pattern, findSubTrees, extractedPhrases, constVars.matchLowerCaseContext, matchingWordRestriction);
    //System.out.println("extracted phrases are " + extractedPhrases + " and output indices are " + outputIndices);
    return extractedPhrases;
}
Also used : CoreLabel(edu.stanford.nlp.ling.CoreLabel) java.util(java.util) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) GrammaticalRelation(edu.stanford.nlp.trees.GrammaticalRelation) SurfacePattern(edu.stanford.nlp.patterns.surface.SurfacePattern) edu.stanford.nlp.util(edu.stanford.nlp.util) Callable(java.util.concurrent.Callable) SemgrexMatcher(edu.stanford.nlp.semgraph.semgrex.SemgrexMatcher) Function(java.util.function.Function) Collectors(java.util.stream.Collectors) Counter(edu.stanford.nlp.stats.Counter) edu.stanford.nlp.patterns(edu.stanford.nlp.patterns) TokenSequenceMatcher(edu.stanford.nlp.ling.tokensregex.TokenSequenceMatcher) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) SemgrexPattern(edu.stanford.nlp.semgraph.semgrex.SemgrexPattern) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) TwoDimensionalCounter(edu.stanford.nlp.stats.TwoDimensionalCounter) IndexedWord(edu.stanford.nlp.ling.IndexedWord) TokenSequencePattern(edu.stanford.nlp.ling.tokensregex.TokenSequencePattern) Function(java.util.function.Function) CoreLabel(edu.stanford.nlp.ling.CoreLabel) Counter(edu.stanford.nlp.stats.Counter) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) TwoDimensionalCounter(edu.stanford.nlp.stats.TwoDimensionalCounter) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) SemanticGraph(edu.stanford.nlp.semgraph.SemanticGraph) GrammaticalRelation(edu.stanford.nlp.trees.GrammaticalRelation) IndexedWord(edu.stanford.nlp.ling.IndexedWord)

Aggregations

ClassicCounter (edu.stanford.nlp.stats.ClassicCounter)13 Counter (edu.stanford.nlp.stats.Counter)13 CoreLabel (edu.stanford.nlp.ling.CoreLabel)7 IOUtils (edu.stanford.nlp.io.IOUtils)6 CoreAnnotations (edu.stanford.nlp.ling.CoreAnnotations)6 edu.stanford.nlp.util (edu.stanford.nlp.util)6 java.util (java.util)6 Redwood (edu.stanford.nlp.util.logging.Redwood)5 edu.stanford.nlp.classify (edu.stanford.nlp.classify)4 RuntimeIOException (edu.stanford.nlp.io.RuntimeIOException)4 SemanticGraph (edu.stanford.nlp.semgraph.SemanticGraph)4 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)4 Function (java.util.function.Function)4 IndexedWord (edu.stanford.nlp.ling.IndexedWord)3 RVFDatum (edu.stanford.nlp.ling.RVFDatum)3 TokenSequencePattern (edu.stanford.nlp.ling.tokensregex.TokenSequencePattern)3 Counters (edu.stanford.nlp.stats.Counters)3 TwoDimensionalCounter (edu.stanford.nlp.stats.TwoDimensionalCounter)3 Util (edu.stanford.nlp.util.logging.Redwood.Util)3 java.io (java.io)3