use of edu.stanford.nlp.sequences.KBestSequenceFinder in project CoreNLP by stanfordnlp.
the class CRFClassifierITest method runKBestTest.
private static void runKBestTest(CRFClassifier<CoreLabel> crf, String str, boolean isStoredAnswer) {
final int K_BEST = 12;
String[] txt = str.split(" ");
List<CoreLabel> input = SentenceUtils.toCoreLabelList(txt);
// do the ugliness that the CRFClassifier routines do to augment the input
ObjectBankWrapper<CoreLabel> obw = new ObjectBankWrapper<>(crf.flags, null, crf.getKnownLCWords());
List<CoreLabel> input2 = obw.processDocument(input);
SequenceModel sequenceModel = crf.getSequenceModel(input2);
List<Pair<CRFLabel, Double>> kBestSequencesLast = null;
for (int k = 1; k <= K_BEST; k++) {
Counter<int[]> kBest = new KBestSequenceFinder().kBestSequences(sequenceModel, k);
List<Pair<CRFLabel, Double>> kBestSequences = adapt(kBest);
assertEquals(k, kBestSequences.size());
// System.out.printf("k=%2d %s%n", k, kBestSequences);
if (kBestSequencesLast != null) {
// The rest of the list is the same
assertEquals("k=" + k, kBestSequencesLast, kBestSequences.subList(0, k - 1));
// New item is lower score
assertTrue(kBestSequences.get(k - 1).second() <= kBestSequences.get(k - 2).second());
for (int m = 0; m < (k - 1); m++) {
// New item is different
assertFalse(kBestSequences.get(k - 1).first().equals(kBestSequences.get(m).first()));
}
} else {
int[] bestSequence = new ExactBestSequenceFinder().bestSequence(sequenceModel);
int[] best1 = new ArrayList<>(kBest.keySet()).get(0);
assertTrue(Arrays.equals(bestSequence, best1));
}
kBestSequencesLast = kBestSequences;
}
List<Pair<List<String>, Double>> lastAnswer = null;
for (int k = 1; k <= K_BEST; k++) {
Counter<List<CoreLabel>> out = crf.classifyKBest(input, CoreAnnotations.AnswerAnnotation.class, k);
assertEquals(k, out.size());
List<Pair<List<CoreLabel>, Double>> beam = Counters.toSortedListWithCounts(out);
List<Pair<List<String>, Double>> beam2 = adapt2(beam);
// System.out.printf("k=%2d %s%n", k, beam2);
if (isStoredAnswer) {
// done for a particular sequence model at one point
assertEquals(beam2.get(k - 1).first().toString(), iobesAnswers[k - 1]);
assertEquals(beam2.get(k - 1).second(), scores[k - 1], 1e-8);
}
if (lastAnswer != null) {
// The rest of the list is the same
assertEquals("k=" + k, lastAnswer, beam2.subList(0, k - 1));
// New item is lower score
assertTrue(beam2.get(k - 1).second() <= beam2.get(k - 2).second());
for (int m = 0; m < (k - 1); m++) {
// New item is different
assertFalse(beam2.get(k - 1).first().equals(beam2.get(m).first()));
}
} else {
List<CoreLabel> best = crf.classify(input);
assertEquals(best, beam.get(0).first());
}
lastAnswer = beam2;
}
}
Aggregations