use of edu.stanford.nlp.stats.TwoDimensionalCounter in project CoreNLP by stanfordnlp.
the class MWEPreprocessor method main.
/**
*
* @param args
*/
public static void main(String[] args) {
if (args.length != 1) {
System.err.printf("Usage: java %s file%n", MWEPreprocessor.class.getName());
System.exit(-1);
}
final File treeFile = new File(args[0]);
TwoDimensionalCounter<String, String> labelTerm = new TwoDimensionalCounter<>();
TwoDimensionalCounter<String, String> termLabel = new TwoDimensionalCounter<>();
TwoDimensionalCounter<String, String> labelPreterm = new TwoDimensionalCounter<>();
TwoDimensionalCounter<String, String> pretermLabel = new TwoDimensionalCounter<>();
TwoDimensionalCounter<String, String> unigramTagger = new TwoDimensionalCounter<>();
try {
BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(treeFile), "UTF-8"));
TreeReaderFactory trf = new FrenchTreeReaderFactory();
TreeReader tr = trf.newTreeReader(br);
for (Tree t; (t = tr.readTree()) != null; ) {
countMWEStatistics(t, unigramTagger, labelPreterm, pretermLabel, labelTerm, termLabel);
}
//Closes the underlying reader
tr.close();
System.out.println("Generating {MWE Type -> Terminal}");
printCounter(labelTerm, "label_term.csv");
System.out.println("Generating {Terminal -> MWE Type}");
printCounter(termLabel, "term_label.csv");
System.out.println("Generating {MWE Type -> POS sequence}");
printCounter(labelPreterm, "label_pos.csv");
System.out.println("Generating {POS sequence -> MWE Type}");
printCounter(pretermLabel, "pos_label.csv");
if (RESOLVE_DUMMY_TAGS) {
System.out.println("Resolving DUMMY tags");
resolveDummyTags(treeFile, pretermLabel, unigramTagger);
}
System.out.println("#Unknown Word Types: " + ManualUWModel.nUnknownWordTypes);
System.out.println("#Missing POS: " + nMissingPOS);
System.out.println("#Missing Phrasal: " + nMissingPhrasal);
System.out.println("Done!");
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
}
use of edu.stanford.nlp.stats.TwoDimensionalCounter in project CoreNLP by stanfordnlp.
the class ScorePhrases method runParallelApplyPats.
void runParallelApplyPats(Map<String, DataInstance> sents, String label, E pattern, TwoDimensionalCounter<CandidatePhrase, E> wordsandLemmaPatExtracted, CollectionValuedMap<E, Triple<String, Integer, Integer>> matchedTokensByPat, Set<CandidatePhrase> alreadyLabeledWords) {
Redwood.log(Redwood.DBG, "Applying pattern " + pattern + " to a total of " + sents.size() + " sentences ");
List<String> notAllowedClasses = new ArrayList<>();
List<String> sentids = CollectionUtils.toList(sents.keySet());
if (constVars.doNotExtractPhraseAnyWordLabeledOtherClass) {
for (String l : constVars.getAnswerClass().keySet()) {
if (!l.equals(label)) {
notAllowedClasses.add(l);
}
}
notAllowedClasses.add("OTHERSEM");
}
Map<TokenSequencePattern, E> surfacePatternsLearnedThisIterConverted = null;
Map<SemgrexPattern, E> depPatternsLearnedThisIterConverted = null;
if (constVars.patternType.equals(PatternFactory.PatternType.SURFACE)) {
surfacePatternsLearnedThisIterConverted = new HashMap<>();
String patternStr = null;
try {
patternStr = pattern.toString(notAllowedClasses);
TokenSequencePattern pat = TokenSequencePattern.compile(constVars.env.get(label), patternStr);
surfacePatternsLearnedThisIterConverted.put(pat, pattern);
} catch (Exception e) {
log.info("Error applying patterrn " + patternStr + ". Probably an ill formed pattern (can be because of special symbols in label names). Contact the software developer.");
throw e;
}
} else if (constVars.patternType.equals(PatternFactory.PatternType.DEP)) {
depPatternsLearnedThisIterConverted = new HashMap<>();
SemgrexPattern pat = SemgrexPattern.compile(pattern.toString(notAllowedClasses), new edu.stanford.nlp.semgraph.semgrex.Env(constVars.env.get(label).getVariables()));
depPatternsLearnedThisIterConverted.put(pat, pattern);
} else
throw new UnsupportedOperationException();
//Apply the patterns and extract candidate phrases
int num;
int numThreads = constVars.numThreads;
//If number of sentences is less, do not create so many threads
if (sents.size() < 50)
numThreads = 1;
if (numThreads == 1)
num = sents.size();
else
num = sents.size() / (numThreads - 1);
ExecutorService executor = Executors.newFixedThreadPool(constVars.numThreads);
List<Future<Triple<TwoDimensionalCounter<CandidatePhrase, E>, CollectionValuedMap<E, Triple<String, Integer, Integer>>, Set<CandidatePhrase>>>> list = new ArrayList<>();
for (int i = 0; i < numThreads; i++) {
Callable<Triple<TwoDimensionalCounter<CandidatePhrase, E>, CollectionValuedMap<E, Triple<String, Integer, Integer>>, Set<CandidatePhrase>>> task = null;
if (pattern.type.equals(PatternFactory.PatternType.SURFACE))
//Redwood.log(Redwood.DBG, "Applying pats: assigning sentences " + i*num + " to " +Math.min(sentids.size(), (i + 1) * num) + " to thread " + (i+1));
task = new ApplyPatterns(sents, num == sents.size() ? sentids : sentids.subList(i * num, Math.min(sentids.size(), (i + 1) * num)), surfacePatternsLearnedThisIterConverted, label, constVars.removeStopWordsFromSelectedPhrases, constVars.removePhrasesWithStopWords, constVars);
else
task = new ApplyDepPatterns(sents, num == sents.size() ? sentids : sentids.subList(i * num, Math.min(sentids.size(), (i + 1) * num)), depPatternsLearnedThisIterConverted, label, constVars.removeStopWordsFromSelectedPhrases, constVars.removePhrasesWithStopWords, constVars);
Future<Triple<TwoDimensionalCounter<CandidatePhrase, E>, CollectionValuedMap<E, Triple<String, Integer, Integer>>, Set<CandidatePhrase>>> submit = executor.submit(task);
list.add(submit);
}
// Now retrieve the result
for (Future<Triple<TwoDimensionalCounter<CandidatePhrase, E>, CollectionValuedMap<E, Triple<String, Integer, Integer>>, Set<CandidatePhrase>>> future : list) {
try {
Triple<TwoDimensionalCounter<CandidatePhrase, E>, CollectionValuedMap<E, Triple<String, Integer, Integer>>, Set<CandidatePhrase>> result = future.get();
Redwood.log(ConstantsAndVariables.extremedebug, "Pattern " + pattern + " extracted phrases " + result.first());
wordsandLemmaPatExtracted.addAll(result.first());
matchedTokensByPat.addAll(result.second());
alreadyLabeledWords.addAll(result.third());
} catch (Exception e) {
executor.shutdownNow();
throw new RuntimeException(e);
}
}
executor.shutdown();
}
use of edu.stanford.nlp.stats.TwoDimensionalCounter in project CoreNLP by stanfordnlp.
the class ScorePhrases method learnNewPhrasesPrivate.
private Counter<CandidatePhrase> learnNewPhrasesPrivate(String label, PatternsForEachToken patternsForEachToken, Counter<E> patternsLearnedThisIter, Counter<E> allSelectedPatterns, Set<CandidatePhrase> alreadyIdentifiedWords, CollectionValuedMap<E, Triple<String, Integer, Integer>> matchedTokensByPat, Counter<CandidatePhrase> scoreForAllWordsThisIteration, TwoDimensionalCounter<CandidatePhrase, E> terms, TwoDimensionalCounter<CandidatePhrase, E> wordsPatExtracted, TwoDimensionalCounter<E, CandidatePhrase> patternsAndWords4Label, String identifier, Set<CandidatePhrase> ignoreWords, boolean computeProcDataFreq) throws IOException, ClassNotFoundException {
Set<CandidatePhrase> alreadyLabeledWords = new HashSet<>();
if (constVars.doNotApplyPatterns) {
// if want to get the stats by the lossy way of just counting without
// applying the patterns
ConstantsAndVariables.DataSentsIterator sentsIter = new ConstantsAndVariables.DataSentsIterator(constVars.batchProcessSents);
while (sentsIter.hasNext()) {
Pair<Map<String, DataInstance>, File> sentsf = sentsIter.next();
this.statsWithoutApplyingPatterns(sentsf.first(), patternsForEachToken, patternsLearnedThisIter, wordsPatExtracted);
}
} else {
if (patternsLearnedThisIter.size() > 0) {
this.applyPats(patternsLearnedThisIter, label, wordsPatExtracted, matchedTokensByPat, alreadyLabeledWords);
}
}
if (computeProcDataFreq) {
if (!phraseScorer.wordFreqNorm.equals(Normalization.NONE)) {
Redwood.log(Redwood.DBG, "computing processed freq");
for (Entry<CandidatePhrase, Double> fq : Data.rawFreq.entrySet()) {
Double in = fq.getValue();
if (phraseScorer.wordFreqNorm.equals(Normalization.SQRT))
in = Math.sqrt(in);
else if (phraseScorer.wordFreqNorm.equals(Normalization.LOG))
in = 1 + Math.log(in);
else
throw new RuntimeException("can't understand the normalization");
assert !in.isNaN() : "Why is processed freq nan when rawfreq is " + in;
Data.processedDataFreq.setCount(fq.getKey(), in);
}
} else
Data.processedDataFreq = Data.rawFreq;
}
if (constVars.wordScoring.equals(WordScoring.WEIGHTEDNORM)) {
for (CandidatePhrase en : wordsPatExtracted.firstKeySet()) {
if (!constVars.getOtherSemanticClassesWords().contains(en) && (en.getPhraseLemma() == null || !constVars.getOtherSemanticClassesWords().contains(CandidatePhrase.createOrGet(en.getPhraseLemma()))) && !alreadyLabeledWords.contains(en)) {
terms.addAll(en, wordsPatExtracted.getCounter(en));
}
}
removeKeys(terms, constVars.getStopWords());
Counter<CandidatePhrase> phraseScores = phraseScorer.scorePhrases(label, terms, wordsPatExtracted, allSelectedPatterns, alreadyIdentifiedWords, false);
System.out.println("count for word U.S. is " + phraseScores.getCount(CandidatePhrase.createOrGet("U.S.")));
Set<CandidatePhrase> ignoreWordsAll;
if (ignoreWords != null && !ignoreWords.isEmpty()) {
ignoreWordsAll = CollectionUtils.unionAsSet(ignoreWords, constVars.getOtherSemanticClassesWords());
} else
ignoreWordsAll = new HashSet<>(constVars.getOtherSemanticClassesWords());
ignoreWordsAll.addAll(constVars.getSeedLabelDictionary().get(label));
ignoreWordsAll.addAll(constVars.getLearnedWords(label).keySet());
System.out.println("ignoreWordsAll contains word U.S. is " + ignoreWordsAll.contains(CandidatePhrase.createOrGet("U.S.")));
Counter<CandidatePhrase> finalwords = chooseTopWords(phraseScores, terms, phraseScores, ignoreWordsAll, constVars.thresholdWordExtract);
phraseScorer.printReasonForChoosing(finalwords);
scoreForAllWordsThisIteration.clear();
Counters.addInPlace(scoreForAllWordsThisIteration, phraseScores);
Redwood.log(ConstantsAndVariables.minimaldebug, "\n\n## Selected Words for " + label + " : " + Counters.toSortedString(finalwords, finalwords.size(), "%1$s:%2$.2f", "\t"));
if (constVars.goldEntities != null) {
Map<String, Boolean> goldEntities4Label = constVars.goldEntities.get(label);
if (goldEntities4Label != null) {
StringBuffer s = new StringBuffer();
finalwords.keySet().stream().forEach(x -> s.append(x.getPhrase() + (goldEntities4Label.containsKey(x.getPhrase()) ? ":" + goldEntities4Label.get(x.getPhrase()) : ":UKNOWN") + "\n"));
Redwood.log(ConstantsAndVariables.minimaldebug, "\n\n## Gold labels for selected words for label " + label + " : " + s.toString());
} else
Redwood.log(Redwood.DBG, "No gold entities provided for label " + label);
}
if (constVars.outDir != null && !constVars.outDir.isEmpty()) {
String outputdir = constVars.outDir + "/" + identifier + "/" + label;
IOUtils.ensureDir(new File(outputdir));
TwoDimensionalCounter<CandidatePhrase, CandidatePhrase> reasonForWords = new TwoDimensionalCounter<>();
for (CandidatePhrase word : finalwords.keySet()) {
for (E l : wordsPatExtracted.getCounter(word).keySet()) {
for (CandidatePhrase w2 : patternsAndWords4Label.getCounter(l)) {
reasonForWords.incrementCount(word, w2);
}
}
}
Redwood.log(ConstantsAndVariables.minimaldebug, "Saving output in " + outputdir);
String filename = outputdir + "/words.json";
// the json object is an array corresponding to each iteration - of list
// of objects,
// each of which is a bean of entity and reasons
JsonArrayBuilder obj = Json.createArrayBuilder();
if (writtenInJustification.containsKey(label) && writtenInJustification.get(label)) {
JsonReader jsonReader = Json.createReader(new BufferedInputStream(new FileInputStream(filename)));
JsonArray objarr = jsonReader.readArray();
for (JsonValue o : objarr) obj.add(o);
jsonReader.close();
}
JsonArrayBuilder objThisIter = Json.createArrayBuilder();
for (CandidatePhrase w : reasonForWords.firstKeySet()) {
JsonObjectBuilder objinner = Json.createObjectBuilder();
JsonArrayBuilder l = Json.createArrayBuilder();
for (CandidatePhrase w2 : reasonForWords.getCounter(w).keySet()) {
l.add(w2.getPhrase());
}
JsonArrayBuilder pats = Json.createArrayBuilder();
for (E p : wordsPatExtracted.getCounter(w)) {
pats.add(p.toStringSimple());
}
objinner.add("reasonwords", l);
objinner.add("patterns", pats);
objinner.add("score", finalwords.getCount(w));
objinner.add("entity", w.getPhrase());
objThisIter.add(objinner.build());
}
obj.add(objThisIter);
// Redwood.log(ConstantsAndVariables.minimaldebug, channelNameLogger,
// "Writing justification at " + filename);
IOUtils.writeStringToFile(StringUtils.normalize(StringUtils.toAscii(obj.build().toString())), filename, "ASCII");
writtenInJustification.put(label, true);
}
if (constVars.justify) {
Redwood.log(Redwood.DBG, "\nJustification for phrases:\n");
for (CandidatePhrase word : finalwords.keySet()) {
Redwood.log(Redwood.DBG, "Phrase " + word + " extracted because of patterns: \t" + Counters.toSortedString(wordsPatExtracted.getCounter(word), wordsPatExtracted.getCounter(word).size(), "%1$s:%2$f", "\n"));
}
}
return finalwords;
} else if (constVars.wordScoring.equals(WordScoring.BPB)) {
Counters.addInPlace(terms, wordsPatExtracted);
Counter<CandidatePhrase> maxPatWeightTerms = new ClassicCounter<>();
Map<CandidatePhrase, E> wordMaxPat = new HashMap<>();
for (Entry<CandidatePhrase, ClassicCounter<E>> en : terms.entrySet()) {
Counter<E> weights = new ClassicCounter<>();
for (E k : en.getValue().keySet()) weights.setCount(k, patternsLearnedThisIter.getCount(k));
maxPatWeightTerms.setCount(en.getKey(), Counters.max(weights));
wordMaxPat.put(en.getKey(), Counters.argmax(weights));
}
Counters.removeKeys(maxPatWeightTerms, alreadyIdentifiedWords);
double maxvalue = Counters.max(maxPatWeightTerms);
Set<CandidatePhrase> words = Counters.keysAbove(maxPatWeightTerms, maxvalue - 1e-10);
CandidatePhrase bestw = null;
if (words.size() > 1) {
double max = Double.NEGATIVE_INFINITY;
for (CandidatePhrase w : words) {
if (terms.getCount(w, wordMaxPat.get(w)) > max) {
max = terms.getCount(w, wordMaxPat.get(w));
bestw = w;
}
}
} else if (words.size() == 1)
bestw = words.iterator().next();
else
return new ClassicCounter<>();
Redwood.log(ConstantsAndVariables.minimaldebug, "Selected Words: " + bestw);
return Counters.asCounter(Arrays.asList(bestw));
} else
throw new RuntimeException("wordscoring " + constVars.wordScoring + " not identified");
}
use of edu.stanford.nlp.stats.TwoDimensionalCounter in project CoreNLP by stanfordnlp.
the class GetPatternsFromDataMultiClass method getPatterns.
@SuppressWarnings({ "unchecked" })
public Counter<E> getPatterns(String label, Set<E> alreadyIdentifiedPatterns, E p0, Counter<CandidatePhrase> p0Set, Set<E> ignorePatterns) throws IOException, ClassNotFoundException {
TwoDimensionalCounter<E, CandidatePhrase> patternsandWords4Label = new TwoDimensionalCounter<>();
TwoDimensionalCounter<E, CandidatePhrase> negPatternsandWords4Label = new TwoDimensionalCounter<>();
//TwoDimensionalCounter<E, String> posnegPatternsandWords4Label = new TwoDimensionalCounter<E, String>();
TwoDimensionalCounter<E, CandidatePhrase> unLabeledPatternsandWords4Label = new TwoDimensionalCounter<>();
//TwoDimensionalCounter<E, String> negandUnLabeledPatternsandWords4Label = new TwoDimensionalCounter<E, String>();
//TwoDimensionalCounter<E, String> allPatternsandWords4Label = new TwoDimensionalCounter<E, String>();
Set<String> allCandidatePhrases = new HashSet<>();
ConstantsAndVariables.DataSentsIterator sentsIter = new ConstantsAndVariables.DataSentsIterator(constVars.batchProcessSents);
boolean firstCallToProcessSents = true;
while (sentsIter.hasNext()) {
Pair<Map<String, DataInstance>, File> sentsPair = sentsIter.next();
if (notComputedAllPatternsYet) {
//in the first iteration
processSents(sentsPair.first(), firstCallToProcessSents);
firstCallToProcessSents = false;
if (patsForEachToken == null) {
//in the first iteration, for the first file
patsForEachToken = PatternsForEachToken.getPatternsInstance(props, constVars.storePatsForEachToken);
readSavedPatternsAndIndex();
}
}
this.calculateSufficientStats(sentsPair.first(), patsForEachToken, label, patternsandWords4Label, negPatternsandWords4Label, unLabeledPatternsandWords4Label, allCandidatePhrases);
}
notComputedAllPatternsYet = false;
if (constVars.computeAllPatterns) {
if (constVars.storePatsForEachToken.equals(ConstantsAndVariables.PatternForEachTokenWay.DB))
patsForEachToken.createIndexIfUsingDBAndNotExists();
if (constVars.allPatternsDir != null) {
IOUtils.ensureDir(new File(constVars.allPatternsDir));
patsForEachToken.save(constVars.allPatternsDir);
}
//savePatternIndex(constVars.allPatternsDir);
}
patsForEachToken.close();
//This is important. It makes sure that we don't recompute patterns in every iteration!
constVars.computeAllPatterns = false;
if (patternsandWords == null)
patternsandWords = new HashMap<>();
if (currentPatternWeights == null)
currentPatternWeights = new HashMap<>();
Counter<E> currentPatternWeights4Label = new ClassicCounter<>();
Set<E> removePats = enforceMinSupportRequirements(patternsandWords4Label, unLabeledPatternsandWords4Label);
Counters.removeKeys(patternsandWords4Label, removePats);
Counters.removeKeys(unLabeledPatternsandWords4Label, removePats);
Counters.removeKeys(negPatternsandWords4Label, removePats);
ScorePatterns scorePatterns;
Class<?> patternscoringclass = getPatternScoringClass(constVars.patternScoring);
if (patternscoringclass != null && patternscoringclass.equals(ScorePatternsF1.class)) {
scorePatterns = new ScorePatternsF1(constVars, constVars.patternScoring, label, allCandidatePhrases, patternsandWords4Label, negPatternsandWords4Label, unLabeledPatternsandWords4Label, props, p0Set, p0);
Counter<E> finalPat = scorePatterns.score();
Counters.removeKeys(finalPat, alreadyIdentifiedPatterns);
Counters.retainNonZeros(finalPat);
Counters.retainTop(finalPat, constVars.numPatterns);
if (Double.isNaN(Counters.max(finalPat)))
throw new RuntimeException("how is the value NaN");
Redwood.log(ConstantsAndVariables.minimaldebug, "Selected Patterns: " + finalPat);
return finalPat;
} else if (patternscoringclass != null && patternscoringclass.equals(ScorePatternsRatioModifiedFreq.class)) {
scorePatterns = new ScorePatternsRatioModifiedFreq(constVars, constVars.patternScoring, label, allCandidatePhrases, patternsandWords4Label, negPatternsandWords4Label, unLabeledPatternsandWords4Label, phInPatScoresCache, scorePhrases, props);
} else if (patternscoringclass != null && patternscoringclass.equals(ScorePatternsFreqBased.class)) {
scorePatterns = new ScorePatternsFreqBased(constVars, constVars.patternScoring, label, allCandidatePhrases, patternsandWords4Label, negPatternsandWords4Label, unLabeledPatternsandWords4Label, props);
} else if (constVars.patternScoring.equals(PatternScoring.kNN)) {
try {
Class<? extends ScorePatterns> clazz = (Class<? extends ScorePatterns>) Class.forName("edu.stanford.nlp.patterns.ScorePatternsKNN");
Constructor<? extends ScorePatterns> ctor = clazz.getConstructor(ConstantsAndVariables.class, PatternScoring.class, String.class, Set.class, TwoDimensionalCounter.class, TwoDimensionalCounter.class, TwoDimensionalCounter.class, ScorePhrases.class, Properties.class);
scorePatterns = ctor.newInstance(constVars, constVars.patternScoring, label, allCandidatePhrases, patternsandWords4Label, negPatternsandWords4Label, unLabeledPatternsandWords4Label, scorePhrases, props);
} catch (ClassNotFoundException e) {
throw new RuntimeException("kNN pattern scoring is not released yet. Stay tuned.");
} catch (NoSuchMethodException | InvocationTargetException | InstantiationException | IllegalAccessException e) {
throw new RuntimeException("newinstance of kNN not created", e);
}
} else {
throw new RuntimeException(constVars.patternScoring + " is not implemented (check spelling?). ");
}
scorePatterns.setUp(props);
currentPatternWeights4Label = scorePatterns.score();
Redwood.log(ConstantsAndVariables.extremedebug, "patterns counter size is " + currentPatternWeights4Label.size());
if (ignorePatterns != null && !ignorePatterns.isEmpty()) {
Counters.removeKeys(currentPatternWeights4Label, ignorePatterns);
Redwood.log(ConstantsAndVariables.extremedebug, "Removing patterns from ignorePatterns of size " + ignorePatterns.size() + ". New patterns size " + currentPatternWeights4Label.size());
}
if (alreadyIdentifiedPatterns != null && !alreadyIdentifiedPatterns.isEmpty()) {
Redwood.log(ConstantsAndVariables.extremedebug, "Patterns size is " + currentPatternWeights4Label.size());
Counters.removeKeys(currentPatternWeights4Label, alreadyIdentifiedPatterns);
Redwood.log(ConstantsAndVariables.extremedebug, "Removing already identified patterns of size " + alreadyIdentifiedPatterns.size() + ". New patterns size " + currentPatternWeights4Label.size());
}
PriorityQueue<E> q = Counters.toPriorityQueue(currentPatternWeights4Label);
int num = 0;
Counter<E> chosenPat = new ClassicCounter<>();
Set<E> removePatterns = new HashSet<>();
Set<E> removeIdentifiedPatterns = null;
while (num < constVars.numPatterns && !q.isEmpty()) {
E pat = q.removeFirst();
if (currentPatternWeights4Label.getCount(pat) < constVars.thresholdSelectPattern) {
Redwood.log(Redwood.DBG, "The max weight of candidate patterns is " + df.format(currentPatternWeights4Label.getCount(pat)) + " so not adding anymore patterns");
break;
}
boolean notchoose = false;
if (!unLabeledPatternsandWords4Label.containsFirstKey(pat) || unLabeledPatternsandWords4Label.getCounter(pat).isEmpty()) {
Redwood.log(ConstantsAndVariables.extremedebug, "Removing pattern " + pat + " because it has no unlab support; pos words: " + patternsandWords4Label.getCounter(pat));
notchoose = true;
continue;
}
Set<E> removeChosenPats = null;
if (!notchoose) {
if (alreadyIdentifiedPatterns != null) {
for (E p : alreadyIdentifiedPatterns) {
if (Pattern.subsumes(constVars.patternType, pat, p)) {
// if (pat.getNextContextStr().contains(p.getNextContextStr()) &&
// pat.getPrevContextStr().contains(p.getPrevContextStr())) {
Redwood.log(ConstantsAndVariables.extremedebug, "Not choosing pattern " + pat + " because it is contained in or contains the already chosen pattern " + p);
notchoose = true;
break;
}
int rest = pat.equalContext(p);
// the contexts dont match
if (rest == Integer.MAX_VALUE)
continue;
// if pat is less restrictive, remove p and add pat!
if (rest < 0) {
if (removeIdentifiedPatterns == null)
removeIdentifiedPatterns = new HashSet<>();
removeIdentifiedPatterns.add(p);
} else {
notchoose = true;
break;
}
}
}
}
// In this iteration:
if (!notchoose) {
for (Pattern p : chosenPat.keySet()) {
//E p = constVars.getPatternIndex().get(pindex);
boolean removeChosenPatFlag = false;
if (Pattern.sameGenre(constVars.patternType, pat, p)) {
if (Pattern.subsumes(constVars.patternType, pat, p)) {
Redwood.log(ConstantsAndVariables.extremedebug, "Not choosing pattern " + pat + " because it is contained in or contains the already chosen pattern " + p);
notchoose = true;
break;
} else if (E.subsumes(constVars.patternType, p, pat)) {
//subsume is true even if equal context
//check if equal context
int rest = pat.equalContext(p);
// the contexts do not match
if (rest == Integer.MAX_VALUE) {
Redwood.log(ConstantsAndVariables.extremedebug, "Not choosing pattern " + p + " because it is contained in or contains another chosen pattern in this iteration " + pat);
removeChosenPatFlag = true;
} else // add pat!
if (rest < 0) {
removeChosenPatFlag = true;
} else {
notchoose = true;
break;
}
}
if (removeChosenPatFlag) {
if (removeChosenPats == null)
removeChosenPats = new HashSet<>();
removeChosenPats.add(pat);
num--;
}
}
}
}
if (notchoose) {
Redwood.log(Redwood.DBG, "Not choosing " + pat + " for whatever reason!");
continue;
}
if (removeChosenPats != null) {
Redwood.log(ConstantsAndVariables.extremedebug, "Removing already chosen patterns in this iteration " + removeChosenPats + " in favor of " + pat);
Counters.removeKeys(chosenPat, removeChosenPats);
}
if (removeIdentifiedPatterns != null) {
Redwood.log(ConstantsAndVariables.extremedebug, "Removing already identified patterns " + removeIdentifiedPatterns + " in favor of " + pat);
removePatterns.addAll(removeIdentifiedPatterns);
}
chosenPat.setCount(pat, currentPatternWeights4Label.getCount(pat));
num++;
}
this.removeLearnedPatterns(label, removePatterns);
Redwood.log(Redwood.DBG, "final size of the patterns is " + chosenPat.size());
Redwood.log(ConstantsAndVariables.minimaldebug, "\n\n## Selected Patterns for " + label + "##\n");
List<Pair<E, Double>> chosenPatSorted = Counters.toSortedListWithCounts(chosenPat);
for (Pair<E, Double> en : chosenPatSorted) Redwood.log(ConstantsAndVariables.minimaldebug, en.first() + ":" + df.format(en.second) + "\n");
if (constVars.outDir != null && !constVars.outDir.isEmpty()) {
CollectionValuedMap<E, CandidatePhrase> posWords = new CollectionValuedMap<>();
for (Entry<E, ClassicCounter<CandidatePhrase>> en : patternsandWords4Label.entrySet()) {
posWords.addAll(en.getKey(), en.getValue().keySet());
}
CollectionValuedMap<E, CandidatePhrase> negWords = new CollectionValuedMap<>();
for (Entry<E, ClassicCounter<CandidatePhrase>> en : negPatternsandWords4Label.entrySet()) {
negWords.addAll(en.getKey(), en.getValue().keySet());
}
CollectionValuedMap<E, CandidatePhrase> unlabWords = new CollectionValuedMap<>();
for (Entry<E, ClassicCounter<CandidatePhrase>> en : unLabeledPatternsandWords4Label.entrySet()) {
unlabWords.addAll(en.getKey(), en.getValue().keySet());
}
if (constVars.outDir != null) {
String outputdir = constVars.outDir + "/" + constVars.identifier + "/" + label;
Redwood.log(ConstantsAndVariables.minimaldebug, "Saving output in " + outputdir);
IOUtils.ensureDir(new File(outputdir));
String filename = outputdir + "/patterns" + ".json";
JsonArrayBuilder obj = Json.createArrayBuilder();
if (writtenPatInJustification.containsKey(label) && writtenPatInJustification.get(label)) {
JsonReader jsonReader = Json.createReader(new BufferedInputStream(new FileInputStream(filename)));
JsonArray objarr = jsonReader.readArray();
jsonReader.close();
for (JsonValue o : objarr) obj.add(o);
} else
obj = Json.createArrayBuilder();
JsonObjectBuilder objThisIter = Json.createObjectBuilder();
for (Pair<E, Double> pat : chosenPatSorted) {
JsonObjectBuilder o = Json.createObjectBuilder();
JsonArrayBuilder pos = Json.createArrayBuilder();
JsonArrayBuilder neg = Json.createArrayBuilder();
JsonArrayBuilder unlab = Json.createArrayBuilder();
for (CandidatePhrase w : posWords.get(pat.first())) pos.add(w.getPhrase());
for (CandidatePhrase w : negWords.get(pat.first())) neg.add(w.getPhrase());
for (CandidatePhrase w : unlabWords.get(pat.first())) unlab.add(w.getPhrase());
o.add("Positive", pos);
o.add("Negative", neg);
o.add("Unlabeled", unlab);
o.add("Score", pat.second());
objThisIter.add(pat.first().toStringSimple(), o);
}
obj.add(objThisIter.build());
IOUtils.ensureDir(new File(filename).getParentFile());
IOUtils.writeStringToFile(StringUtils.normalize(StringUtils.toAscii(obj.build().toString())), filename, "ASCII");
writtenPatInJustification.put(label, true);
}
}
if (constVars.justify) {
Redwood.log(Redwood.DBG, "Justification for Patterns:");
for (E key : chosenPat.keySet()) {
Redwood.log(Redwood.DBG, "\nPattern: " + key);
Redwood.log(Redwood.DBG, "Positive Words:" + Counters.toSortedString(patternsandWords4Label.getCounter(key), patternsandWords4Label.getCounter(key).size(), "%1$s:%2$f", ";"));
Redwood.log(Redwood.DBG, "Negative Words:" + Counters.toSortedString(negPatternsandWords4Label.getCounter(key), negPatternsandWords4Label.getCounter(key).size(), "%1$s:%2$f", ";"));
Redwood.log(Redwood.DBG, "Unlabeled Words:" + Counters.toSortedString(unLabeledPatternsandWords4Label.getCounter(key), unLabeledPatternsandWords4Label.getCounter(key).size(), "%1$s:%2$f", ";"));
}
}
//allPatternsandWords.put(label, allPatternsandWords4Label);
patternsandWords.put(label, patternsandWords4Label);
currentPatternWeights.put(label, currentPatternWeights4Label);
return chosenPat;
}
use of edu.stanford.nlp.stats.TwoDimensionalCounter in project CoreNLP by stanfordnlp.
the class GetPatternsFromDataMultiClass method setUpConstructor.
@SuppressWarnings("rawtypes")
private void setUpConstructor(Map<String, DataInstance> sents, Map<String, Set<CandidatePhrase>> seedSets, boolean labelUsingSeedSets, Map<String, Class<? extends TypesafeMap.Key<String>>> answerClass, Map<String, Class> generalizeClasses, Map<String, Map<Class, Object>> ignoreClasses) throws IOException, InstantiationException, IllegalAccessException, IllegalArgumentException, InvocationTargetException, NoSuchMethodException, SecurityException, InterruptedException, ExecutionException, ClassNotFoundException {
Data.sents = sents;
ArgumentParser.fillOptions(Data.class, props);
ArgumentParser.fillOptions(ConstantsAndVariables.class, props);
PatternFactory.setUp(props, PatternFactory.PatternType.valueOf(props.getProperty(Flags.patternType)), seedSets.keySet());
constVars = new ConstantsAndVariables(props, seedSets, answerClass, generalizeClasses, ignoreClasses);
if (constVars.writeMatchedTokensFiles && constVars.batchProcessSents) {
throw new RuntimeException("writeMatchedTokensFiles and batchProcessSents cannot be true at the same time (not implemented; also doesn't make sense to save a large sentences json file)");
}
if (constVars.debug < 1) {
Redwood.hideChannelsEverywhere(ConstantsAndVariables.minimaldebug);
}
if (constVars.debug < 2) {
Redwood.hideChannelsEverywhere(Redwood.DBG);
}
constVars.justify = true;
if (constVars.debug < 3) {
constVars.justify = false;
}
if (constVars.debug < 4) {
Redwood.hideChannelsEverywhere(ConstantsAndVariables.extremedebug);
}
Redwood.log(Redwood.DBG, "Running with debug output");
Redwood.log(ConstantsAndVariables.extremedebug, "Running with extreme debug output");
wordsPatExtracted = new HashMap<>();
for (String label : answerClass.keySet()) {
wordsPatExtracted.put(label, new TwoDimensionalCounter<>());
}
scorePhrases = new ScorePhrases(props, constVars);
createPats = new CreatePatterns(props, constVars);
assert !(constVars.doNotApplyPatterns && (PatternFactory.useStopWordsBeforeTerm || PatternFactory.numWordsCompoundMax > 1)) : " Cannot have both doNotApplyPatterns and (useStopWordsBeforeTerm true or numWordsCompound > 1)!";
if (constVars.invertedIndexDirectory == null) {
File f = File.createTempFile("inv", "index");
f.deleteOnExit();
f.mkdir();
constVars.invertedIndexDirectory = f.getAbsolutePath();
}
Set<String> extremelySmallStopWordsList = CollectionUtils.asSet(".", ",", "in", "on", "of", "a", "the", "an");
//Function to use to how to add CoreLabels to index
Function<CoreLabel, Map<String, String>> transformCoreLabelToString = l -> {
Map<String, String> add = new HashMap<>();
for (Class gn : constVars.getGeneralizeClasses().values()) {
Object b = l.get(gn);
if (b != null && !b.toString().equals(constVars.backgroundSymbol)) {
add.put(Token.getKeyForClass(gn), b.toString());
}
}
return add;
};
boolean createIndex = false;
if (constVars.loadInvertedIndex)
constVars.invertedIndex = SentenceIndex.loadIndex(constVars.invertedIndexClass, props, extremelySmallStopWordsList, constVars.invertedIndexDirectory, transformCoreLabelToString);
else {
constVars.invertedIndex = SentenceIndex.createIndex(constVars.invertedIndexClass, null, props, extremelySmallStopWordsList, constVars.invertedIndexDirectory, transformCoreLabelToString);
createIndex = true;
}
int totalNumSents = 0;
boolean computeDataFreq = false;
if (Data.rawFreq == null) {
Data.rawFreq = new ClassicCounter<>();
computeDataFreq = true;
}
ConstantsAndVariables.DataSentsIterator iter = new ConstantsAndVariables.DataSentsIterator(constVars.batchProcessSents);
while (iter.hasNext()) {
Pair<Map<String, DataInstance>, File> sentsIter = iter.next();
Map<String, DataInstance> sentsf = sentsIter.first();
if (constVars.batchProcessSents) {
for (Entry<String, DataInstance> en : sentsf.entrySet()) {
Data.sentId2File.put(en.getKey(), sentsIter.second());
}
}
totalNumSents += sentsf.size();
if (computeDataFreq) {
Data.computeRawFreqIfNull(sentsf, PatternFactory.numWordsCompoundMax);
}
Redwood.log(Redwood.DBG, "Initializing sents size " + sentsf.size() + " sentences, either by labeling with the seed set or just setting the right classes");
for (String l : constVars.getAnswerClass().keySet()) {
Redwood.log(Redwood.DBG, "labelUsingSeedSets is " + labelUsingSeedSets + " and seed set size for " + l + " is " + (seedSets == null ? "null" : seedSets.get(l).size()));
Set<CandidatePhrase> seed = seedSets == null || !labelUsingSeedSets ? new HashSet<>() : (seedSets.containsKey(l) ? seedSets.get(l) : new HashSet<>());
if (!matchedSeedWords.containsKey(l)) {
matchedSeedWords.put(l, new ClassicCounter<>());
}
Counter<CandidatePhrase> matched = runLabelSeedWords(sentsf, constVars.getAnswerClass().get(l), l, seed, constVars, labelUsingSeedSets);
System.out.println("matched phrases for " + l + " is " + matched);
matchedSeedWords.get(l).addAll(matched);
if (constVars.addIndvWordsFromPhrasesExceptLastAsNeg) {
Redwood.log(ConstantsAndVariables.minimaldebug, "adding indv words from phrases except last as neg");
Set<CandidatePhrase> otherseed = new HashSet<>();
if (labelUsingSeedSets) {
for (CandidatePhrase s : seed) {
String[] t = s.getPhrase().split("\\s+");
for (int i = 0; i < t.length - 1; i++) {
if (!seed.contains(t[i])) {
otherseed.add(CandidatePhrase.createOrGet(t[i]));
}
}
}
}
runLabelSeedWords(sentsf, PatternsAnnotations.OtherSemanticLabel.class, "OTHERSEM", otherseed, constVars, labelUsingSeedSets);
}
}
if (labelUsingSeedSets && constVars.getOtherSemanticClassesWords() != null) {
String l = "OTHERSEM";
if (!matchedSeedWords.containsKey(l)) {
matchedSeedWords.put(l, new ClassicCounter<>());
}
matchedSeedWords.get(l).addAll(runLabelSeedWords(sentsf, PatternsAnnotations.OtherSemanticLabel.class, l, constVars.getOtherSemanticClassesWords(), constVars, labelUsingSeedSets));
}
if (constVars.removeOverLappingLabelsFromSeed) {
removeOverLappingLabels(sentsf);
}
if (createIndex)
constVars.invertedIndex.add(sentsf, true);
if (sentsIter.second().exists()) {
Redwood.log(Redwood.DBG, "Saving the labeled seed sents (if given the option) to the same file " + sentsIter.second());
IOUtils.writeObjectToFile(sentsf, sentsIter.second());
}
}
Redwood.log(Redwood.DBG, "Done loading/creating inverted index of tokens and labeling data with total of " + constVars.invertedIndex.size() + " sentences");
//If the scorer class is LearnFeatWt then individual word class is added as a feature
if (scorePhrases.phraseScorerClass.equals(ScorePhrasesAverageFeatures.class) && (constVars.usePatternEvalWordClass || constVars.usePhraseEvalWordClass)) {
if (constVars.externalFeatureWeightsDir == null) {
File f = File.createTempFile("tempfeat", ".txt");
f.delete();
f.deleteOnExit();
constVars.externalFeatureWeightsDir = f.getAbsolutePath();
}
IOUtils.ensureDir(new File(constVars.externalFeatureWeightsDir));
for (String label : seedSets.keySet()) {
String externalFeatureWeightsFileLabel = constVars.externalFeatureWeightsDir + "/" + label;
File f = new File(externalFeatureWeightsFileLabel);
if (!f.exists()) {
Redwood.log(Redwood.DBG, "externalweightsfile for the label " + label + " does not exist: learning weights!");
LearnImportantFeatures lmf = new LearnImportantFeatures();
ArgumentParser.fillOptions(lmf, props);
lmf.answerClass = answerClass.get(label);
lmf.answerLabel = label;
lmf.setUp();
lmf.getTopFeatures(new ConstantsAndVariables.DataSentsIterator(constVars.batchProcessSents), constVars.perSelectRand, constVars.perSelectNeg, externalFeatureWeightsFileLabel);
}
Counter<Integer> distSimWeightsLabel = new ClassicCounter<>();
for (String line : IOUtils.readLines(externalFeatureWeightsFileLabel)) {
String[] t = line.split(":");
if (!t[0].startsWith("Cluster"))
continue;
String s = t[0].replace("Cluster-", "");
Integer clusterNum = Integer.parseInt(s);
distSimWeightsLabel.setCount(clusterNum, Double.parseDouble(t[1]));
}
constVars.distSimWeights.put(label, distSimWeightsLabel);
}
}
// computing semantic odds values
if (constVars.usePatternEvalSemanticOdds || constVars.usePhraseEvalSemanticOdds) {
Counter<CandidatePhrase> dictOddsWeightsLabel = new ClassicCounter<>();
Counter<CandidatePhrase> otherSemanticClassFreq = new ClassicCounter<>();
for (CandidatePhrase s : constVars.getOtherSemanticClassesWords()) {
for (String s1 : StringUtils.getNgrams(Arrays.asList(s.getPhrase().split("\\s+")), 1, PatternFactory.numWordsCompoundMax)) otherSemanticClassFreq.incrementCount(CandidatePhrase.createOrGet(s1));
}
otherSemanticClassFreq = Counters.add(otherSemanticClassFreq, 1.0);
// otherSemanticClassFreq.setDefaultReturnValue(1.0);
Map<String, Counter<CandidatePhrase>> labelDictNgram = new HashMap<>();
for (String label : seedSets.keySet()) {
Counter<CandidatePhrase> classFreq = new ClassicCounter<>();
for (CandidatePhrase s : seedSets.get(label)) {
for (String s1 : StringUtils.getNgrams(Arrays.asList(s.getPhrase().split("\\s+")), 1, PatternFactory.numWordsCompoundMax)) classFreq.incrementCount(CandidatePhrase.createOrGet(s1));
}
classFreq = Counters.add(classFreq, 1.0);
labelDictNgram.put(label, classFreq);
// classFreq.setDefaultReturnValue(1.0);
}
for (String label : seedSets.keySet()) {
Counter<CandidatePhrase> otherLabelFreq = new ClassicCounter<>();
for (String label2 : seedSets.keySet()) {
if (label.equals(label2))
continue;
otherLabelFreq.addAll(labelDictNgram.get(label2));
}
otherLabelFreq.addAll(otherSemanticClassFreq);
dictOddsWeightsLabel = Counters.divisionNonNaN(labelDictNgram.get(label), otherLabelFreq);
constVars.dictOddsWeights.put(label, dictOddsWeightsLabel);
}
}
//Redwood.log(Redwood.DBG, "All options are:" + "\n" + Maps.toString(getAllOptions(), "","","\t","\n"));
}
Aggregations