use of io.anserini.embeddings.TermNotFoundException in project Anserini by castorini.
the class WmdPassageScorer method score.
@Override
public void score(String query, Map<String, Float> sentences) throws Exception {
StandardAnalyzer sa = new StandardAnalyzer(StopFilter.makeStopSet(stopWords));
TokenStream tokenStream = sa.tokenStream("contents", new StringReader(query));
CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
tokenStream.reset();
Set<String> questionTerms = new HashSet<>();
Set<String> candidateTerms = new HashSet<>();
// avoid duplicate passages
Set<String> seenSentences = new HashSet<>();
while (tokenStream.incrementToken()) {
questionTerms.add(charTermAttribute.toString());
}
for (Map.Entry<String, Float> sent : sentences.entrySet()) {
double wmd = 0.0;
candidateTerms.clear();
sa = new StandardAnalyzer(StopFilter.makeStopSet(stopWords));
TokenStream candTokenStream = sa.tokenStream("contents", new StringReader(sent.getKey()));
charTermAttribute = candTokenStream.addAttribute(CharTermAttribute.class);
candTokenStream.reset();
while (candTokenStream.incrementToken()) {
candidateTerms.add(charTermAttribute.toString());
}
for (String qTerm : questionTerms) {
double minWMD = Double.MAX_VALUE;
for (String candTerm : candidateTerms) {
try {
double thisWMD = distance(wmdDictionary.getEmbeddingVector(qTerm), wmdDictionary.getEmbeddingVector(candTerm));
if (minWMD > thisWMD) {
minWMD = thisWMD;
}
} catch (TermNotFoundException e) {
String missingTerm = e.getMessage();
// mover's distance is 0
if (!qTerm.equals(missingTerm)) {
continue;
}
if (qTerm.equals(candTerm)) {
minWMD = 0.0;
} else {
try {
// if the embedding for the question term doesn't exist, consider
// it to be an unknown term
double thisWMD = distance(wmdDictionary.getEmbeddingVector("unk"), wmdDictionary.getEmbeddingVector(candTerm));
if (minWMD > thisWMD) {
minWMD = thisWMD;
}
} catch (TermNotFoundException e1) {
// "unk" is OOV
}
}
} catch (IOException e) {
// thrown if the search fails
}
}
if (minWMD != Double.MAX_VALUE) {
wmd += minWMD;
}
}
double weightedScore = -1 * (wmd + 0.0001 * sent.getValue());
ScoredPassage scoredPassage = new ScoredPassage(sent.getKey(), weightedScore, sent.getValue());
if ((scoredPassageHeap.size() < topPassages || weightedScore > scoredPassageHeap.peekLast().getScore()) && !seenSentences.contains(sent)) {
if (scoredPassageHeap.size() == topPassages) {
scoredPassageHeap.pollLast();
}
scoredPassageHeap.add(scoredPassage);
seenSentences.add(sent.getKey());
}
}
}
Aggregations