use of edu.stanford.nlp.stats.Counter in project CoreNLP by stanfordnlp.
the class FeatureExtractor method getFeatures.
private Counter<String> getFeatures(Document doc, Mention m, Map<Integer, List<Mention>> mentionsByHeadIndex) {
Counter<String> features = new ClassicCounter<>();
// type features
features.incrementCount("mention-type=" + m.mentionType);
features.incrementCount("gender=" + m.gender);
features.incrementCount("person-fine=" + m.person);
features.incrementCount("head-ne-type=" + m.nerString);
List<String> singletonFeatures = m.getSingletonFeatures(dictionaries);
for (Map.Entry<Integer, String> e : SINGLETON_FEATURES.entrySet()) {
if (e.getKey() < singletonFeatures.size()) {
features.incrementCount(e.getValue() + "=" + singletonFeatures.get(e.getKey()));
}
}
// length and location features
addNumeric(features, "mention-length", m.spanToString().length());
addNumeric(features, "mention-words", m.originalSpan.size());
addNumeric(features, "sentence-words", m.sentenceWords.size());
features.incrementCount("sentence-words=" + bin(m.sentenceWords.size()));
features.incrementCount("mention-position", m.mentionNum / (double) doc.predictedMentions.size());
features.incrementCount("sentence-position", m.sentNum / (double) doc.numSentences);
// lexical features
CoreLabel firstWord = firstWord(m);
CoreLabel lastWord = lastWord(m);
CoreLabel headWord = headWord(m);
CoreLabel prevWord = prevWord(m);
CoreLabel nextWord = nextWord(m);
CoreLabel prevprevWord = prevprevWord(m);
CoreLabel nextnextWord = nextnextWord(m);
String headPOS = getPOS(headWord);
String firstPOS = getPOS(firstWord);
String lastPOS = getPOS(lastWord);
String prevPOS = getPOS(prevWord);
String nextPOS = getPOS(nextWord);
String prevprevPOS = getPOS(prevprevWord);
String nextnextPOS = getPOS(nextnextWord);
features.incrementCount("first-word=" + wordIndicator(firstWord, firstPOS));
features.incrementCount("last-word=" + wordIndicator(lastWord, lastPOS));
features.incrementCount("head-word=" + wordIndicator(headWord, headPOS));
features.incrementCount("next-word=" + wordIndicator(nextWord, nextPOS));
features.incrementCount("prev-word=" + wordIndicator(prevWord, prevPOS));
features.incrementCount("next-bigram=" + wordIndicator(nextWord, nextnextWord, nextPOS + "_" + nextnextPOS));
features.incrementCount("prev-bigram=" + wordIndicator(prevprevWord, prevWord, prevprevPOS + "_" + prevPOS));
features.incrementCount("next-pos=" + nextPOS);
features.incrementCount("prev-pos=" + prevPOS);
features.incrementCount("first-pos=" + firstPOS);
features.incrementCount("last-pos=" + lastPOS);
features.incrementCount("next-pos-bigram=" + nextPOS + "_" + nextnextPOS);
features.incrementCount("prev-pos-bigram=" + prevprevPOS + "_" + prevPOS);
addDependencyFeatures(features, "parent", getDependencyParent(m), true);
addFeature(features, "ends-with-head", m.headIndex == m.endIndex - 1);
addFeature(features, "is-generic", m.originalSpan.size() == 1 && firstPOS.equals("NNS"));
// syntax features
IndexedWord w = m.headIndexedWord;
String depPath = "";
int depth = 0;
while (w != null) {
SemanticGraphEdge e = getDependencyParent(m, w);
depth++;
if (depth <= 3 && e != null) {
depPath += (depPath.isEmpty() ? "" : "_") + e.getRelation().toString();
features.incrementCount("dep-path=" + depPath);
w = e.getSource();
} else {
w = null;
}
}
if (useConstituencyParse) {
int fullEmbeddingLevel = headEmbeddingLevel(m.contextParseTree, m.headIndex);
int mentionEmbeddingLevel = headEmbeddingLevel(m.mentionSubTree, m.headIndex - m.startIndex);
if (fullEmbeddingLevel != -1 && mentionEmbeddingLevel != -1) {
features.incrementCount("mention-embedding-level=" + bin(fullEmbeddingLevel - mentionEmbeddingLevel));
features.incrementCount("head-embedding-level=" + bin(mentionEmbeddingLevel));
} else {
features.incrementCount("undetermined-embedding-level");
}
features.incrementCount("num-embedded-nps=" + bin(numEmbeddedNps(m.mentionSubTree)));
String syntaxPath = "";
Tree tree = m.contextParseTree;
Tree head = tree.getLeaves().get(m.headIndex).ancestor(1, tree);
depth = 0;
for (Tree node : tree.pathNodeToNode(head, tree)) {
syntaxPath += node.value() + "-";
features.incrementCount("syntax-path=" + syntaxPath);
depth++;
if (depth >= 4 || node.value().equals("S")) {
break;
}
}
}
// mention containment features
addFeature(features, "contained-in-other-mention", mentionsByHeadIndex.get(m.headIndex).stream().anyMatch(m2 -> m != m2 && m.insideIn(m2)));
addFeature(features, "contains-other-mention", mentionsByHeadIndex.get(m.headIndex).stream().anyMatch(m2 -> m != m2 && m2.insideIn(m)));
// features from dcoref rules
addFeature(features, "bare-plural", m.originalSpan.size() == 1 && headPOS.equals("NNS"));
addFeature(features, "quantifier-start", dictionaries.quantifiers.contains(firstWord.word().toLowerCase()));
addFeature(features, "negative-start", firstWord.word().toLowerCase().matches("none|no|nothing|not"));
addFeature(features, "partitive", RuleBasedCorefMentionFinder.partitiveRule(m, m.sentenceWords, dictionaries));
addFeature(features, "adjectival-demonym", dictionaries.isAdjectivalDemonym(m.spanToString()));
if (doc.docType != DocType.ARTICLE && m.person == Person.YOU && nextWord != null && nextWord.word().equalsIgnoreCase("know")) {
features.incrementCount("generic-you");
}
return features;
}
use of edu.stanford.nlp.stats.Counter in project CoreNLP by stanfordnlp.
the class PairwiseModelTrainer method test.
public static void test(PairwiseModel model, String predictionsName, boolean anaphoricityModel) throws Exception {
Redwood.log("scoref-train", "Reading compression...");
Compressor<String> compressor = IOUtils.readObjectFromFile(StatisticalCorefTrainer.compressorFile);
Redwood.log("scoref-train", "Reading test data...");
List<DocumentExamples> testDocuments = IOUtils.readObjectFromFile(StatisticalCorefTrainer.extractedFeaturesFile);
Redwood.log("scoref-train", "Building test set...");
List<Pair<Example, Map<Integer, CompressedFeatureVector>>> allExamples = anaphoricityModel ? getAnaphoricityExamples(testDocuments) : getExamples(testDocuments);
Redwood.log("scoref-train", "Testing...");
PrintWriter writer = new PrintWriter(model.getDefaultOutputPath() + predictionsName);
Map<Integer, Counter<Pair<Integer, Integer>>> scores = new HashMap<>();
writeScores(allExamples, compressor, model, writer, scores);
if (model instanceof MaxMarginMentionRanker) {
writer.close();
writer = new PrintWriter(model.getDefaultOutputPath() + predictionsName + "_anaphoricity");
testDocuments = IOUtils.readObjectFromFile(StatisticalCorefTrainer.extractedFeaturesFile);
allExamples = getAnaphoricityExamples(testDocuments);
writeScores(allExamples, compressor, model, writer, scores);
}
IOUtils.writeObjectToFile(scores, model.getDefaultOutputPath() + predictionsName + ".ser");
writer.close();
}
use of edu.stanford.nlp.stats.Counter in project CoreNLP by stanfordnlp.
the class ScorePhrases method learnNewPhrasesPrivate.
private Counter<CandidatePhrase> learnNewPhrasesPrivate(String label, PatternsForEachToken patternsForEachToken, Counter<E> patternsLearnedThisIter, Counter<E> allSelectedPatterns, Set<CandidatePhrase> alreadyIdentifiedWords, CollectionValuedMap<E, Triple<String, Integer, Integer>> matchedTokensByPat, Counter<CandidatePhrase> scoreForAllWordsThisIteration, TwoDimensionalCounter<CandidatePhrase, E> terms, TwoDimensionalCounter<CandidatePhrase, E> wordsPatExtracted, TwoDimensionalCounter<E, CandidatePhrase> patternsAndWords4Label, String identifier, Set<CandidatePhrase> ignoreWords, boolean computeProcDataFreq) throws IOException, ClassNotFoundException {
Set<CandidatePhrase> alreadyLabeledWords = new HashSet<>();
if (constVars.doNotApplyPatterns) {
// if want to get the stats by the lossy way of just counting without
// applying the patterns
ConstantsAndVariables.DataSentsIterator sentsIter = new ConstantsAndVariables.DataSentsIterator(constVars.batchProcessSents);
while (sentsIter.hasNext()) {
Pair<Map<String, DataInstance>, File> sentsf = sentsIter.next();
this.statsWithoutApplyingPatterns(sentsf.first(), patternsForEachToken, patternsLearnedThisIter, wordsPatExtracted);
}
} else {
if (patternsLearnedThisIter.size() > 0) {
this.applyPats(patternsLearnedThisIter, label, wordsPatExtracted, matchedTokensByPat, alreadyLabeledWords);
}
}
if (computeProcDataFreq) {
if (!phraseScorer.wordFreqNorm.equals(Normalization.NONE)) {
Redwood.log(Redwood.DBG, "computing processed freq");
for (Entry<CandidatePhrase, Double> fq : Data.rawFreq.entrySet()) {
Double in = fq.getValue();
if (phraseScorer.wordFreqNorm.equals(Normalization.SQRT))
in = Math.sqrt(in);
else if (phraseScorer.wordFreqNorm.equals(Normalization.LOG))
in = 1 + Math.log(in);
else
throw new RuntimeException("can't understand the normalization");
assert !in.isNaN() : "Why is processed freq nan when rawfreq is " + in;
Data.processedDataFreq.setCount(fq.getKey(), in);
}
} else
Data.processedDataFreq = Data.rawFreq;
}
if (constVars.wordScoring.equals(WordScoring.WEIGHTEDNORM)) {
for (CandidatePhrase en : wordsPatExtracted.firstKeySet()) {
if (!constVars.getOtherSemanticClassesWords().contains(en) && (en.getPhraseLemma() == null || !constVars.getOtherSemanticClassesWords().contains(CandidatePhrase.createOrGet(en.getPhraseLemma()))) && !alreadyLabeledWords.contains(en)) {
terms.addAll(en, wordsPatExtracted.getCounter(en));
}
}
removeKeys(terms, constVars.getStopWords());
Counter<CandidatePhrase> phraseScores = phraseScorer.scorePhrases(label, terms, wordsPatExtracted, allSelectedPatterns, alreadyIdentifiedWords, false);
System.out.println("count for word U.S. is " + phraseScores.getCount(CandidatePhrase.createOrGet("U.S.")));
Set<CandidatePhrase> ignoreWordsAll;
if (ignoreWords != null && !ignoreWords.isEmpty()) {
ignoreWordsAll = CollectionUtils.unionAsSet(ignoreWords, constVars.getOtherSemanticClassesWords());
} else
ignoreWordsAll = new HashSet<>(constVars.getOtherSemanticClassesWords());
ignoreWordsAll.addAll(constVars.getSeedLabelDictionary().get(label));
ignoreWordsAll.addAll(constVars.getLearnedWords(label).keySet());
System.out.println("ignoreWordsAll contains word U.S. is " + ignoreWordsAll.contains(CandidatePhrase.createOrGet("U.S.")));
Counter<CandidatePhrase> finalwords = chooseTopWords(phraseScores, terms, phraseScores, ignoreWordsAll, constVars.thresholdWordExtract);
phraseScorer.printReasonForChoosing(finalwords);
scoreForAllWordsThisIteration.clear();
Counters.addInPlace(scoreForAllWordsThisIteration, phraseScores);
Redwood.log(ConstantsAndVariables.minimaldebug, "\n\n## Selected Words for " + label + " : " + Counters.toSortedString(finalwords, finalwords.size(), "%1$s:%2$.2f", "\t"));
if (constVars.goldEntities != null) {
Map<String, Boolean> goldEntities4Label = constVars.goldEntities.get(label);
if (goldEntities4Label != null) {
StringBuffer s = new StringBuffer();
finalwords.keySet().stream().forEach(x -> s.append(x.getPhrase() + (goldEntities4Label.containsKey(x.getPhrase()) ? ":" + goldEntities4Label.get(x.getPhrase()) : ":UKNOWN") + "\n"));
Redwood.log(ConstantsAndVariables.minimaldebug, "\n\n## Gold labels for selected words for label " + label + " : " + s.toString());
} else
Redwood.log(Redwood.DBG, "No gold entities provided for label " + label);
}
if (constVars.outDir != null && !constVars.outDir.isEmpty()) {
String outputdir = constVars.outDir + "/" + identifier + "/" + label;
IOUtils.ensureDir(new File(outputdir));
TwoDimensionalCounter<CandidatePhrase, CandidatePhrase> reasonForWords = new TwoDimensionalCounter<>();
for (CandidatePhrase word : finalwords.keySet()) {
for (E l : wordsPatExtracted.getCounter(word).keySet()) {
for (CandidatePhrase w2 : patternsAndWords4Label.getCounter(l)) {
reasonForWords.incrementCount(word, w2);
}
}
}
Redwood.log(ConstantsAndVariables.minimaldebug, "Saving output in " + outputdir);
String filename = outputdir + "/words.json";
// the json object is an array corresponding to each iteration - of list
// of objects,
// each of which is a bean of entity and reasons
JsonArrayBuilder obj = Json.createArrayBuilder();
if (writtenInJustification.containsKey(label) && writtenInJustification.get(label)) {
JsonReader jsonReader = Json.createReader(new BufferedInputStream(new FileInputStream(filename)));
JsonArray objarr = jsonReader.readArray();
for (JsonValue o : objarr) obj.add(o);
jsonReader.close();
}
JsonArrayBuilder objThisIter = Json.createArrayBuilder();
for (CandidatePhrase w : reasonForWords.firstKeySet()) {
JsonObjectBuilder objinner = Json.createObjectBuilder();
JsonArrayBuilder l = Json.createArrayBuilder();
for (CandidatePhrase w2 : reasonForWords.getCounter(w).keySet()) {
l.add(w2.getPhrase());
}
JsonArrayBuilder pats = Json.createArrayBuilder();
for (E p : wordsPatExtracted.getCounter(w)) {
pats.add(p.toStringSimple());
}
objinner.add("reasonwords", l);
objinner.add("patterns", pats);
objinner.add("score", finalwords.getCount(w));
objinner.add("entity", w.getPhrase());
objThisIter.add(objinner.build());
}
obj.add(objThisIter);
// Redwood.log(ConstantsAndVariables.minimaldebug, channelNameLogger,
// "Writing justification at " + filename);
IOUtils.writeStringToFile(StringUtils.normalize(StringUtils.toAscii(obj.build().toString())), filename, "ASCII");
writtenInJustification.put(label, true);
}
if (constVars.justify) {
Redwood.log(Redwood.DBG, "\nJustification for phrases:\n");
for (CandidatePhrase word : finalwords.keySet()) {
Redwood.log(Redwood.DBG, "Phrase " + word + " extracted because of patterns: \t" + Counters.toSortedString(wordsPatExtracted.getCounter(word), wordsPatExtracted.getCounter(word).size(), "%1$s:%2$f", "\n"));
}
}
return finalwords;
} else if (constVars.wordScoring.equals(WordScoring.BPB)) {
Counters.addInPlace(terms, wordsPatExtracted);
Counter<CandidatePhrase> maxPatWeightTerms = new ClassicCounter<>();
Map<CandidatePhrase, E> wordMaxPat = new HashMap<>();
for (Entry<CandidatePhrase, ClassicCounter<E>> en : terms.entrySet()) {
Counter<E> weights = new ClassicCounter<>();
for (E k : en.getValue().keySet()) weights.setCount(k, patternsLearnedThisIter.getCount(k));
maxPatWeightTerms.setCount(en.getKey(), Counters.max(weights));
wordMaxPat.put(en.getKey(), Counters.argmax(weights));
}
Counters.removeKeys(maxPatWeightTerms, alreadyIdentifiedWords);
double maxvalue = Counters.max(maxPatWeightTerms);
Set<CandidatePhrase> words = Counters.keysAbove(maxPatWeightTerms, maxvalue - 1e-10);
CandidatePhrase bestw = null;
if (words.size() > 1) {
double max = Double.NEGATIVE_INFINITY;
for (CandidatePhrase w : words) {
if (terms.getCount(w, wordMaxPat.get(w)) > max) {
max = terms.getCount(w, wordMaxPat.get(w));
bestw = w;
}
}
} else if (words.size() == 1)
bestw = words.iterator().next();
else
return new ClassicCounter<>();
Redwood.log(ConstantsAndVariables.minimaldebug, "Selected Words: " + bestw);
return Counters.asCounter(Arrays.asList(bestw));
} else
throw new RuntimeException("wordscoring " + constVars.wordScoring + " not identified");
}
use of edu.stanford.nlp.stats.Counter in project CoreNLP by stanfordnlp.
the class GetPatternsFromDataMultiClass method setUpConstructor.
@SuppressWarnings("rawtypes")
private void setUpConstructor(Map<String, DataInstance> sents, Map<String, Set<CandidatePhrase>> seedSets, boolean labelUsingSeedSets, Map<String, Class<? extends TypesafeMap.Key<String>>> answerClass, Map<String, Class> generalizeClasses, Map<String, Map<Class, Object>> ignoreClasses) throws IOException, InstantiationException, IllegalAccessException, IllegalArgumentException, InvocationTargetException, NoSuchMethodException, SecurityException, InterruptedException, ExecutionException, ClassNotFoundException {
Data.sents = sents;
ArgumentParser.fillOptions(Data.class, props);
ArgumentParser.fillOptions(ConstantsAndVariables.class, props);
PatternFactory.setUp(props, PatternFactory.PatternType.valueOf(props.getProperty(Flags.patternType)), seedSets.keySet());
constVars = new ConstantsAndVariables(props, seedSets, answerClass, generalizeClasses, ignoreClasses);
if (constVars.writeMatchedTokensFiles && constVars.batchProcessSents) {
throw new RuntimeException("writeMatchedTokensFiles and batchProcessSents cannot be true at the same time (not implemented; also doesn't make sense to save a large sentences json file)");
}
if (constVars.debug < 1) {
Redwood.hideChannelsEverywhere(ConstantsAndVariables.minimaldebug);
}
if (constVars.debug < 2) {
Redwood.hideChannelsEverywhere(Redwood.DBG);
}
constVars.justify = true;
if (constVars.debug < 3) {
constVars.justify = false;
}
if (constVars.debug < 4) {
Redwood.hideChannelsEverywhere(ConstantsAndVariables.extremedebug);
}
Redwood.log(Redwood.DBG, "Running with debug output");
Redwood.log(ConstantsAndVariables.extremedebug, "Running with extreme debug output");
wordsPatExtracted = new HashMap<>();
for (String label : answerClass.keySet()) {
wordsPatExtracted.put(label, new TwoDimensionalCounter<>());
}
scorePhrases = new ScorePhrases(props, constVars);
createPats = new CreatePatterns(props, constVars);
assert !(constVars.doNotApplyPatterns && (PatternFactory.useStopWordsBeforeTerm || PatternFactory.numWordsCompoundMax > 1)) : " Cannot have both doNotApplyPatterns and (useStopWordsBeforeTerm true or numWordsCompound > 1)!";
if (constVars.invertedIndexDirectory == null) {
File f = File.createTempFile("inv", "index");
f.deleteOnExit();
f.mkdir();
constVars.invertedIndexDirectory = f.getAbsolutePath();
}
Set<String> extremelySmallStopWordsList = CollectionUtils.asSet(".", ",", "in", "on", "of", "a", "the", "an");
//Function to use to how to add CoreLabels to index
Function<CoreLabel, Map<String, String>> transformCoreLabelToString = l -> {
Map<String, String> add = new HashMap<>();
for (Class gn : constVars.getGeneralizeClasses().values()) {
Object b = l.get(gn);
if (b != null && !b.toString().equals(constVars.backgroundSymbol)) {
add.put(Token.getKeyForClass(gn), b.toString());
}
}
return add;
};
boolean createIndex = false;
if (constVars.loadInvertedIndex)
constVars.invertedIndex = SentenceIndex.loadIndex(constVars.invertedIndexClass, props, extremelySmallStopWordsList, constVars.invertedIndexDirectory, transformCoreLabelToString);
else {
constVars.invertedIndex = SentenceIndex.createIndex(constVars.invertedIndexClass, null, props, extremelySmallStopWordsList, constVars.invertedIndexDirectory, transformCoreLabelToString);
createIndex = true;
}
int totalNumSents = 0;
boolean computeDataFreq = false;
if (Data.rawFreq == null) {
Data.rawFreq = new ClassicCounter<>();
computeDataFreq = true;
}
ConstantsAndVariables.DataSentsIterator iter = new ConstantsAndVariables.DataSentsIterator(constVars.batchProcessSents);
while (iter.hasNext()) {
Pair<Map<String, DataInstance>, File> sentsIter = iter.next();
Map<String, DataInstance> sentsf = sentsIter.first();
if (constVars.batchProcessSents) {
for (Entry<String, DataInstance> en : sentsf.entrySet()) {
Data.sentId2File.put(en.getKey(), sentsIter.second());
}
}
totalNumSents += sentsf.size();
if (computeDataFreq) {
Data.computeRawFreqIfNull(sentsf, PatternFactory.numWordsCompoundMax);
}
Redwood.log(Redwood.DBG, "Initializing sents size " + sentsf.size() + " sentences, either by labeling with the seed set or just setting the right classes");
for (String l : constVars.getAnswerClass().keySet()) {
Redwood.log(Redwood.DBG, "labelUsingSeedSets is " + labelUsingSeedSets + " and seed set size for " + l + " is " + (seedSets == null ? "null" : seedSets.get(l).size()));
Set<CandidatePhrase> seed = seedSets == null || !labelUsingSeedSets ? new HashSet<>() : (seedSets.containsKey(l) ? seedSets.get(l) : new HashSet<>());
if (!matchedSeedWords.containsKey(l)) {
matchedSeedWords.put(l, new ClassicCounter<>());
}
Counter<CandidatePhrase> matched = runLabelSeedWords(sentsf, constVars.getAnswerClass().get(l), l, seed, constVars, labelUsingSeedSets);
System.out.println("matched phrases for " + l + " is " + matched);
matchedSeedWords.get(l).addAll(matched);
if (constVars.addIndvWordsFromPhrasesExceptLastAsNeg) {
Redwood.log(ConstantsAndVariables.minimaldebug, "adding indv words from phrases except last as neg");
Set<CandidatePhrase> otherseed = new HashSet<>();
if (labelUsingSeedSets) {
for (CandidatePhrase s : seed) {
String[] t = s.getPhrase().split("\\s+");
for (int i = 0; i < t.length - 1; i++) {
if (!seed.contains(t[i])) {
otherseed.add(CandidatePhrase.createOrGet(t[i]));
}
}
}
}
runLabelSeedWords(sentsf, PatternsAnnotations.OtherSemanticLabel.class, "OTHERSEM", otherseed, constVars, labelUsingSeedSets);
}
}
if (labelUsingSeedSets && constVars.getOtherSemanticClassesWords() != null) {
String l = "OTHERSEM";
if (!matchedSeedWords.containsKey(l)) {
matchedSeedWords.put(l, new ClassicCounter<>());
}
matchedSeedWords.get(l).addAll(runLabelSeedWords(sentsf, PatternsAnnotations.OtherSemanticLabel.class, l, constVars.getOtherSemanticClassesWords(), constVars, labelUsingSeedSets));
}
if (constVars.removeOverLappingLabelsFromSeed) {
removeOverLappingLabels(sentsf);
}
if (createIndex)
constVars.invertedIndex.add(sentsf, true);
if (sentsIter.second().exists()) {
Redwood.log(Redwood.DBG, "Saving the labeled seed sents (if given the option) to the same file " + sentsIter.second());
IOUtils.writeObjectToFile(sentsf, sentsIter.second());
}
}
Redwood.log(Redwood.DBG, "Done loading/creating inverted index of tokens and labeling data with total of " + constVars.invertedIndex.size() + " sentences");
//If the scorer class is LearnFeatWt then individual word class is added as a feature
if (scorePhrases.phraseScorerClass.equals(ScorePhrasesAverageFeatures.class) && (constVars.usePatternEvalWordClass || constVars.usePhraseEvalWordClass)) {
if (constVars.externalFeatureWeightsDir == null) {
File f = File.createTempFile("tempfeat", ".txt");
f.delete();
f.deleteOnExit();
constVars.externalFeatureWeightsDir = f.getAbsolutePath();
}
IOUtils.ensureDir(new File(constVars.externalFeatureWeightsDir));
for (String label : seedSets.keySet()) {
String externalFeatureWeightsFileLabel = constVars.externalFeatureWeightsDir + "/" + label;
File f = new File(externalFeatureWeightsFileLabel);
if (!f.exists()) {
Redwood.log(Redwood.DBG, "externalweightsfile for the label " + label + " does not exist: learning weights!");
LearnImportantFeatures lmf = new LearnImportantFeatures();
ArgumentParser.fillOptions(lmf, props);
lmf.answerClass = answerClass.get(label);
lmf.answerLabel = label;
lmf.setUp();
lmf.getTopFeatures(new ConstantsAndVariables.DataSentsIterator(constVars.batchProcessSents), constVars.perSelectRand, constVars.perSelectNeg, externalFeatureWeightsFileLabel);
}
Counter<Integer> distSimWeightsLabel = new ClassicCounter<>();
for (String line : IOUtils.readLines(externalFeatureWeightsFileLabel)) {
String[] t = line.split(":");
if (!t[0].startsWith("Cluster"))
continue;
String s = t[0].replace("Cluster-", "");
Integer clusterNum = Integer.parseInt(s);
distSimWeightsLabel.setCount(clusterNum, Double.parseDouble(t[1]));
}
constVars.distSimWeights.put(label, distSimWeightsLabel);
}
}
// computing semantic odds values
if (constVars.usePatternEvalSemanticOdds || constVars.usePhraseEvalSemanticOdds) {
Counter<CandidatePhrase> dictOddsWeightsLabel = new ClassicCounter<>();
Counter<CandidatePhrase> otherSemanticClassFreq = new ClassicCounter<>();
for (CandidatePhrase s : constVars.getOtherSemanticClassesWords()) {
for (String s1 : StringUtils.getNgrams(Arrays.asList(s.getPhrase().split("\\s+")), 1, PatternFactory.numWordsCompoundMax)) otherSemanticClassFreq.incrementCount(CandidatePhrase.createOrGet(s1));
}
otherSemanticClassFreq = Counters.add(otherSemanticClassFreq, 1.0);
// otherSemanticClassFreq.setDefaultReturnValue(1.0);
Map<String, Counter<CandidatePhrase>> labelDictNgram = new HashMap<>();
for (String label : seedSets.keySet()) {
Counter<CandidatePhrase> classFreq = new ClassicCounter<>();
for (CandidatePhrase s : seedSets.get(label)) {
for (String s1 : StringUtils.getNgrams(Arrays.asList(s.getPhrase().split("\\s+")), 1, PatternFactory.numWordsCompoundMax)) classFreq.incrementCount(CandidatePhrase.createOrGet(s1));
}
classFreq = Counters.add(classFreq, 1.0);
labelDictNgram.put(label, classFreq);
// classFreq.setDefaultReturnValue(1.0);
}
for (String label : seedSets.keySet()) {
Counter<CandidatePhrase> otherLabelFreq = new ClassicCounter<>();
for (String label2 : seedSets.keySet()) {
if (label.equals(label2))
continue;
otherLabelFreq.addAll(labelDictNgram.get(label2));
}
otherLabelFreq.addAll(otherSemanticClassFreq);
dictOddsWeightsLabel = Counters.divisionNonNaN(labelDictNgram.get(label), otherLabelFreq);
constVars.dictOddsWeights.put(label, dictOddsWeightsLabel);
}
}
//Redwood.log(Redwood.DBG, "All options are:" + "\n" + Maps.toString(getAllOptions(), "","","\t","\n"));
}
use of edu.stanford.nlp.stats.Counter in project CoreNLP by stanfordnlp.
the class ApplyDepPatterns method getMatchedTokensIndex.
private Collection<ExtractedPhrase> getMatchedTokensIndex(SemanticGraph graph, SemgrexPattern pattern, DataInstance sent, String label) {
//TODO: look at the ignoreCommonTags flag
ExtractPhraseFromPattern extract = new ExtractPhraseFromPattern(false, PatternFactory.numWordsCompoundMapped.get(label));
Collection<IntPair> outputIndices = new ArrayList<>();
boolean findSubTrees = true;
List<CoreLabel> tokensC = sent.getTokens();
//TODO: see if you can get rid of this (only used for matchedGraphs)
List<String> tokens = tokensC.stream().map(x -> x.word()).collect(Collectors.toList());
List<String> outputPhrases = new ArrayList<>();
List<ExtractedPhrase> extractedPhrases = new ArrayList<>();
Function<Pair<IndexedWord, SemanticGraph>, Counter<String>> extractFeatures = new Function<Pair<IndexedWord, SemanticGraph>, Counter<String>>() {
@Override
public Counter<String> apply(Pair<IndexedWord, SemanticGraph> indexedWordSemanticGraphPair) {
//TODO: make features;
Counter<String> feat = new ClassicCounter<>();
IndexedWord vertex = indexedWordSemanticGraphPair.first();
SemanticGraph graph = indexedWordSemanticGraphPair.second();
List<Pair<GrammaticalRelation, IndexedWord>> pt = graph.parentPairs(vertex);
for (Pair<GrammaticalRelation, IndexedWord> en : pt) {
feat.incrementCount("PARENTREL-" + en.first());
}
return feat;
}
};
extract.getSemGrexPatternNodes(graph, tokens, outputPhrases, outputIndices, pattern, findSubTrees, extractedPhrases, constVars.matchLowerCaseContext, matchingWordRestriction);
//System.out.println("extracted phrases are " + extractedPhrases + " and output indices are " + outputIndices);
return extractedPhrases;
}
Aggregations