Search in sources :

Example 1 with BeamBestSequenceFinder

use of edu.stanford.nlp.sequences.BeamBestSequenceFinder in project CoreNLP by stanfordnlp.

the class CMMClassifier method classifySeq.

/**
 * Classify a List of {@link CoreLabel}s using sequence information
 * (i.e. Viterbi or Beam Search).
 *
 * @param document A List of {@link CoreLabel}s to be classified
 */
private void classifySeq(List<IN> document) {
    if (document.isEmpty()) {
        return;
    }
    SequenceModel ts = getSequenceModel(document);
    // TagScorer ts = new PrevOnlyScorer(document, tagIndex, this, (!flags.useTaggySequences ? (flags.usePrevSequences ? 1 : 0) : flags.maxLeft), 0, answerArrays);
    int[] tags;
    // log.info("***begin test***");
    if (flags.useViterbi) {
        ExactBestSequenceFinder ti = new ExactBestSequenceFinder();
        tags = ti.bestSequence(ts);
    } else {
        BeamBestSequenceFinder ti = new BeamBestSequenceFinder(flags.beamSize, true, true);
        tags = ti.bestSequence(ts, document.size());
    }
    // used to improve recall in task 1b
    if (flags.lowerNewgeneThreshold) {
        log.info("Using NEWGENE threshold: " + flags.newgeneThreshold);
        int[] copy = new int[tags.length];
        System.arraycopy(tags, 0, copy, 0, tags.length);
        // for each sequence marked as NEWGENE in the gazette
        // tag the entire sequence as NEWGENE and sum the score
        // if the score is greater than newgeneThreshold, accept
        int ngTag = classIndex.indexOf("G");
        // int bgTag = classIndex.indexOf(BACKGROUND);
        int bgTag = classIndex.indexOf(flags.backgroundSymbol);
        for (int i = 0, dSize = document.size(); i < dSize; i++) {
            CoreLabel wordInfo = document.get(i);
            if ("NEWGENE".equals(wordInfo.get(CoreAnnotations.GazAnnotation.class))) {
                int start = i;
                int j;
                for (j = i; j < document.size(); j++) {
                    wordInfo = document.get(j);
                    if (!"NEWGENE".equals(wordInfo.get(CoreAnnotations.GazAnnotation.class))) {
                        break;
                    }
                }
                int end = j;
                // int end = i + 1;
                int winStart = Math.max(0, start - 4);
                int winEnd = Math.min(tags.length, end + 4);
                // clear a window around the sequences
                for (j = winStart; j < winEnd; j++) {
                    copy[j] = bgTag;
                }
                // score as nongene
                double bgScore = 0.0;
                for (j = start; j < end; j++) {
                    double[] scores = ts.scoresOf(copy, j);
                    scores = Scorer.recenter(scores);
                    bgScore += scores[bgTag];
                }
                // first pass, compute all of the scores
                ClassicCounter<Pair<Integer, Integer>> prevScores = new ClassicCounter<>();
                for (j = start; j < end; j++) {
                    // clear the sequence
                    for (int k = start; k < end; k++) {
                        copy[k] = bgTag;
                    }
                    // grow the sequence from j until the end
                    for (int k = j; k < end; k++) {
                        copy[k] = ngTag;
                        // score the sequence
                        double ngScore = 0.0;
                        for (int m = start; m < end; m++) {
                            double[] scores = ts.scoresOf(copy, m);
                            scores = Scorer.recenter(scores);
                            ngScore += scores[tags[m]];
                        }
                        prevScores.incrementCount(new Pair<>(Integer.valueOf(j), Integer.valueOf(k)), ngScore - bgScore);
                    }
                }
                for (j = start; j < end; j++) {
                    // grow the sequence from j until the end
                    for (int k = j; k < end; k++) {
                        double score = prevScores.getCount(new Pair<>(Integer.valueOf(j), Integer.valueOf(k)));
                        // adding a word to the left
                        Pair<Integer, Integer> al = new Pair<>(Integer.valueOf(j - 1), Integer.valueOf(k));
                        // adding a word to the right
                        Pair<Integer, Integer> ar = new Pair<>(Integer.valueOf(j), Integer.valueOf(k + 1));
                        // subtracting word from left
                        Pair<Integer, Integer> sl = new Pair<>(Integer.valueOf(j + 1), Integer.valueOf(k));
                        // subtracting word from right
                        Pair<Integer, Integer> sr = new Pair<>(Integer.valueOf(j), Integer.valueOf(k - 1));
                        // make sure the score is greater than all its neighbors (one add or subtract)
                        if (score >= flags.newgeneThreshold && (!prevScores.containsKey(al) || score > prevScores.getCount(al)) && (!prevScores.containsKey(ar) || score > prevScores.getCount(ar)) && (!prevScores.containsKey(sl) || score > prevScores.getCount(sl)) && (!prevScores.containsKey(sr) || score > prevScores.getCount(sr))) {
                            StringBuilder sb = new StringBuilder();
                            wordInfo = document.get(j);
                            String docId = wordInfo.get(CoreAnnotations.IDAnnotation.class);
                            String startIndex = wordInfo.get(CoreAnnotations.PositionAnnotation.class);
                            wordInfo = document.get(k);
                            String endIndex = wordInfo.get(CoreAnnotations.PositionAnnotation.class);
                            for (int m = j; m <= k; m++) {
                                wordInfo = document.get(m);
                                sb.append(wordInfo.word());
                                sb.append(' ');
                            }
                            /*log.info(sb.toString()+"score:"+score+
                  " al:"+prevScores.getCount(al)+
                  " ar:"+prevScores.getCount(ar)+
                  "  sl:"+prevScores.getCount(sl)+" sr:"+ prevScores.getCount(sr));*/
                            System.out.println(docId + '|' + startIndex + ' ' + endIndex + '|' + sb.toString().trim());
                        }
                    }
                }
                // restore the original tags
                for (j = winStart; j < winEnd; j++) {
                    copy[j] = tags[j];
                }
                i = end;
            }
        }
    }
    for (int i = 0, docSize = document.size(); i < docSize; i++) {
        CoreLabel lineInfo = document.get(i);
        String answer = classIndex.get(tags[i]);
        lineInfo.set(CoreAnnotations.AnswerAnnotation.class, answer);
    }
    if (flags.justify && classifier instanceof LinearClassifier) {
        LinearClassifier<String, String> lc = (LinearClassifier<String, String>) classifier;
        if (flags.dump) {
            lc.dump();
        }
        for (int i = 0, docSize = document.size(); i < docSize; i++) {
            CoreLabel lineInfo = document.get(i);
            log.info("@@ Position is: " + i + ": ");
            log.info(lineInfo.word() + ' ' + lineInfo.get(CoreAnnotations.AnswerAnnotation.class));
            lc.justificationOf(makeDatum(document, i, featureFactories));
        }
    }
    if (flags.useReverse) {
        Collections.reverse(document);
    }
}
Also used : SequenceModel(edu.stanford.nlp.sequences.SequenceModel) ExactBestSequenceFinder(edu.stanford.nlp.sequences.ExactBestSequenceFinder) CoreLabel(edu.stanford.nlp.ling.CoreLabel) CoreAnnotations(edu.stanford.nlp.ling.CoreAnnotations) ClassicCounter(edu.stanford.nlp.stats.ClassicCounter) LinearClassifier(edu.stanford.nlp.classify.LinearClassifier) BeamBestSequenceFinder(edu.stanford.nlp.sequences.BeamBestSequenceFinder)

Aggregations

LinearClassifier (edu.stanford.nlp.classify.LinearClassifier)1 CoreAnnotations (edu.stanford.nlp.ling.CoreAnnotations)1 CoreLabel (edu.stanford.nlp.ling.CoreLabel)1 BeamBestSequenceFinder (edu.stanford.nlp.sequences.BeamBestSequenceFinder)1 ExactBestSequenceFinder (edu.stanford.nlp.sequences.ExactBestSequenceFinder)1 SequenceModel (edu.stanford.nlp.sequences.SequenceModel)1 ClassicCounter (edu.stanford.nlp.stats.ClassicCounter)1