use of io.anserini.search.topicreader.TrecTopicReader in project Anserini by castorini.
the class ApproximateNearestNeighborEval method main.
public static void main(String[] args) throws Exception {
ApproximateNearestNeighborEval.Args indexArgs = new ApproximateNearestNeighborEval.Args();
CmdLineParser parser = new CmdLineParser(indexArgs, ParserProperties.defaults().withUsageWidth(90));
try {
parser.parseArgument(args);
} catch (CmdLineException e) {
System.err.println(e.getMessage());
parser.printUsage(System.err);
System.err.println("Example: " + ApproximateNearestNeighborEval.class.getSimpleName() + parser.printExample(OptionHandlerFilter.REQUIRED));
return;
}
Analyzer vectorAnalyzer;
if (indexArgs.encoding.equalsIgnoreCase(FW)) {
vectorAnalyzer = new FakeWordsEncoderAnalyzer(indexArgs.q);
} else if (indexArgs.encoding.equalsIgnoreCase(LEXLSH)) {
vectorAnalyzer = new LexicalLshAnalyzer(indexArgs.decimals, indexArgs.ngrams, indexArgs.hashCount, indexArgs.bucketCount, indexArgs.hashSetSize);
} else {
parser.printUsage(System.err);
System.err.println("Example: " + ApproximateNearestNeighborEval.class.getSimpleName() + parser.printExample(OptionHandlerFilter.REQUIRED));
return;
}
System.out.println(String.format("Loading model %s", indexArgs.input));
Map<String, List<float[]>> wordVectors = IndexVectors.readGloVe(indexArgs.input);
Path indexDir = indexArgs.path;
if (!Files.exists(indexDir)) {
Files.createDirectories(indexDir);
}
System.out.println(String.format("Reading index at %s", indexArgs.path));
Directory d = FSDirectory.open(indexDir);
DirectoryReader reader = DirectoryReader.open(d);
IndexSearcher searcher = new IndexSearcher(reader);
if (indexArgs.encoding.equalsIgnoreCase(FW)) {
searcher.setSimilarity(new ClassicSimilarity());
}
StandardAnalyzer standardAnalyzer = new StandardAnalyzer();
double recall = 0;
double time = 0d;
System.out.println("Evaluating at retrieval depth: " + indexArgs.depth);
TrecTopicReader trecTopicReader = new TrecTopicReader(indexArgs.topicsPath);
Collection<String> words = new LinkedList<>();
trecTopicReader.read().values().forEach(e -> words.addAll(AnalyzerUtils.analyze(standardAnalyzer, e.get("title"))));
int queryCount = 0;
for (String word : words) {
if (wordVectors.containsKey(word)) {
Set<String> truth = nearestVector(wordVectors, word, indexArgs.topN);
try {
List<float[]> vectors = wordVectors.get(word);
for (float[] vector : vectors) {
StringBuilder sb = new StringBuilder();
for (double fv : vector) {
if (sb.length() > 0) {
sb.append(' ');
}
sb.append(fv);
}
String fvString = sb.toString();
CommonTermsQuery simQuery = new CommonTermsQuery(SHOULD, SHOULD, indexArgs.cutoff);
if (indexArgs.msm > 0) {
simQuery.setLowFreqMinimumNumberShouldMatch(indexArgs.msm);
}
for (String token : AnalyzerUtils.analyze(vectorAnalyzer, fvString)) {
simQuery.add(new Term(IndexVectors.FIELD_VECTOR, token));
}
long start = System.currentTimeMillis();
TopScoreDocCollector results = TopScoreDocCollector.create(indexArgs.depth, Integer.MAX_VALUE);
searcher.search(simQuery, results);
time += System.currentTimeMillis() - start;
Set<String> observations = new HashSet<>();
for (ScoreDoc sd : results.topDocs().scoreDocs) {
Document document = reader.document(sd.doc);
String wordValue = document.get(IndexVectors.FIELD_ID);
observations.add(wordValue);
}
double intersection = Sets.intersection(truth, observations).size();
double localRecall = intersection / (double) truth.size();
recall += localRecall;
queryCount++;
}
} catch (IOException e) {
System.err.println("search for '" + word + "' failed " + e.getLocalizedMessage());
}
}
if (queryCount >= indexArgs.samples) {
break;
}
}
recall /= queryCount;
time /= queryCount;
System.out.println(String.format("R@%d: %.4f", indexArgs.depth, recall));
System.out.println(String.format("avg query time: %s ms", time));
reader.close();
d.close();
}
Aggregations