use of edu.stanford.nlp.sequences.ExactBestSequenceFinder in project CoreNLP by stanfordnlp.
the class CMMClassifier method classifySeq.
/**
* Classify a List of {@link CoreLabel}s using sequence information
* (i.e. Viterbi or Beam Search).
*
* @param document A List of {@link CoreLabel}s to be classified
*/
private void classifySeq(List<IN> document) {
if (document.isEmpty()) {
return;
}
SequenceModel ts = getSequenceModel(document);
// TagScorer ts = new PrevOnlyScorer(document, tagIndex, this, (!flags.useTaggySequences ? (flags.usePrevSequences ? 1 : 0) : flags.maxLeft), 0, answerArrays);
int[] tags;
//log.info("***begin test***");
if (flags.useViterbi) {
ExactBestSequenceFinder ti = new ExactBestSequenceFinder();
tags = ti.bestSequence(ts);
} else {
BeamBestSequenceFinder ti = new BeamBestSequenceFinder(flags.beamSize, true, true);
tags = ti.bestSequence(ts, document.size());
}
// used to improve recall in task 1b
if (flags.lowerNewgeneThreshold) {
log.info("Using NEWGENE threshold: " + flags.newgeneThreshold);
int[] copy = new int[tags.length];
System.arraycopy(tags, 0, copy, 0, tags.length);
// for each sequence marked as NEWGENE in the gazette
// tag the entire sequence as NEWGENE and sum the score
// if the score is greater than newgeneThreshold, accept
int ngTag = classIndex.indexOf("G");
//int bgTag = classIndex.indexOf(BACKGROUND);
int bgTag = classIndex.indexOf(flags.backgroundSymbol);
for (int i = 0, dSize = document.size(); i < dSize; i++) {
CoreLabel wordInfo = document.get(i);
if ("NEWGENE".equals(wordInfo.get(CoreAnnotations.GazAnnotation.class))) {
int start = i;
int j;
for (j = i; j < document.size(); j++) {
wordInfo = document.get(j);
if (!"NEWGENE".equals(wordInfo.get(CoreAnnotations.GazAnnotation.class))) {
break;
}
}
int end = j;
//int end = i + 1;
int winStart = Math.max(0, start - 4);
int winEnd = Math.min(tags.length, end + 4);
// clear a window around the sequences
for (j = winStart; j < winEnd; j++) {
copy[j] = bgTag;
}
// score as nongene
double bgScore = 0.0;
for (j = start; j < end; j++) {
double[] scores = ts.scoresOf(copy, j);
scores = Scorer.recenter(scores);
bgScore += scores[bgTag];
}
// first pass, compute all of the scores
ClassicCounter<Pair<Integer, Integer>> prevScores = new ClassicCounter<>();
for (j = start; j < end; j++) {
// clear the sequence
for (int k = start; k < end; k++) {
copy[k] = bgTag;
}
// grow the sequence from j until the end
for (int k = j; k < end; k++) {
copy[k] = ngTag;
// score the sequence
double ngScore = 0.0;
for (int m = start; m < end; m++) {
double[] scores = ts.scoresOf(copy, m);
scores = Scorer.recenter(scores);
ngScore += scores[tags[m]];
}
prevScores.incrementCount(new Pair<>(Integer.valueOf(j), Integer.valueOf(k)), ngScore - bgScore);
}
}
for (j = start; j < end; j++) {
// grow the sequence from j until the end
for (int k = j; k < end; k++) {
double score = prevScores.getCount(new Pair<>(Integer.valueOf(j), Integer.valueOf(k)));
// adding a word to the left
Pair<Integer, Integer> al = new Pair<>(Integer.valueOf(j - 1), Integer.valueOf(k));
// adding a word to the right
Pair<Integer, Integer> ar = new Pair<>(Integer.valueOf(j), Integer.valueOf(k + 1));
// subtracting word from left
Pair<Integer, Integer> sl = new Pair<>(Integer.valueOf(j + 1), Integer.valueOf(k));
// subtracting word from right
Pair<Integer, Integer> sr = new Pair<>(Integer.valueOf(j), Integer.valueOf(k - 1));
// make sure the score is greater than all its neighbors (one add or subtract)
if (score >= flags.newgeneThreshold && (!prevScores.containsKey(al) || score > prevScores.getCount(al)) && (!prevScores.containsKey(ar) || score > prevScores.getCount(ar)) && (!prevScores.containsKey(sl) || score > prevScores.getCount(sl)) && (!prevScores.containsKey(sr) || score > prevScores.getCount(sr))) {
StringBuilder sb = new StringBuilder();
wordInfo = document.get(j);
String docId = wordInfo.get(CoreAnnotations.IDAnnotation.class);
String startIndex = wordInfo.get(CoreAnnotations.PositionAnnotation.class);
wordInfo = document.get(k);
String endIndex = wordInfo.get(CoreAnnotations.PositionAnnotation.class);
for (int m = j; m <= k; m++) {
wordInfo = document.get(m);
sb.append(wordInfo.word());
sb.append(' ');
}
/*log.info(sb.toString()+"score:"+score+
" al:"+prevScores.getCount(al)+
" ar:"+prevScores.getCount(ar)+
" sl:"+prevScores.getCount(sl)+" sr:"+ prevScores.getCount(sr));*/
System.out.println(docId + '|' + startIndex + ' ' + endIndex + '|' + sb.toString().trim());
}
}
}
// restore the original tags
for (j = winStart; j < winEnd; j++) {
copy[j] = tags[j];
}
i = end;
}
}
}
for (int i = 0, docSize = document.size(); i < docSize; i++) {
CoreLabel lineInfo = document.get(i);
String answer = classIndex.get(tags[i]);
lineInfo.set(CoreAnnotations.AnswerAnnotation.class, answer);
}
if (flags.justify && classifier instanceof LinearClassifier) {
LinearClassifier<String, String> lc = (LinearClassifier<String, String>) classifier;
if (flags.dump) {
lc.dump();
}
for (int i = 0, docSize = document.size(); i < docSize; i++) {
CoreLabel lineInfo = document.get(i);
log.info("@@ Position is: " + i + ": ");
log.info(lineInfo.word() + ' ' + lineInfo.get(CoreAnnotations.AnswerAnnotation.class));
lc.justificationOf(makeDatum(document, i, featureFactories));
}
}
if (flags.useReverse) {
Collections.reverse(document);
}
}
use of edu.stanford.nlp.sequences.ExactBestSequenceFinder in project CoreNLP by stanfordnlp.
the class TestSentence method runTagInference.
private void runTagInference() {
this.initializeScorer();
if (Thread.interrupted()) {
// Allow interrupting
throw new RuntimeInterruptedException();
}
BestSequenceFinder ti = new ExactBestSequenceFinder();
//new BeamBestSequenceFinder(50);
//new KBestSequenceFinder()
int[] bestTags = ti.bestSequence(this);
finalTags = new String[bestTags.length];
for (int j = 0; j < size; j++) {
finalTags[j] = maxentTagger.tags.getTag(bestTags[j + leftWindow()]);
}
if (Thread.interrupted()) {
// Allow interrupting
throw new RuntimeInterruptedException();
}
cleanUpScorer();
}
use of edu.stanford.nlp.sequences.ExactBestSequenceFinder in project CoreNLP by stanfordnlp.
the class CRFClassifierITest method runKBestTest.
private static void runKBestTest(CRFClassifier<CoreLabel> crf, String str, boolean isStoredAnswer) {
final int K_BEST = 12;
String[] txt = str.split(" ");
List<CoreLabel> input = SentenceUtils.toCoreLabelList(txt);
// do the ugliness that the CRFClassifier routines do to augment the input
ObjectBankWrapper<CoreLabel> obw = new ObjectBankWrapper<>(crf.flags, null, crf.getKnownLCWords());
List<CoreLabel> input2 = obw.processDocument(input);
SequenceModel sequenceModel = crf.getSequenceModel(input2);
List<Pair<CRFLabel, Double>> kBestSequencesLast = null;
for (int k = 1; k <= K_BEST; k++) {
Counter<int[]> kBest = new KBestSequenceFinder().kBestSequences(sequenceModel, k);
List<Pair<CRFLabel, Double>> kBestSequences = adapt(kBest);
assertEquals(k, kBestSequences.size());
// System.out.printf("k=%2d %s%n", k, kBestSequences);
if (kBestSequencesLast != null) {
// The rest of the list is the same
assertEquals("k=" + k, kBestSequencesLast, kBestSequences.subList(0, k - 1));
// New item is lower score
assertTrue(kBestSequences.get(k - 1).second() <= kBestSequences.get(k - 2).second());
for (int m = 0; m < (k - 1); m++) {
// New item is different
assertFalse(kBestSequences.get(k - 1).first().equals(kBestSequences.get(m).first()));
}
} else {
int[] bestSequence = new ExactBestSequenceFinder().bestSequence(sequenceModel);
int[] best1 = new ArrayList<>(kBest.keySet()).get(0);
assertTrue(Arrays.equals(bestSequence, best1));
}
kBestSequencesLast = kBestSequences;
}
List<Pair<List<String>, Double>> lastAnswer = null;
for (int k = 1; k <= K_BEST; k++) {
Counter<List<CoreLabel>> out = crf.classifyKBest(input, CoreAnnotations.AnswerAnnotation.class, k);
assertEquals(k, out.size());
List<Pair<List<CoreLabel>, Double>> beam = Counters.toSortedListWithCounts(out);
List<Pair<List<String>, Double>> beam2 = adapt2(beam);
// System.out.printf("k=%2d %s%n", k, beam2);
if (isStoredAnswer) {
// done for a particular sequence model at one point
assertEquals(beam2.get(k - 1).first().toString(), iobesAnswers[k - 1]);
assertEquals(beam2.get(k - 1).second(), scores[k - 1], 1e-8);
}
if (lastAnswer != null) {
// The rest of the list is the same
assertEquals("k=" + k, lastAnswer, beam2.subList(0, k - 1));
// New item is lower score
assertTrue(beam2.get(k - 1).second() <= beam2.get(k - 2).second());
for (int m = 0; m < (k - 1); m++) {
// New item is different
assertFalse(beam2.get(k - 1).first().equals(beam2.get(m).first()));
}
} else {
List<CoreLabel> best = crf.classify(input);
assertEquals(best, beam.get(0).first());
}
lastAnswer = beam2;
}
}
Aggregations