Search in sources :

Example 1 with SequenceModel

use of edu.stanford.nlp.sequences.SequenceModel in project CoreNLP by stanfordnlp.

the class CMMClassifier method classifySeq.

/**
   * Classify a List of {@link CoreLabel}s using sequence information
   * (i.e. Viterbi or Beam Search).
   *
   * @param document A List of {@link CoreLabel}s to be classified
   */
private void classifySeq(List<IN> document) {
    if (document.isEmpty()) {
        return;
    }
    SequenceModel ts = getSequenceModel(document);
    //    TagScorer ts = new PrevOnlyScorer(document, tagIndex, this, (!flags.useTaggySequences ? (flags.usePrevSequences ? 1 : 0) : flags.maxLeft), 0, answerArrays);
    int[] tags;
    //log.info("***begin test***");
    if (flags.useViterbi) {
        ExactBestSequenceFinder ti = new ExactBestSequenceFinder();
        tags = ti.bestSequence(ts);
    } else {
        BeamBestSequenceFinder ti = new BeamBestSequenceFinder(flags.beamSize, true, true);
        tags = ti.bestSequence(ts, document.size());
    }
    // used to improve recall in task 1b
    if (flags.lowerNewgeneThreshold) {
        log.info("Using NEWGENE threshold: " + flags.newgeneThreshold);
        int[] copy = new int[tags.length];
        System.arraycopy(tags, 0, copy, 0, tags.length);
        // for each sequence marked as NEWGENE in the gazette
        // tag the entire sequence as NEWGENE and sum the score
        // if the score is greater than newgeneThreshold, accept
        int ngTag = classIndex.indexOf("G");
        //int bgTag = classIndex.indexOf(BACKGROUND);
        int bgTag = classIndex.indexOf(flags.backgroundSymbol);
        for (int i = 0, dSize = document.size(); i < dSize; i++) {
            CoreLabel wordInfo = document.get(i);
            if ("NEWGENE".equals(wordInfo.get(CoreAnnotations.GazAnnotation.class))) {
                int start = i;
                int j;
                for (j = i; j < document.size(); j++) {
                    wordInfo = document.get(j);
                    if (!"NEWGENE".equals(wordInfo.get(CoreAnnotations.GazAnnotation.class))) {
                        break;
                    }
                }
                int end = j;
                //int end = i + 1;
                int winStart = Math.max(0, start - 4);
                int winEnd = Math.min(tags.length, end + 4);
                // clear a window around the sequences
                for (j = winStart; j < winEnd; j++) {
                    copy[j] = bgTag;
                }
                // score as nongene
                double bgScore = 0.0;
                for (j = start; j < end; j++) {
                    double[] scores = ts.scoresOf(copy, j);
                    scores = Scorer.recenter(scores);
                    bgScore += scores[bgTag];
                }
                // first pass, compute all of the scores
                ClassicCounter<Pair<Integer, Integer>> prevScores = new ClassicCounter<>();
                for (j = start; j < end; j++) {
                    // clear the sequence
                    for (int k = start; k < end; k++) {
                        copy[k] = bgTag;
                    }
                    // grow the sequence from j until the end
                    for (int k = j; k < end; k++) {
                        copy[k] = ngTag;
                        // score the sequence
                        double ngScore = 0.0;
                        for (int m = start; m < end; m++) {
                            double[] scores = ts.scoresOf(copy, m);
                            scores = Scorer.recenter(scores);
                            ngScore += scores[tags[m]];
                        }
                        prevScores.incrementCount(new Pair<>(Integer.valueOf(j), Integer.valueOf(k)), ngScore - bgScore);
                    }
                }
                for (j = start; j < end; j++) {
                    // grow the sequence from j until the end
                    for (int k = j; k < end; k++) {
                        double score = prevScores.getCount(new Pair<>(Integer.valueOf(j), Integer.valueOf(k)));
                        // adding a word to the left
                        Pair<Integer, Integer> al = new Pair<>(Integer.valueOf(j - 1), Integer.valueOf(k));
                        // adding a word to the right
                        Pair<Integer, Integer> ar = new Pair<>(Integer.valueOf(j), Integer.valueOf(k + 1));
                        // subtracting word from left
                        Pair<Integer, Integer> sl = new Pair<>(Integer.valueOf(j + 1), Integer.valueOf(k));
                        // subtracting word from right
                        Pair<Integer, Integer> sr = new Pair<>(Integer.valueOf(j), Integer.valueOf(k - 1));
                        // make sure the score is greater than all its neighbors (one add or subtract)
                        if (score >= flags.newgeneThreshold && (!prevScores.containsKey(al) || score > prevScores.getCount(al)) && (!prevScores.containsKey(ar) || score > prevScores.getCount(ar)) && (!prevScores.containsKey(sl) || score > prevScores.getCount(sl)) && (!prevScores.containsKey(sr) || score > prevScores.getCount(sr))) {
                            StringBuilder sb = new StringBuilder();
                            wordInfo = document.get(j);
                            String docId = wordInfo.get(CoreAnnotations.IDAnnotation.class);
                            String startIndex = wordInfo.get(CoreAnnotations.PositionAnnotation.class);
                            wordInfo = document.get(k);
                            String endIndex = wordInfo.get(CoreAnnotations.PositionAnnotation.class);
                            for (int m = j; m <= k; m++) {
                                wordInfo = document.get(m);
                                sb.append(wordInfo.word());
                                sb.append(' ');
                            }
                            /*log.info(sb.toString()+"score:"+score+
                  " al:"+prevScores.getCount(al)+
                  " ar:"+prevScores.getCount(ar)+
                  "  sl:"+prevScores.getCount(sl)+" sr:"+ prevScores.getCount(sr));*/
                            System.out.println(docId + '|' + startIndex + ' ' + endIndex + '|' + sb.toString().trim());
                        }
                    }
                }
                // restore the original tags
                for (j = winStart; j < winEnd; j++) {
                    copy[j] = tags[j];
                }
                i = end;
            }
        }
    }
    for (int i = 0, docSize = document.size(); i < docSize; i++) {
        CoreLabel lineInfo = document.get(i);
        String answer = classIndex.get(tags[i]);
        lineInfo.set(CoreAnnotations.AnswerAnnotation.class, answer);
    }
    if (flags.justify && classifier instanceof LinearClassifier) {
        LinearClassifier<String, String> lc = (LinearClassifier<String, String>) classifier;
        if (flags.dump) {
            lc.dump();
        }
        for (int i = 0, docSize = document.size(); i < docSize; i++) {
            CoreLabel lineInfo = document.get(i);
            log.info("@@ Position is: " + i + ": ");
            log.info(lineInfo.word() + ' ' + lineInfo.get(CoreAnnotations.AnswerAnnotation.class));
            lc.justificationOf(makeDatum(document, i, featureFactories));
        }
    }
    if (flags.useReverse) {
        Collections.reverse(document);
    }
}
Also used : SequenceModel(edu.stanford.nlp.sequences.SequenceModel) ExactBestSequenceFinder(edu.stanford.nlp.sequences.ExactBestSequenceFinder) CoreLabel(edu.stanford.nlp.ling.CoreLabel) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) LinearClassifier(edu.stanford.nlp.classify.LinearClassifier) BeamBestSequenceFinder(edu.stanford.nlp.sequences.BeamBestSequenceFinder)

Example 2 with SequenceModel

use of edu.stanford.nlp.sequences.SequenceModel in project CoreNLP by stanfordnlp.

the class CRFClassifierITest method runKBestTest.

private static void runKBestTest(CRFClassifier<CoreLabel> crf, String str, boolean isStoredAnswer) {
    final int K_BEST = 12;
    String[] txt = str.split(" ");
    List<CoreLabel> input = SentenceUtils.toCoreLabelList(txt);
    // do the ugliness that the CRFClassifier routines do to augment the input
    ObjectBankWrapper<CoreLabel> obw = new ObjectBankWrapper<>(crf.flags, null, crf.getKnownLCWords());
    List<CoreLabel> input2 = obw.processDocument(input);
    SequenceModel sequenceModel = crf.getSequenceModel(input2);
    List<Pair<CRFLabel, Double>> kBestSequencesLast = null;
    for (int k = 1; k <= K_BEST; k++) {
        Counter<int[]> kBest = new KBestSequenceFinder().kBestSequences(sequenceModel, k);
        List<Pair<CRFLabel, Double>> kBestSequences = adapt(kBest);
        assertEquals(k, kBestSequences.size());
        // System.out.printf("k=%2d %s%n", k, kBestSequences);
        if (kBestSequencesLast != null) {
            // The rest of the list is the same
            assertEquals("k=" + k, kBestSequencesLast, kBestSequences.subList(0, k - 1));
            // New item is lower score
            assertTrue(kBestSequences.get(k - 1).second() <= kBestSequences.get(k - 2).second());
            for (int m = 0; m < (k - 1); m++) {
                // New item is different
                assertFalse(kBestSequences.get(k - 1).first().equals(kBestSequences.get(m).first()));
            }
        } else {
            int[] bestSequence = new ExactBestSequenceFinder().bestSequence(sequenceModel);
            int[] best1 = new ArrayList<>(kBest.keySet()).get(0);
            assertTrue(Arrays.equals(bestSequence, best1));
        }
        kBestSequencesLast = kBestSequences;
    }
    List<Pair<List<String>, Double>> lastAnswer = null;
    for (int k = 1; k <= K_BEST; k++) {
        Counter<List<CoreLabel>> out = crf.classifyKBest(input, CoreAnnotations.AnswerAnnotation.class, k);
        assertEquals(k, out.size());
        List<Pair<List<CoreLabel>, Double>> beam = Counters.toSortedListWithCounts(out);
        List<Pair<List<String>, Double>> beam2 = adapt2(beam);
        // System.out.printf("k=%2d %s%n", k, beam2);
        if (isStoredAnswer) {
            // done for a particular sequence model at one point
            assertEquals(beam2.get(k - 1).first().toString(), iobesAnswers[k - 1]);
            assertEquals(beam2.get(k - 1).second(), scores[k - 1], 1e-8);
        }
        if (lastAnswer != null) {
            // The rest of the list is the same
            assertEquals("k=" + k, lastAnswer, beam2.subList(0, k - 1));
            // New item is lower score
            assertTrue(beam2.get(k - 1).second() <= beam2.get(k - 2).second());
            for (int m = 0; m < (k - 1); m++) {
                // New item is different
                assertFalse(beam2.get(k - 1).first().equals(beam2.get(m).first()));
            }
        } else {
            List<CoreLabel> best = crf.classify(input);
            assertEquals(best, beam.get(0).first());
        }
        lastAnswer = beam2;
    }
}
Also used : SequenceModel(edu.stanford.nlp.sequences.SequenceModel) ExactBestSequenceFinder(edu.stanford.nlp.sequences.ExactBestSequenceFinder) ObjectBankWrapper(edu.stanford.nlp.sequences.ObjectBankWrapper) CoreLabel(edu.stanford.nlp.ling.CoreLabel) KBestSequenceFinder(edu.stanford.nlp.sequences.KBestSequenceFinder) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) ArrayList(java.util.ArrayList) List(java.util.List) Pair(edu.stanford.nlp.util.Pair)

Aggregations

CoreAnnotations (edu.stanford.nlp.ling.CoreAnnotations)2 CoreLabel (edu.stanford.nlp.ling.CoreLabel)2 ExactBestSequenceFinder (edu.stanford.nlp.sequences.ExactBestSequenceFinder)2 SequenceModel (edu.stanford.nlp.sequences.SequenceModel)2 LinearClassifier (edu.stanford.nlp.classify.LinearClassifier)1 BeamBestSequenceFinder (edu.stanford.nlp.sequences.BeamBestSequenceFinder)1 KBestSequenceFinder (edu.stanford.nlp.sequences.KBestSequenceFinder)1 ObjectBankWrapper (edu.stanford.nlp.sequences.ObjectBankWrapper)1 ClassicCounter (edu.stanford.nlp.stats.ClassicCounter)1 Pair (edu.stanford.nlp.util.Pair)1 ArrayList (java.util.ArrayList)1 List (java.util.List)1