Search in sources :

Example 1 with TrecTopicReader

use of io.anserini.search.topicreader.TrecTopicReader in project Anserini by castorini.

the class ApproximateNearestNeighborEval method main.

public static void main(String[] args) throws Exception {
    ApproximateNearestNeighborEval.Args indexArgs = new ApproximateNearestNeighborEval.Args();
    CmdLineParser parser = new CmdLineParser(indexArgs, ParserProperties.defaults().withUsageWidth(90));
    try {
        parser.parseArgument(args);
    } catch (CmdLineException e) {
        System.err.println(e.getMessage());
        parser.printUsage(System.err);
        System.err.println("Example: " + ApproximateNearestNeighborEval.class.getSimpleName() + parser.printExample(OptionHandlerFilter.REQUIRED));
        return;
    }
    Analyzer vectorAnalyzer;
    if (indexArgs.encoding.equalsIgnoreCase(FW)) {
        vectorAnalyzer = new FakeWordsEncoderAnalyzer(indexArgs.q);
    } else if (indexArgs.encoding.equalsIgnoreCase(LEXLSH)) {
        vectorAnalyzer = new LexicalLshAnalyzer(indexArgs.decimals, indexArgs.ngrams, indexArgs.hashCount, indexArgs.bucketCount, indexArgs.hashSetSize);
    } else {
        parser.printUsage(System.err);
        System.err.println("Example: " + ApproximateNearestNeighborEval.class.getSimpleName() + parser.printExample(OptionHandlerFilter.REQUIRED));
        return;
    }
    System.out.println(String.format("Loading model %s", indexArgs.input));
    Map<String, List<float[]>> wordVectors = IndexVectors.readGloVe(indexArgs.input);
    Path indexDir = indexArgs.path;
    if (!Files.exists(indexDir)) {
        Files.createDirectories(indexDir);
    }
    System.out.println(String.format("Reading index at %s", indexArgs.path));
    Directory d = FSDirectory.open(indexDir);
    DirectoryReader reader = DirectoryReader.open(d);
    IndexSearcher searcher = new IndexSearcher(reader);
    if (indexArgs.encoding.equalsIgnoreCase(FW)) {
        searcher.setSimilarity(new ClassicSimilarity());
    }
    StandardAnalyzer standardAnalyzer = new StandardAnalyzer();
    double recall = 0;
    double time = 0d;
    System.out.println("Evaluating at retrieval depth: " + indexArgs.depth);
    TrecTopicReader trecTopicReader = new TrecTopicReader(indexArgs.topicsPath);
    Collection<String> words = new LinkedList<>();
    trecTopicReader.read().values().forEach(e -> words.addAll(AnalyzerUtils.analyze(standardAnalyzer, e.get("title"))));
    int queryCount = 0;
    for (String word : words) {
        if (wordVectors.containsKey(word)) {
            Set<String> truth = nearestVector(wordVectors, word, indexArgs.topN);
            try {
                List<float[]> vectors = wordVectors.get(word);
                for (float[] vector : vectors) {
                    StringBuilder sb = new StringBuilder();
                    for (double fv : vector) {
                        if (sb.length() > 0) {
                            sb.append(' ');
                        }
                        sb.append(fv);
                    }
                    String fvString = sb.toString();
                    CommonTermsQuery simQuery = new CommonTermsQuery(SHOULD, SHOULD, indexArgs.cutoff);
                    if (indexArgs.msm > 0) {
                        simQuery.setLowFreqMinimumNumberShouldMatch(indexArgs.msm);
                    }
                    for (String token : AnalyzerUtils.analyze(vectorAnalyzer, fvString)) {
                        simQuery.add(new Term(IndexVectors.FIELD_VECTOR, token));
                    }
                    long start = System.currentTimeMillis();
                    TopScoreDocCollector results = TopScoreDocCollector.create(indexArgs.depth, Integer.MAX_VALUE);
                    searcher.search(simQuery, results);
                    time += System.currentTimeMillis() - start;
                    Set<String> observations = new HashSet<>();
                    for (ScoreDoc sd : results.topDocs().scoreDocs) {
                        Document document = reader.document(sd.doc);
                        String wordValue = document.get(IndexVectors.FIELD_ID);
                        observations.add(wordValue);
                    }
                    double intersection = Sets.intersection(truth, observations).size();
                    double localRecall = intersection / (double) truth.size();
                    recall += localRecall;
                    queryCount++;
                }
            } catch (IOException e) {
                System.err.println("search for '" + word + "' failed " + e.getLocalizedMessage());
            }
        }
        if (queryCount >= indexArgs.samples) {
            break;
        }
    }
    recall /= queryCount;
    time /= queryCount;
    System.out.println(String.format("R@%d: %.4f", indexArgs.depth, recall));
    System.out.println(String.format("avg query time: %s ms", time));
    reader.close();
    d.close();
}
Also used : IndexSearcher(org.apache.lucene.search.IndexSearcher) ClassicSimilarity(org.apache.lucene.search.similarities.ClassicSimilarity) LexicalLshAnalyzer(io.anserini.ann.lexlsh.LexicalLshAnalyzer) FakeWordsEncoderAnalyzer(io.anserini.ann.fw.FakeWordsEncoderAnalyzer) Analyzer(org.apache.lucene.analysis.Analyzer) StandardAnalyzer(org.apache.lucene.analysis.standard.StandardAnalyzer) Document(org.apache.lucene.document.Document) CommonTermsQuery(org.apache.lucene.queries.CommonTermsQuery) ScoreDoc(org.apache.lucene.search.ScoreDoc) LinkedList(java.util.LinkedList) List(java.util.List) Directory(org.apache.lucene.store.Directory) FSDirectory(org.apache.lucene.store.FSDirectory) HashSet(java.util.HashSet) Path(java.nio.file.Path) CmdLineParser(org.kohsuke.args4j.CmdLineParser) DirectoryReader(org.apache.lucene.index.DirectoryReader) TopScoreDocCollector(org.apache.lucene.search.TopScoreDocCollector) FakeWordsEncoderAnalyzer(io.anserini.ann.fw.FakeWordsEncoderAnalyzer) LexicalLshAnalyzer(io.anserini.ann.lexlsh.LexicalLshAnalyzer) Term(org.apache.lucene.index.Term) IOException(java.io.IOException) LinkedList(java.util.LinkedList) StandardAnalyzer(org.apache.lucene.analysis.standard.StandardAnalyzer) CmdLineException(org.kohsuke.args4j.CmdLineException) TrecTopicReader(io.anserini.search.topicreader.TrecTopicReader)

Aggregations

FakeWordsEncoderAnalyzer (io.anserini.ann.fw.FakeWordsEncoderAnalyzer)1 LexicalLshAnalyzer (io.anserini.ann.lexlsh.LexicalLshAnalyzer)1 TrecTopicReader (io.anserini.search.topicreader.TrecTopicReader)1 IOException (java.io.IOException)1 Path (java.nio.file.Path)1 HashSet (java.util.HashSet)1 LinkedList (java.util.LinkedList)1 List (java.util.List)1 Analyzer (org.apache.lucene.analysis.Analyzer)1 StandardAnalyzer (org.apache.lucene.analysis.standard.StandardAnalyzer)1 Document (org.apache.lucene.document.Document)1 DirectoryReader (org.apache.lucene.index.DirectoryReader)1 Term (org.apache.lucene.index.Term)1 CommonTermsQuery (org.apache.lucene.queries.CommonTermsQuery)1 IndexSearcher (org.apache.lucene.search.IndexSearcher)1 ScoreDoc (org.apache.lucene.search.ScoreDoc)1 TopScoreDocCollector (org.apache.lucene.search.TopScoreDocCollector)1 ClassicSimilarity (org.apache.lucene.search.similarities.ClassicSimilarity)1 Directory (org.apache.lucene.store.Directory)1 FSDirectory (org.apache.lucene.store.FSDirectory)1