use of edu.stanford.nlp.util.Triple in project CoreNLP by stanfordnlp.
the class SplittingGrammarExtractor method mergeStates.
public void mergeStates() {
if (op.trainOptions.splitRecombineRate <= 0.0) {
return;
}
// we go through the machinery to sum up the temporary betas,
// counting the total mass
TwoDimensionalMap<String, String, double[][]> tempUnaryBetas = new TwoDimensionalMap<>();
ThreeDimensionalMap<String, String, String, double[][][]> tempBinaryBetas = new ThreeDimensionalMap<>();
Map<String, double[]> totalStateMass = Generics.newHashMap();
recalculateTemporaryBetas(false, totalStateMass, tempUnaryBetas, tempBinaryBetas);
// Next, for each tree we count the effect of merging its
// annotations. We only consider the most recently split
// annotations as candidates for merging.
Map<String, double[]> deltaAnnotations = Generics.newHashMap();
for (Tree tree : trees) {
countMergeEffects(tree, totalStateMass, deltaAnnotations);
}
// Now we have a map of the (approximate) likelihood loss from
// merging each state. We merge the ones that provide the least
// benefit, up to the splitRecombineRate
List<Triple<String, Integer, Double>> sortedDeltas = new ArrayList<>();
for (String state : deltaAnnotations.keySet()) {
double[] scores = deltaAnnotations.get(state);
for (int i = 0; i < scores.length; ++i) {
sortedDeltas.add(new Triple<>(state, i * 2, scores[i]));
}
}
Collections.sort(sortedDeltas, new Comparator<Triple<String, Integer, Double>>() {
public int compare(Triple<String, Integer, Double> first, Triple<String, Integer, Double> second) {
// "backwards", sorting from high to low.
return Double.compare(second.third(), first.third());
}
public boolean equals(Object o) {
return o == this;
}
});
// for (Triple<String, Integer, Double> delta : sortedDeltas) {
// System.out.println(delta.first() + "-" + delta.second() + ": " + delta.third());
// }
// System.out.println("-------------");
// Only merge a fraction of the splits based on what the user
// originally asked for
int splitsToMerge = (int) (sortedDeltas.size() * op.trainOptions.splitRecombineRate);
splitsToMerge = Math.max(0, splitsToMerge);
splitsToMerge = Math.min(sortedDeltas.size() - 1, splitsToMerge);
sortedDeltas = sortedDeltas.subList(0, splitsToMerge);
System.out.println();
System.out.println(sortedDeltas);
Map<String, int[]> mergeCorrespondence = buildMergeCorrespondence(sortedDeltas);
recalculateMergedBetas(mergeCorrespondence);
for (Triple<String, Integer, Double> delta : sortedDeltas) {
stateSplitCounts.decrementCount(delta.first(), 1);
}
}
use of edu.stanford.nlp.util.Triple in project CoreNLP by stanfordnlp.
the class ApplyPatternsMulti method call.
@Override
public Pair<TwoDimensionalCounter<Pair<String, String>, E>, CollectionValuedMap<E, Triple<String, Integer, Integer>>> call() throws Exception {
//CollectionValuedMap<String, Integer> tokensMatchedPattern = new CollectionValuedMap<String, Integer>();
CollectionValuedMap<E, Triple<String, Integer, Integer>> matchedTokensByPat = new CollectionValuedMap<>();
TwoDimensionalCounter<Pair<String, String>, E> allFreq = new TwoDimensionalCounter<>();
for (String sentid : sentids) {
List<CoreLabel> sent = sents.get(sentid).getTokens();
//FIND_ALL is faster than FIND_NONOVERLAP
Iterable<SequenceMatchResult<CoreMap>> matched = multiPatternMatcher.find(sent, SequenceMatcher.FindType.FIND_ALL);
for (SequenceMatchResult<CoreMap> m : matched) {
int s = m.start("$term");
int e = m.end("$term");
E matchedPat = patterns.get(m.pattern());
matchedTokensByPat.add(matchedPat, new Triple<>(sentid, s, e));
String phrase = "";
String phraseLemma = "";
boolean useWordNotLabeled = false;
boolean doNotUse = false;
//find if the neighboring words are labeled - if so - club them together
if (constVars.clubNeighboringLabeledWords) {
for (int i = s - 1; i >= 0; i--) {
if (!sent.get(i).get(constVars.getAnswerClass().get(label)).equals(label)) {
s = i + 1;
break;
}
}
for (int i = e; i < sent.size(); i++) {
if (!sent.get(i).get(constVars.getAnswerClass().get(label)).equals(label)) {
e = i;
break;
}
}
}
//to make sure we discard phrases with stopwords in between, but include the ones in which stop words were removed at the ends if removeStopWordsFromSelectedPhrases is true
boolean[] addedindices = new boolean[e - s];
Arrays.fill(addedindices, false);
for (int i = s; i < e; i++) {
CoreLabel l = sent.get(i);
l.set(PatternsAnnotations.MatchedPattern.class, true);
if (!l.containsKey(PatternsAnnotations.MatchedPatterns.class))
l.set(PatternsAnnotations.MatchedPatterns.class, new HashSet<>());
l.get(PatternsAnnotations.MatchedPatterns.class).add(matchedPat);
// }
for (Entry<Class, Object> ig : constVars.getIgnoreWordswithClassesDuringSelection().get(label).entrySet()) {
if (l.containsKey(ig.getKey()) && l.get(ig.getKey()).equals(ig.getValue())) {
doNotUse = true;
}
}
boolean containsStop = containsStopWord(l, constVars.getCommonEngWords(), PatternFactory.ignoreWordRegex);
if (removePhrasesWithStopWords && containsStop) {
doNotUse = true;
} else {
if (!containsStop || !removeStopWordsFromSelectedPhrases) {
if (label == null || l.get(constVars.getAnswerClass().get(label)) == null || !l.get(constVars.getAnswerClass().get(label)).equals(label.toString())) {
useWordNotLabeled = true;
}
phrase += " " + l.word();
phraseLemma += " " + l.lemma();
addedindices[i - s] = true;
}
}
}
for (int i = 0; i < addedindices.length; i++) {
if (i > 0 && i < addedindices.length - 1 && addedindices[i - 1] == true && addedindices[i] == false && addedindices[i + 1] == true) {
doNotUse = true;
break;
}
}
if (!doNotUse && useWordNotLabeled) {
phrase = phrase.trim();
phraseLemma = phraseLemma.trim();
allFreq.incrementCount(new Pair<>(phrase, phraseLemma), matchedPat, 1.0);
}
}
// for (SurfacePattern pat : patterns.keySet()) {
// String patternStr = pat.toString();
//
// TokenSequencePattern p = TokenSequencePattern.compile(constVars.env.get(label), patternStr);
// if (pat == null || p == null)
// throw new RuntimeException("why is the pattern " + pat + " null?");
//
// TokenSequenceMatcher m = p.getMatcher(sent);
// while (m.find()) {
//
// int s = m.start("$term");
// int e = m.end("$term");
//
// String phrase = "";
// String phraseLemma = "";
// boolean useWordNotLabeled = false;
// boolean doNotUse = false;
// for (int i = s; i < e; i++) {
// CoreLabel l = sent.get(i);
// l.set(PatternsAnnotations.MatchedPattern.class, true);
// if (restrictToMatched) {
// tokensMatchedPattern.add(sentid, i);
// }
// for (Entry<Class, Object> ig : constVars.ignoreWordswithClassesDuringSelection.get(label).entrySet()) {
// if (l.containsKey(ig.getKey()) && l.get(ig.getKey()).equals(ig.getValue())) {
// doNotUse = true;
// }
// }
// boolean containsStop = containsStopWord(l, constVars.getCommonEngWords(), constVars.ignoreWordRegex, ignoreWords);
// if (removePhrasesWithStopWords && containsStop) {
// doNotUse = true;
// } else {
// if (!containsStop || !removeStopWordsFromSelectedPhrases) {
//
// if (label == null || l.get(constVars.answerClass.get(label)) == null || !l.get(constVars.answerClass.get(label)).equals(label.toString())) {
// useWordNotLabeled = true;
// }
// phrase += " " + l.word();
// phraseLemma += " " + l.lemma();
//
// }
// }
// }
// if (!doNotUse && useWordNotLabeled) {
// phrase = phrase.trim();
// phraseLemma = phraseLemma.trim();
// allFreq.incrementCount(new Pair<String, String>(phrase, phraseLemma), pat, 1.0);
// }
// }
// }
}
return new Pair<>(allFreq, matchedTokensByPat);
}
use of edu.stanford.nlp.util.Triple in project CoreNLP by stanfordnlp.
the class ChineseSimWordAvgDepGrammar method getMap.
public Map<Pair<Integer, String>, List<Triple<Integer, String, Double>>> getMap(String filename) {
Map<Pair<Integer, String>, List<Triple<Integer, String, Double>>> hashMap = Generics.newHashMap();
try {
BufferedReader wordMapBReader = new BufferedReader(new InputStreamReader(new FileInputStream(filename), "UTF-8"));
String wordMapLine;
Pattern linePattern = Pattern.compile("sim\\((.+)/(.+):(.+)/(.+)\\)=(.+)");
while ((wordMapLine = wordMapBReader.readLine()) != null) {
Matcher m = linePattern.matcher(wordMapLine);
if (!m.matches()) {
log.info("Ill-formed line in similar word map file: " + wordMapLine);
continue;
}
Pair<Integer, String> iTW = new Pair<>(wordIndex.addToIndex(m.group(1)), m.group(2));
double score = Double.parseDouble(m.group(5));
List<Triple<Integer, String, Double>> tripleList = hashMap.get(iTW);
if (tripleList == null) {
tripleList = new ArrayList<>();
hashMap.put(iTW, tripleList);
}
tripleList.add(new Triple<>(wordIndex.addToIndex(m.group(3)), m.group(4), score));
}
} catch (IOException e) {
throw new RuntimeException("Problem reading similar words file!");
}
return hashMap;
}
use of edu.stanford.nlp.util.Triple in project CoreNLP by stanfordnlp.
the class PerceptronModel method trainBatch.
/**
* Trains a batch of trees and returns the following: a list of
* Update objects, the number of transitions correct, and the number
* of transitions wrong.
* <br>
* If the model is trained with multiple threads, it is expected
* that a valid MulticoreWrapper is passed in which does the
* processing. In that case, the processing is done on all of the
* trees without updating any weights, which allows the results for
* multithreaded training to be reproduced.
*/
private Triple<List<Update>, Integer, Integer> trainBatch(List<Integer> indices, List<Tree> binarizedTrees, List<List<Transition>> transitionLists, List<Update> updates, Oracle oracle, MulticoreWrapper<Integer, Pair<Integer, Integer>> wrapper) {
int numCorrect = 0;
int numWrong = 0;
if (op.trainOptions.trainingThreads == 1) {
for (Integer index : indices) {
Pair<Integer, Integer> count = trainTree(index, binarizedTrees, transitionLists, updates, oracle);
numCorrect += count.first;
numWrong += count.second;
}
} else {
for (Integer index : indices) {
wrapper.put(index);
}
wrapper.join(false);
while (wrapper.peek()) {
Pair<Integer, Integer> result = wrapper.poll();
numCorrect += result.first;
numWrong += result.second;
}
}
return new Triple<>(updates, numCorrect, numWrong);
}
use of edu.stanford.nlp.util.Triple in project CoreNLP by stanfordnlp.
the class CRFClassifierITest method runCRFTest.
private static void runCRFTest(CRFClassifier<CoreLabel> crf) {
for (int i = 0; i < testTexts.length; i++) {
String[] testText = testTexts[i];
assertEquals(i + ": Wrong array size in test", 7, testText.length);
// System.err.println("length of string is " + testText[0].length());
String out;
out = crf.classifyToString(testText[0]);
assertEquals(i + ": CRF buggy on classifyToString", testText[1], out);
out = crf.classifyWithInlineXML(testText[0]);
assertEquals(i + ": CRF buggy on classifyWithInlineXML", testText[2], out);
out = crf.classifyToString(testText[0], "xml", false).replaceAll("\r", "");
assertEquals(i + ": CRF buggy on classifyToString(xml, false)", testText[3], out);
out = crf.classifyToString(testText[0], "xml", true);
assertEquals(i + ": CRF buggy on classifyToString(xml, true)", testText[4], out);
out = crf.classifyToString(testText[0], "slashTags", false).replaceAll("\r", "");
// System.out.println("Gold: |" + testText[5] + "|");
// System.out.println("Guess: |" + out + "|");
assertEquals(i + ": CRF buggy on classifyToString(slashTags, false)", testText[5], out);
out = crf.classifyToString(testText[0], "inlineXML", false).replaceAll("\r", "");
assertEquals(i + ": CRF buggy on classifyToString(inlineXML, false)", testText[6], out);
List<Triple<String, Integer, Integer>> trip = crf.classifyToCharacterOffsets(testText[0]);
// I couldn't work out how to avoid a type warning in the next line, sigh [cdm 2009]
assertEquals(i + ": CRF buggy on classifyToCharacterOffsets", Arrays.asList(testTrip[i]), trip);
if (i == 0) {
// cdm 2013: I forget exactly what this was but something about the reduplicated period at the end of Jr.?
Triple<String, Integer, Integer> x = trip.get(trip.size() - 1);
assertEquals("CRF buggy on classifyToCharacterOffsets abbreviation period", 'r', testText[0].charAt(x.third() - 1));
}
if (i == 3) {
// check that tokens have okay offsets
List<List<CoreLabel>> doc = crf.classify(testText[0]);
assertEquals("Wrong number of sentences", 1, doc.size());
List<CoreLabel> tokens = doc.get(0);
assertEquals("Wrong number of tokens", offsets.length, tokens.size());
for (int j = 0, sz = tokens.size(); j < sz; j++) {
CoreLabel token = tokens.get(j);
assertEquals("Wrong begin offset", offsets[j][0], (int) token.get(CoreAnnotations.CharacterOffsetBeginAnnotation.class));
assertEquals("Wrong end offset", offsets[j][1], (int) token.get(CoreAnnotations.CharacterOffsetEndAnnotation.class));
}
}
}
}
Aggregations