use of io.anserini.search.query.DisjunctionMaxQueryGenerator in project Anserini by castorini.
the class SearchMsmarco method main.
public static void main(String[] args) throws Exception {
Args retrieveArgs = new Args();
CmdLineParser parser = new CmdLineParser(retrieveArgs, ParserProperties.defaults().withUsageWidth(90));
try {
parser.parseArgument(args);
} catch (CmdLineException e) {
System.err.println(e.getMessage());
parser.printUsage(System.err);
System.err.println("Example: Eval " + parser.printExample(OptionHandlerFilter.REQUIRED));
return;
}
System.out.println("###############################################################################");
System.out.println("WARNING: This class has been deprecated and may be removed in a future release!");
System.out.println("###############################################################################\n");
long totalStartTime = System.nanoTime();
Analyzer analyzer;
if (retrieveArgs.pretokenized) {
analyzer = new WhitespaceAnalyzer();
System.out.println("Initializing whilte space analyzer");
} else {
analyzer = DefaultEnglishAnalyzer.fromArguments(retrieveArgs.stemmer, retrieveArgs.keepstop, retrieveArgs.stopwords);
System.out.println("Initializing analyzer with stemmer=" + retrieveArgs.stemmer + ", keepstop=" + retrieveArgs.keepstop + ", stopwords=" + retrieveArgs.stopwords);
}
SimpleSearcher searcher = new SimpleSearcher(retrieveArgs.index, analyzer);
searcher.setBM25(retrieveArgs.k1, retrieveArgs.b);
System.out.println("Initializing BM25, setting k1=" + retrieveArgs.k1 + " and b=" + retrieveArgs.b + "");
if (retrieveArgs.rm3) {
searcher.setRM3(retrieveArgs.fbTerms, retrieveArgs.fbDocs, retrieveArgs.originalQueryWeight);
System.out.println("Initializing RM3, setting fbTerms=" + retrieveArgs.fbTerms + ", fbDocs=" + retrieveArgs.fbDocs + " and originalQueryWeight=" + retrieveArgs.originalQueryWeight);
}
Map<String, Float> fields = new HashMap<>();
retrieveArgs.fields.forEach((key, value) -> fields.put(key, Float.valueOf(value)));
if (retrieveArgs.fields.size() > 0) {
System.out.println("Performing weighted field search with fields=" + retrieveArgs.fields);
}
QueryGenerator queryGenerator;
if (retrieveArgs.dismax) {
queryGenerator = new DisjunctionMaxQueryGenerator(retrieveArgs.dismax_tiebreaker);
System.out.println("Initializing dismax query generator, with tiebreaker=" + retrieveArgs.dismax_tiebreaker);
} else {
queryGenerator = new BagOfWordsQueryGenerator();
System.out.println("Initializing bag-of-words query generator.");
}
PrintWriter out = new PrintWriter(Files.newBufferedWriter(Paths.get(retrieveArgs.output), StandardCharsets.US_ASCII));
if (retrieveArgs.threads == 1) {
// single-threaded retrieval
long startTime = System.nanoTime();
List<String> lines = FileUtils.readLines(new File(retrieveArgs.qid_queries), "utf-8");
for (int lineNumber = 0; lineNumber < lines.size(); ++lineNumber) {
String line = lines.get(lineNumber);
String[] split = line.trim().split("\t");
String qid = split[0];
String query = split[1];
SimpleSearcher.Result[] hits;
if (retrieveArgs.fields.size() > 0) {
hits = searcher.searchFields(queryGenerator, query, fields, retrieveArgs.hits);
} else {
hits = searcher.search(queryGenerator, query, retrieveArgs.hits);
}
if (lineNumber % 100 == 0) {
double timePerQuery = (double) (System.nanoTime() - startTime) / (lineNumber + 1) / 1e9;
System.out.format("Retrieving query " + lineNumber + " (%.3f s/query)\n", timePerQuery);
}
for (int rank = 0; rank < hits.length; ++rank) {
String docno = hits[rank].docid;
out.println(qid + "\t" + docno + "\t" + (rank + 1));
}
}
} else {
// multithreaded batch retrieval
List<String> lines = FileUtils.readLines(new File(retrieveArgs.qid_queries), "utf-8");
List<String> queries = lines.stream().map(x -> x.trim().split("\t")[1]).collect(Collectors.toList());
List<String> qids = lines.stream().map(x -> x.trim().split("\t")[0]).collect(Collectors.toList());
Map<String, SimpleSearcher.Result[]> results;
if (retrieveArgs.fields.size() > 0) {
results = searcher.batchSearchFields(queryGenerator, queries, qids, retrieveArgs.hits, retrieveArgs.threads, fields);
} else {
results = searcher.batchSearch(queryGenerator, queries, qids, retrieveArgs.hits, retrieveArgs.threads);
}
for (String qid : qids) {
SimpleSearcher.Result[] hits = results.get(qid);
for (int rank = 0; rank < hits.length; ++rank) {
String docno = hits[rank].docid;
out.println(qid + "\t" + docno + "\t" + (rank + 1));
}
}
}
searcher.close();
out.flush();
out.close();
double totalTime = (double) (System.nanoTime() - totalStartTime) / 1e9;
System.out.format("Total retrieval time: %.3f s\n", totalTime);
System.out.println("Done!");
}
Aggregations