use of edu.stanford.nlp.sequences.BeamBestSequenceFinder in project CoreNLP by stanfordnlp.
the class CMMClassifier method classifySeq.
/**
* Classify a List of {@link CoreLabel}s using sequence information
* (i.e. Viterbi or Beam Search).
*
* @param document A List of {@link CoreLabel}s to be classified
*/
private void classifySeq(List<IN> document) {
if (document.isEmpty()) {
return;
}
SequenceModel ts = getSequenceModel(document);
// TagScorer ts = new PrevOnlyScorer(document, tagIndex, this, (!flags.useTaggySequences ? (flags.usePrevSequences ? 1 : 0) : flags.maxLeft), 0, answerArrays);
int[] tags;
// log.info("***begin test***");
if (flags.useViterbi) {
ExactBestSequenceFinder ti = new ExactBestSequenceFinder();
tags = ti.bestSequence(ts);
} else {
BeamBestSequenceFinder ti = new BeamBestSequenceFinder(flags.beamSize, true, true);
tags = ti.bestSequence(ts, document.size());
}
// used to improve recall in task 1b
if (flags.lowerNewgeneThreshold) {
log.info("Using NEWGENE threshold: " + flags.newgeneThreshold);
int[] copy = new int[tags.length];
System.arraycopy(tags, 0, copy, 0, tags.length);
// for each sequence marked as NEWGENE in the gazette
// tag the entire sequence as NEWGENE and sum the score
// if the score is greater than newgeneThreshold, accept
int ngTag = classIndex.indexOf("G");
// int bgTag = classIndex.indexOf(BACKGROUND);
int bgTag = classIndex.indexOf(flags.backgroundSymbol);
for (int i = 0, dSize = document.size(); i < dSize; i++) {
CoreLabel wordInfo = document.get(i);
if ("NEWGENE".equals(wordInfo.get(CoreAnnotations.GazAnnotation.class))) {
int start = i;
int j;
for (j = i; j < document.size(); j++) {
wordInfo = document.get(j);
if (!"NEWGENE".equals(wordInfo.get(CoreAnnotations.GazAnnotation.class))) {
break;
}
}
int end = j;
// int end = i + 1;
int winStart = Math.max(0, start - 4);
int winEnd = Math.min(tags.length, end + 4);
// clear a window around the sequences
for (j = winStart; j < winEnd; j++) {
copy[j] = bgTag;
}
// score as nongene
double bgScore = 0.0;
for (j = start; j < end; j++) {
double[] scores = ts.scoresOf(copy, j);
scores = Scorer.recenter(scores);
bgScore += scores[bgTag];
}
// first pass, compute all of the scores
ClassicCounter<Pair<Integer, Integer>> prevScores = new ClassicCounter<>();
for (j = start; j < end; j++) {
// clear the sequence
for (int k = start; k < end; k++) {
copy[k] = bgTag;
}
// grow the sequence from j until the end
for (int k = j; k < end; k++) {
copy[k] = ngTag;
// score the sequence
double ngScore = 0.0;
for (int m = start; m < end; m++) {
double[] scores = ts.scoresOf(copy, m);
scores = Scorer.recenter(scores);
ngScore += scores[tags[m]];
}
prevScores.incrementCount(new Pair<>(Integer.valueOf(j), Integer.valueOf(k)), ngScore - bgScore);
}
}
for (j = start; j < end; j++) {
// grow the sequence from j until the end
for (int k = j; k < end; k++) {
double score = prevScores.getCount(new Pair<>(Integer.valueOf(j), Integer.valueOf(k)));
// adding a word to the left
Pair<Integer, Integer> al = new Pair<>(Integer.valueOf(j - 1), Integer.valueOf(k));
// adding a word to the right
Pair<Integer, Integer> ar = new Pair<>(Integer.valueOf(j), Integer.valueOf(k + 1));
// subtracting word from left
Pair<Integer, Integer> sl = new Pair<>(Integer.valueOf(j + 1), Integer.valueOf(k));
// subtracting word from right
Pair<Integer, Integer> sr = new Pair<>(Integer.valueOf(j), Integer.valueOf(k - 1));
// make sure the score is greater than all its neighbors (one add or subtract)
if (score >= flags.newgeneThreshold && (!prevScores.containsKey(al) || score > prevScores.getCount(al)) && (!prevScores.containsKey(ar) || score > prevScores.getCount(ar)) && (!prevScores.containsKey(sl) || score > prevScores.getCount(sl)) && (!prevScores.containsKey(sr) || score > prevScores.getCount(sr))) {
StringBuilder sb = new StringBuilder();
wordInfo = document.get(j);
String docId = wordInfo.get(CoreAnnotations.IDAnnotation.class);
String startIndex = wordInfo.get(CoreAnnotations.PositionAnnotation.class);
wordInfo = document.get(k);
String endIndex = wordInfo.get(CoreAnnotations.PositionAnnotation.class);
for (int m = j; m <= k; m++) {
wordInfo = document.get(m);
sb.append(wordInfo.word());
sb.append(' ');
}
/*log.info(sb.toString()+"score:"+score+
" al:"+prevScores.getCount(al)+
" ar:"+prevScores.getCount(ar)+
" sl:"+prevScores.getCount(sl)+" sr:"+ prevScores.getCount(sr));*/
System.out.println(docId + '|' + startIndex + ' ' + endIndex + '|' + sb.toString().trim());
}
}
}
// restore the original tags
for (j = winStart; j < winEnd; j++) {
copy[j] = tags[j];
}
i = end;
}
}
}
for (int i = 0, docSize = document.size(); i < docSize; i++) {
CoreLabel lineInfo = document.get(i);
String answer = classIndex.get(tags[i]);
lineInfo.set(CoreAnnotations.AnswerAnnotation.class, answer);
}
if (flags.justify && classifier instanceof LinearClassifier) {
LinearClassifier<String, String> lc = (LinearClassifier<String, String>) classifier;
if (flags.dump) {
lc.dump();
}
for (int i = 0, docSize = document.size(); i < docSize; i++) {
CoreLabel lineInfo = document.get(i);
log.info("@@ Position is: " + i + ": ");
log.info(lineInfo.word() + ' ' + lineInfo.get(CoreAnnotations.AnswerAnnotation.class));
lc.justificationOf(makeDatum(document, i, featureFactories));
}
}
if (flags.useReverse) {
Collections.reverse(document);
}
}
Aggregations