Search in sources :

Example 1 with FakeWordsEncoderAnalyzer

use of io.anserini.ann.fw.FakeWordsEncoderAnalyzer in project Anserini by castorini.

the class IndexVectors method main.

public static void main(String[] args) throws Exception {
    IndexVectors.Args indexArgs = new IndexVectors.Args();
    CmdLineParser parser = new CmdLineParser(indexArgs, ParserProperties.defaults().withUsageWidth(90));
    try {
        parser.parseArgument(args);
    } catch (CmdLineException e) {
        System.err.println(e.getMessage());
        parser.printUsage(System.err);
        System.err.println("Example: " + IndexVectors.class.getSimpleName() + parser.printExample(OptionHandlerFilter.REQUIRED));
        return;
    }
    Analyzer vectorAnalyzer;
    if (indexArgs.encoding.equalsIgnoreCase(FW)) {
        vectorAnalyzer = new FakeWordsEncoderAnalyzer(indexArgs.q);
    } else if (indexArgs.encoding.equalsIgnoreCase(LEXLSH)) {
        vectorAnalyzer = new LexicalLshAnalyzer(indexArgs.decimals, indexArgs.ngrams, indexArgs.hashCount, indexArgs.bucketCount, indexArgs.hashSetSize);
    } else {
        parser.printUsage(System.err);
        System.err.println("Example: " + IndexVectors.class.getSimpleName() + parser.printExample(OptionHandlerFilter.REQUIRED));
        return;
    }
    final long start = System.nanoTime();
    System.out.println(String.format("Loading model %s", indexArgs.input));
    Map<String, List<float[]>> vectors = readGloVe(indexArgs.input);
    Path indexDir = indexArgs.path;
    if (!Files.exists(indexDir)) {
        Files.createDirectories(indexDir);
    }
    System.out.println(String.format("Creating index at %s...", indexArgs.path));
    Directory d = FSDirectory.open(indexDir);
    Map<String, Analyzer> map = new HashMap<>();
    map.put(FIELD_VECTOR, vectorAnalyzer);
    Analyzer analyzer = new PerFieldAnalyzerWrapper(new StandardAnalyzer(), map);
    IndexWriterConfig conf = new IndexWriterConfig(analyzer);
    IndexWriter indexWriter = new IndexWriter(d, conf);
    final AtomicInteger cnt = new AtomicInteger();
    for (Map.Entry<String, List<float[]>> entry : vectors.entrySet()) {
        for (float[] vector : entry.getValue()) {
            Document doc = new Document();
            doc.add(new StringField(FIELD_ID, entry.getKey(), Field.Store.YES));
            StringBuilder sb = new StringBuilder();
            for (double fv : vector) {
                if (sb.length() > 0) {
                    sb.append(' ');
                }
                sb.append(fv);
            }
            doc.add(new TextField(FIELD_VECTOR, sb.toString(), indexArgs.stored ? Field.Store.YES : Field.Store.NO));
            try {
                indexWriter.addDocument(doc);
                int cur = cnt.incrementAndGet();
                if (cur % 100000 == 0) {
                    System.out.println(String.format("%s docs added", cnt));
                }
            } catch (IOException e) {
                System.err.println("Error while indexing: " + e.getLocalizedMessage());
            }
        }
    }
    indexWriter.commit();
    System.out.println(String.format("%s docs indexed", cnt.get()));
    long space = FileUtils.sizeOfDirectory(indexDir.toFile()) / (1024L * 1024L);
    System.out.println(String.format("Index size: %dMB", space));
    indexWriter.close();
    d.close();
    final long durationMillis = TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS);
    System.out.println(String.format("Total time: %s", DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss")));
}
Also used : HashMap(java.util.HashMap) LexicalLshAnalyzer(io.anserini.ann.lexlsh.LexicalLshAnalyzer) FakeWordsEncoderAnalyzer(io.anserini.ann.fw.FakeWordsEncoderAnalyzer) Analyzer(org.apache.lucene.analysis.Analyzer) StandardAnalyzer(org.apache.lucene.analysis.standard.StandardAnalyzer) Document(org.apache.lucene.document.Document) TextField(org.apache.lucene.document.TextField) LinkedList(java.util.LinkedList) List(java.util.List) Directory(org.apache.lucene.store.Directory) FSDirectory(org.apache.lucene.store.FSDirectory) Path(java.nio.file.Path) CmdLineParser(org.kohsuke.args4j.CmdLineParser) FakeWordsEncoderAnalyzer(io.anserini.ann.fw.FakeWordsEncoderAnalyzer) LexicalLshAnalyzer(io.anserini.ann.lexlsh.LexicalLshAnalyzer) IOException(java.io.IOException) PerFieldAnalyzerWrapper(org.apache.lucene.analysis.miscellaneous.PerFieldAnalyzerWrapper) IndexWriter(org.apache.lucene.index.IndexWriter) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) StandardAnalyzer(org.apache.lucene.analysis.standard.StandardAnalyzer) StringField(org.apache.lucene.document.StringField) HashMap(java.util.HashMap) Map(java.util.Map) CmdLineException(org.kohsuke.args4j.CmdLineException) IndexWriterConfig(org.apache.lucene.index.IndexWriterConfig)

Example 2 with FakeWordsEncoderAnalyzer

use of io.anserini.ann.fw.FakeWordsEncoderAnalyzer in project Anserini by castorini.

the class ApproximateNearestNeighborSearch method main.

public static void main(String[] args) throws Exception {
    ApproximateNearestNeighborSearch.Args indexArgs = new ApproximateNearestNeighborSearch.Args();
    CmdLineParser parser = new CmdLineParser(indexArgs, ParserProperties.defaults().withUsageWidth(90));
    try {
        parser.parseArgument(args);
    } catch (CmdLineException e) {
        System.err.println(e.getMessage());
        parser.printUsage(System.err);
        System.err.println("Example: " + ApproximateNearestNeighborSearch.class.getSimpleName() + parser.printExample(OptionHandlerFilter.REQUIRED));
        return;
    }
    Analyzer vectorAnalyzer;
    if (indexArgs.encoding.equalsIgnoreCase(FW)) {
        vectorAnalyzer = new FakeWordsEncoderAnalyzer(indexArgs.q);
    } else if (indexArgs.encoding.equalsIgnoreCase(LEXLSH)) {
        vectorAnalyzer = new LexicalLshAnalyzer(indexArgs.decimals, indexArgs.ngrams, indexArgs.hashCount, indexArgs.bucketCount, indexArgs.hashSetSize);
    } else {
        parser.printUsage(System.err);
        System.err.println("Example: " + ApproximateNearestNeighborSearch.class.getSimpleName() + parser.printExample(OptionHandlerFilter.REQUIRED));
        return;
    }
    if (!indexArgs.stored && indexArgs.input == null) {
        System.err.println("Either -path or -stored args must be set");
        return;
    }
    Path indexDir = indexArgs.path;
    if (!Files.exists(indexDir)) {
        Files.createDirectories(indexDir);
    }
    System.out.println(String.format("Reading index at %s", indexArgs.path));
    Directory d = FSDirectory.open(indexDir);
    DirectoryReader reader = DirectoryReader.open(d);
    IndexSearcher searcher = new IndexSearcher(reader);
    if (indexArgs.encoding.equalsIgnoreCase(FW)) {
        searcher.setSimilarity(new ClassicSimilarity());
    }
    Collection<String> vectorStrings = new LinkedList<>();
    if (indexArgs.stored) {
        TopDocs topDocs = searcher.search(new TermQuery(new Term(IndexVectors.FIELD_ID, indexArgs.word)), indexArgs.depth);
        for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
            vectorStrings.add(reader.document(scoreDoc.doc).get(IndexVectors.FIELD_VECTOR));
        }
    } else {
        System.out.println(String.format("Loading model %s", indexArgs.input));
        Map<String, List<float[]>> wordVectors = IndexVectors.readGloVe(indexArgs.input);
        if (wordVectors.containsKey(indexArgs.word)) {
            List<float[]> vectors = wordVectors.get(indexArgs.word);
            for (float[] vector : vectors) {
                StringBuilder sb = new StringBuilder();
                for (double fv : vector) {
                    if (sb.length() > 0) {
                        sb.append(' ');
                    }
                    sb.append(fv);
                }
                String vectorString = sb.toString();
                vectorStrings.add(vectorString);
            }
        }
    }
    for (String vectorString : vectorStrings) {
        float msm = indexArgs.msm;
        float cutoff = indexArgs.cutoff;
        CommonTermsQuery simQuery = new CommonTermsQuery(SHOULD, SHOULD, cutoff);
        for (String token : AnalyzerUtils.analyze(vectorAnalyzer, vectorString)) {
            simQuery.add(new Term(IndexVectors.FIELD_VECTOR, token));
        }
        if (msm > 0) {
            simQuery.setHighFreqMinimumNumberShouldMatch(msm);
            simQuery.setLowFreqMinimumNumberShouldMatch(msm);
        }
        long start = System.currentTimeMillis();
        TopScoreDocCollector results = TopScoreDocCollector.create(indexArgs.depth, Integer.MAX_VALUE);
        searcher.search(simQuery, results);
        long time = System.currentTimeMillis() - start;
        System.out.println(String.format("%d nearest neighbors of '%s':", indexArgs.depth, indexArgs.word));
        int rank = 1;
        for (ScoreDoc sd : results.topDocs().scoreDocs) {
            Document document = reader.document(sd.doc);
            String word = document.get(IndexVectors.FIELD_ID);
            System.out.println(String.format("%d. %s (%.3f)", rank, word, sd.score));
            rank++;
        }
        System.out.println(String.format("Search time: %dms", time));
    }
    reader.close();
    d.close();
}
Also used : IndexSearcher(org.apache.lucene.search.IndexSearcher) ClassicSimilarity(org.apache.lucene.search.similarities.ClassicSimilarity) LexicalLshAnalyzer(io.anserini.ann.lexlsh.LexicalLshAnalyzer) FakeWordsEncoderAnalyzer(io.anserini.ann.fw.FakeWordsEncoderAnalyzer) Analyzer(org.apache.lucene.analysis.Analyzer) Document(org.apache.lucene.document.Document) ScoreDoc(org.apache.lucene.search.ScoreDoc) CommonTermsQuery(org.apache.lucene.queries.CommonTermsQuery) TopDocs(org.apache.lucene.search.TopDocs) LinkedList(java.util.LinkedList) List(java.util.List) Directory(org.apache.lucene.store.Directory) FSDirectory(org.apache.lucene.store.FSDirectory) Path(java.nio.file.Path) TermQuery(org.apache.lucene.search.TermQuery) CmdLineParser(org.kohsuke.args4j.CmdLineParser) DirectoryReader(org.apache.lucene.index.DirectoryReader) TopScoreDocCollector(org.apache.lucene.search.TopScoreDocCollector) FakeWordsEncoderAnalyzer(io.anserini.ann.fw.FakeWordsEncoderAnalyzer) LexicalLshAnalyzer(io.anserini.ann.lexlsh.LexicalLshAnalyzer) Term(org.apache.lucene.index.Term) LinkedList(java.util.LinkedList) CmdLineException(org.kohsuke.args4j.CmdLineException)

Example 3 with FakeWordsEncoderAnalyzer

use of io.anserini.ann.fw.FakeWordsEncoderAnalyzer in project Anserini by castorini.

the class ApproximateNearestNeighborEval method main.

public static void main(String[] args) throws Exception {
    ApproximateNearestNeighborEval.Args indexArgs = new ApproximateNearestNeighborEval.Args();
    CmdLineParser parser = new CmdLineParser(indexArgs, ParserProperties.defaults().withUsageWidth(90));
    try {
        parser.parseArgument(args);
    } catch (CmdLineException e) {
        System.err.println(e.getMessage());
        parser.printUsage(System.err);
        System.err.println("Example: " + ApproximateNearestNeighborEval.class.getSimpleName() + parser.printExample(OptionHandlerFilter.REQUIRED));
        return;
    }
    Analyzer vectorAnalyzer;
    if (indexArgs.encoding.equalsIgnoreCase(FW)) {
        vectorAnalyzer = new FakeWordsEncoderAnalyzer(indexArgs.q);
    } else if (indexArgs.encoding.equalsIgnoreCase(LEXLSH)) {
        vectorAnalyzer = new LexicalLshAnalyzer(indexArgs.decimals, indexArgs.ngrams, indexArgs.hashCount, indexArgs.bucketCount, indexArgs.hashSetSize);
    } else {
        parser.printUsage(System.err);
        System.err.println("Example: " + ApproximateNearestNeighborEval.class.getSimpleName() + parser.printExample(OptionHandlerFilter.REQUIRED));
        return;
    }
    System.out.println(String.format("Loading model %s", indexArgs.input));
    Map<String, List<float[]>> wordVectors = IndexVectors.readGloVe(indexArgs.input);
    Path indexDir = indexArgs.path;
    if (!Files.exists(indexDir)) {
        Files.createDirectories(indexDir);
    }
    System.out.println(String.format("Reading index at %s", indexArgs.path));
    Directory d = FSDirectory.open(indexDir);
    DirectoryReader reader = DirectoryReader.open(d);
    IndexSearcher searcher = new IndexSearcher(reader);
    if (indexArgs.encoding.equalsIgnoreCase(FW)) {
        searcher.setSimilarity(new ClassicSimilarity());
    }
    StandardAnalyzer standardAnalyzer = new StandardAnalyzer();
    double recall = 0;
    double time = 0d;
    System.out.println("Evaluating at retrieval depth: " + indexArgs.depth);
    TrecTopicReader trecTopicReader = new TrecTopicReader(indexArgs.topicsPath);
    Collection<String> words = new LinkedList<>();
    trecTopicReader.read().values().forEach(e -> words.addAll(AnalyzerUtils.analyze(standardAnalyzer, e.get("title"))));
    int queryCount = 0;
    for (String word : words) {
        if (wordVectors.containsKey(word)) {
            Set<String> truth = nearestVector(wordVectors, word, indexArgs.topN);
            try {
                List<float[]> vectors = wordVectors.get(word);
                for (float[] vector : vectors) {
                    StringBuilder sb = new StringBuilder();
                    for (double fv : vector) {
                        if (sb.length() > 0) {
                            sb.append(' ');
                        }
                        sb.append(fv);
                    }
                    String fvString = sb.toString();
                    CommonTermsQuery simQuery = new CommonTermsQuery(SHOULD, SHOULD, indexArgs.cutoff);
                    if (indexArgs.msm > 0) {
                        simQuery.setLowFreqMinimumNumberShouldMatch(indexArgs.msm);
                    }
                    for (String token : AnalyzerUtils.analyze(vectorAnalyzer, fvString)) {
                        simQuery.add(new Term(IndexVectors.FIELD_VECTOR, token));
                    }
                    long start = System.currentTimeMillis();
                    TopScoreDocCollector results = TopScoreDocCollector.create(indexArgs.depth, Integer.MAX_VALUE);
                    searcher.search(simQuery, results);
                    time += System.currentTimeMillis() - start;
                    Set<String> observations = new HashSet<>();
                    for (ScoreDoc sd : results.topDocs().scoreDocs) {
                        Document document = reader.document(sd.doc);
                        String wordValue = document.get(IndexVectors.FIELD_ID);
                        observations.add(wordValue);
                    }
                    double intersection = Sets.intersection(truth, observations).size();
                    double localRecall = intersection / (double) truth.size();
                    recall += localRecall;
                    queryCount++;
                }
            } catch (IOException e) {
                System.err.println("search for '" + word + "' failed " + e.getLocalizedMessage());
            }
        }
        if (queryCount >= indexArgs.samples) {
            break;
        }
    }
    recall /= queryCount;
    time /= queryCount;
    System.out.println(String.format("R@%d: %.4f", indexArgs.depth, recall));
    System.out.println(String.format("avg query time: %s ms", time));
    reader.close();
    d.close();
}
Also used : IndexSearcher(org.apache.lucene.search.IndexSearcher) ClassicSimilarity(org.apache.lucene.search.similarities.ClassicSimilarity) LexicalLshAnalyzer(io.anserini.ann.lexlsh.LexicalLshAnalyzer) FakeWordsEncoderAnalyzer(io.anserini.ann.fw.FakeWordsEncoderAnalyzer) Analyzer(org.apache.lucene.analysis.Analyzer) StandardAnalyzer(org.apache.lucene.analysis.standard.StandardAnalyzer) Document(org.apache.lucene.document.Document) CommonTermsQuery(org.apache.lucene.queries.CommonTermsQuery) ScoreDoc(org.apache.lucene.search.ScoreDoc) LinkedList(java.util.LinkedList) List(java.util.List) Directory(org.apache.lucene.store.Directory) FSDirectory(org.apache.lucene.store.FSDirectory) HashSet(java.util.HashSet) Path(java.nio.file.Path) CmdLineParser(org.kohsuke.args4j.CmdLineParser) DirectoryReader(org.apache.lucene.index.DirectoryReader) TopScoreDocCollector(org.apache.lucene.search.TopScoreDocCollector) FakeWordsEncoderAnalyzer(io.anserini.ann.fw.FakeWordsEncoderAnalyzer) LexicalLshAnalyzer(io.anserini.ann.lexlsh.LexicalLshAnalyzer) Term(org.apache.lucene.index.Term) IOException(java.io.IOException) LinkedList(java.util.LinkedList) StandardAnalyzer(org.apache.lucene.analysis.standard.StandardAnalyzer) CmdLineException(org.kohsuke.args4j.CmdLineException) TrecTopicReader(io.anserini.search.topicreader.TrecTopicReader)

Aggregations

FakeWordsEncoderAnalyzer (io.anserini.ann.fw.FakeWordsEncoderAnalyzer)3 LexicalLshAnalyzer (io.anserini.ann.lexlsh.LexicalLshAnalyzer)3 Path (java.nio.file.Path)3 LinkedList (java.util.LinkedList)3 List (java.util.List)3 Analyzer (org.apache.lucene.analysis.Analyzer)3 Document (org.apache.lucene.document.Document)3 Directory (org.apache.lucene.store.Directory)3 FSDirectory (org.apache.lucene.store.FSDirectory)3 CmdLineException (org.kohsuke.args4j.CmdLineException)3 CmdLineParser (org.kohsuke.args4j.CmdLineParser)3 IOException (java.io.IOException)2 StandardAnalyzer (org.apache.lucene.analysis.standard.StandardAnalyzer)2 DirectoryReader (org.apache.lucene.index.DirectoryReader)2 Term (org.apache.lucene.index.Term)2 CommonTermsQuery (org.apache.lucene.queries.CommonTermsQuery)2 IndexSearcher (org.apache.lucene.search.IndexSearcher)2 ScoreDoc (org.apache.lucene.search.ScoreDoc)2 TopScoreDocCollector (org.apache.lucene.search.TopScoreDocCollector)2 ClassicSimilarity (org.apache.lucene.search.similarities.ClassicSimilarity)2