Search in sources :

Example 1 with TermNotFoundException

use of io.anserini.embeddings.TermNotFoundException in project Anserini by castorini.

the class WmdPassageScorer method score.

@Override
public void score(String query, Map<String, Float> sentences) throws Exception {
    StandardAnalyzer sa = new StandardAnalyzer(StopFilter.makeStopSet(stopWords));
    TokenStream tokenStream = sa.tokenStream("contents", new StringReader(query));
    CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
    tokenStream.reset();
    Set<String> questionTerms = new HashSet<>();
    Set<String> candidateTerms = new HashSet<>();
    // avoid duplicate passages
    Set<String> seenSentences = new HashSet<>();
    while (tokenStream.incrementToken()) {
        questionTerms.add(charTermAttribute.toString());
    }
    for (Map.Entry<String, Float> sent : sentences.entrySet()) {
        double wmd = 0.0;
        candidateTerms.clear();
        sa = new StandardAnalyzer(StopFilter.makeStopSet(stopWords));
        TokenStream candTokenStream = sa.tokenStream("contents", new StringReader(sent.getKey()));
        charTermAttribute = candTokenStream.addAttribute(CharTermAttribute.class);
        candTokenStream.reset();
        while (candTokenStream.incrementToken()) {
            candidateTerms.add(charTermAttribute.toString());
        }
        for (String qTerm : questionTerms) {
            double minWMD = Double.MAX_VALUE;
            for (String candTerm : candidateTerms) {
                try {
                    double thisWMD = distance(wmdDictionary.getEmbeddingVector(qTerm), wmdDictionary.getEmbeddingVector(candTerm));
                    if (minWMD > thisWMD) {
                        minWMD = thisWMD;
                    }
                } catch (TermNotFoundException e) {
                    String missingTerm = e.getMessage();
                    // mover's distance is 0
                    if (!qTerm.equals(missingTerm)) {
                        continue;
                    }
                    if (qTerm.equals(candTerm)) {
                        minWMD = 0.0;
                    } else {
                        try {
                            // if the embedding for the question term doesn't exist, consider
                            // it to be an unknown term
                            double thisWMD = distance(wmdDictionary.getEmbeddingVector("unk"), wmdDictionary.getEmbeddingVector(candTerm));
                            if (minWMD > thisWMD) {
                                minWMD = thisWMD;
                            }
                        } catch (TermNotFoundException e1) {
                        // "unk" is OOV
                        }
                    }
                } catch (IOException e) {
                // thrown if the search fails
                }
            }
            if (minWMD != Double.MAX_VALUE) {
                wmd += minWMD;
            }
        }
        double weightedScore = -1 * (wmd + 0.0001 * sent.getValue());
        ScoredPassage scoredPassage = new ScoredPassage(sent.getKey(), weightedScore, sent.getValue());
        if ((scoredPassageHeap.size() < topPassages || weightedScore > scoredPassageHeap.peekLast().getScore()) && !seenSentences.contains(sent)) {
            if (scoredPassageHeap.size() == topPassages) {
                scoredPassageHeap.pollLast();
            }
            scoredPassageHeap.add(scoredPassage);
            seenSentences.add(sent.getKey());
        }
    }
}
Also used : TokenStream(org.apache.lucene.analysis.TokenStream) TermNotFoundException(io.anserini.embeddings.TermNotFoundException) CharTermAttribute(org.apache.lucene.analysis.tokenattributes.CharTermAttribute) StandardAnalyzer(org.apache.lucene.analysis.standard.StandardAnalyzer)

Aggregations

TermNotFoundException (io.anserini.embeddings.TermNotFoundException)1 TokenStream (org.apache.lucene.analysis.TokenStream)1 StandardAnalyzer (org.apache.lucene.analysis.standard.StandardAnalyzer)1 CharTermAttribute (org.apache.lucene.analysis.tokenattributes.CharTermAttribute)1