use of io.anserini.ann.fw.FakeWordsEncoderAnalyzer in project Anserini by castorini.
the class IndexVectors method main.
public static void main(String[] args) throws Exception {
IndexVectors.Args indexArgs = new IndexVectors.Args();
CmdLineParser parser = new CmdLineParser(indexArgs, ParserProperties.defaults().withUsageWidth(90));
try {
parser.parseArgument(args);
} catch (CmdLineException e) {
System.err.println(e.getMessage());
parser.printUsage(System.err);
System.err.println("Example: " + IndexVectors.class.getSimpleName() + parser.printExample(OptionHandlerFilter.REQUIRED));
return;
}
Analyzer vectorAnalyzer;
if (indexArgs.encoding.equalsIgnoreCase(FW)) {
vectorAnalyzer = new FakeWordsEncoderAnalyzer(indexArgs.q);
} else if (indexArgs.encoding.equalsIgnoreCase(LEXLSH)) {
vectorAnalyzer = new LexicalLshAnalyzer(indexArgs.decimals, indexArgs.ngrams, indexArgs.hashCount, indexArgs.bucketCount, indexArgs.hashSetSize);
} else {
parser.printUsage(System.err);
System.err.println("Example: " + IndexVectors.class.getSimpleName() + parser.printExample(OptionHandlerFilter.REQUIRED));
return;
}
final long start = System.nanoTime();
System.out.println(String.format("Loading model %s", indexArgs.input));
Map<String, List<float[]>> vectors = readGloVe(indexArgs.input);
Path indexDir = indexArgs.path;
if (!Files.exists(indexDir)) {
Files.createDirectories(indexDir);
}
System.out.println(String.format("Creating index at %s...", indexArgs.path));
Directory d = FSDirectory.open(indexDir);
Map<String, Analyzer> map = new HashMap<>();
map.put(FIELD_VECTOR, vectorAnalyzer);
Analyzer analyzer = new PerFieldAnalyzerWrapper(new StandardAnalyzer(), map);
IndexWriterConfig conf = new IndexWriterConfig(analyzer);
IndexWriter indexWriter = new IndexWriter(d, conf);
final AtomicInteger cnt = new AtomicInteger();
for (Map.Entry<String, List<float[]>> entry : vectors.entrySet()) {
for (float[] vector : entry.getValue()) {
Document doc = new Document();
doc.add(new StringField(FIELD_ID, entry.getKey(), Field.Store.YES));
StringBuilder sb = new StringBuilder();
for (double fv : vector) {
if (sb.length() > 0) {
sb.append(' ');
}
sb.append(fv);
}
doc.add(new TextField(FIELD_VECTOR, sb.toString(), indexArgs.stored ? Field.Store.YES : Field.Store.NO));
try {
indexWriter.addDocument(doc);
int cur = cnt.incrementAndGet();
if (cur % 100000 == 0) {
System.out.println(String.format("%s docs added", cnt));
}
} catch (IOException e) {
System.err.println("Error while indexing: " + e.getLocalizedMessage());
}
}
}
indexWriter.commit();
System.out.println(String.format("%s docs indexed", cnt.get()));
long space = FileUtils.sizeOfDirectory(indexDir.toFile()) / (1024L * 1024L);
System.out.println(String.format("Index size: %dMB", space));
indexWriter.close();
d.close();
final long durationMillis = TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS);
System.out.println(String.format("Total time: %s", DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss")));
}
use of io.anserini.ann.fw.FakeWordsEncoderAnalyzer in project Anserini by castorini.
the class ApproximateNearestNeighborSearch method main.
public static void main(String[] args) throws Exception {
ApproximateNearestNeighborSearch.Args indexArgs = new ApproximateNearestNeighborSearch.Args();
CmdLineParser parser = new CmdLineParser(indexArgs, ParserProperties.defaults().withUsageWidth(90));
try {
parser.parseArgument(args);
} catch (CmdLineException e) {
System.err.println(e.getMessage());
parser.printUsage(System.err);
System.err.println("Example: " + ApproximateNearestNeighborSearch.class.getSimpleName() + parser.printExample(OptionHandlerFilter.REQUIRED));
return;
}
Analyzer vectorAnalyzer;
if (indexArgs.encoding.equalsIgnoreCase(FW)) {
vectorAnalyzer = new FakeWordsEncoderAnalyzer(indexArgs.q);
} else if (indexArgs.encoding.equalsIgnoreCase(LEXLSH)) {
vectorAnalyzer = new LexicalLshAnalyzer(indexArgs.decimals, indexArgs.ngrams, indexArgs.hashCount, indexArgs.bucketCount, indexArgs.hashSetSize);
} else {
parser.printUsage(System.err);
System.err.println("Example: " + ApproximateNearestNeighborSearch.class.getSimpleName() + parser.printExample(OptionHandlerFilter.REQUIRED));
return;
}
if (!indexArgs.stored && indexArgs.input == null) {
System.err.println("Either -path or -stored args must be set");
return;
}
Path indexDir = indexArgs.path;
if (!Files.exists(indexDir)) {
Files.createDirectories(indexDir);
}
System.out.println(String.format("Reading index at %s", indexArgs.path));
Directory d = FSDirectory.open(indexDir);
DirectoryReader reader = DirectoryReader.open(d);
IndexSearcher searcher = new IndexSearcher(reader);
if (indexArgs.encoding.equalsIgnoreCase(FW)) {
searcher.setSimilarity(new ClassicSimilarity());
}
Collection<String> vectorStrings = new LinkedList<>();
if (indexArgs.stored) {
TopDocs topDocs = searcher.search(new TermQuery(new Term(IndexVectors.FIELD_ID, indexArgs.word)), indexArgs.depth);
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
vectorStrings.add(reader.document(scoreDoc.doc).get(IndexVectors.FIELD_VECTOR));
}
} else {
System.out.println(String.format("Loading model %s", indexArgs.input));
Map<String, List<float[]>> wordVectors = IndexVectors.readGloVe(indexArgs.input);
if (wordVectors.containsKey(indexArgs.word)) {
List<float[]> vectors = wordVectors.get(indexArgs.word);
for (float[] vector : vectors) {
StringBuilder sb = new StringBuilder();
for (double fv : vector) {
if (sb.length() > 0) {
sb.append(' ');
}
sb.append(fv);
}
String vectorString = sb.toString();
vectorStrings.add(vectorString);
}
}
}
for (String vectorString : vectorStrings) {
float msm = indexArgs.msm;
float cutoff = indexArgs.cutoff;
CommonTermsQuery simQuery = new CommonTermsQuery(SHOULD, SHOULD, cutoff);
for (String token : AnalyzerUtils.analyze(vectorAnalyzer, vectorString)) {
simQuery.add(new Term(IndexVectors.FIELD_VECTOR, token));
}
if (msm > 0) {
simQuery.setHighFreqMinimumNumberShouldMatch(msm);
simQuery.setLowFreqMinimumNumberShouldMatch(msm);
}
long start = System.currentTimeMillis();
TopScoreDocCollector results = TopScoreDocCollector.create(indexArgs.depth, Integer.MAX_VALUE);
searcher.search(simQuery, results);
long time = System.currentTimeMillis() - start;
System.out.println(String.format("%d nearest neighbors of '%s':", indexArgs.depth, indexArgs.word));
int rank = 1;
for (ScoreDoc sd : results.topDocs().scoreDocs) {
Document document = reader.document(sd.doc);
String word = document.get(IndexVectors.FIELD_ID);
System.out.println(String.format("%d. %s (%.3f)", rank, word, sd.score));
rank++;
}
System.out.println(String.format("Search time: %dms", time));
}
reader.close();
d.close();
}
use of io.anserini.ann.fw.FakeWordsEncoderAnalyzer in project Anserini by castorini.
the class ApproximateNearestNeighborEval method main.
public static void main(String[] args) throws Exception {
ApproximateNearestNeighborEval.Args indexArgs = new ApproximateNearestNeighborEval.Args();
CmdLineParser parser = new CmdLineParser(indexArgs, ParserProperties.defaults().withUsageWidth(90));
try {
parser.parseArgument(args);
} catch (CmdLineException e) {
System.err.println(e.getMessage());
parser.printUsage(System.err);
System.err.println("Example: " + ApproximateNearestNeighborEval.class.getSimpleName() + parser.printExample(OptionHandlerFilter.REQUIRED));
return;
}
Analyzer vectorAnalyzer;
if (indexArgs.encoding.equalsIgnoreCase(FW)) {
vectorAnalyzer = new FakeWordsEncoderAnalyzer(indexArgs.q);
} else if (indexArgs.encoding.equalsIgnoreCase(LEXLSH)) {
vectorAnalyzer = new LexicalLshAnalyzer(indexArgs.decimals, indexArgs.ngrams, indexArgs.hashCount, indexArgs.bucketCount, indexArgs.hashSetSize);
} else {
parser.printUsage(System.err);
System.err.println("Example: " + ApproximateNearestNeighborEval.class.getSimpleName() + parser.printExample(OptionHandlerFilter.REQUIRED));
return;
}
System.out.println(String.format("Loading model %s", indexArgs.input));
Map<String, List<float[]>> wordVectors = IndexVectors.readGloVe(indexArgs.input);
Path indexDir = indexArgs.path;
if (!Files.exists(indexDir)) {
Files.createDirectories(indexDir);
}
System.out.println(String.format("Reading index at %s", indexArgs.path));
Directory d = FSDirectory.open(indexDir);
DirectoryReader reader = DirectoryReader.open(d);
IndexSearcher searcher = new IndexSearcher(reader);
if (indexArgs.encoding.equalsIgnoreCase(FW)) {
searcher.setSimilarity(new ClassicSimilarity());
}
StandardAnalyzer standardAnalyzer = new StandardAnalyzer();
double recall = 0;
double time = 0d;
System.out.println("Evaluating at retrieval depth: " + indexArgs.depth);
TrecTopicReader trecTopicReader = new TrecTopicReader(indexArgs.topicsPath);
Collection<String> words = new LinkedList<>();
trecTopicReader.read().values().forEach(e -> words.addAll(AnalyzerUtils.analyze(standardAnalyzer, e.get("title"))));
int queryCount = 0;
for (String word : words) {
if (wordVectors.containsKey(word)) {
Set<String> truth = nearestVector(wordVectors, word, indexArgs.topN);
try {
List<float[]> vectors = wordVectors.get(word);
for (float[] vector : vectors) {
StringBuilder sb = new StringBuilder();
for (double fv : vector) {
if (sb.length() > 0) {
sb.append(' ');
}
sb.append(fv);
}
String fvString = sb.toString();
CommonTermsQuery simQuery = new CommonTermsQuery(SHOULD, SHOULD, indexArgs.cutoff);
if (indexArgs.msm > 0) {
simQuery.setLowFreqMinimumNumberShouldMatch(indexArgs.msm);
}
for (String token : AnalyzerUtils.analyze(vectorAnalyzer, fvString)) {
simQuery.add(new Term(IndexVectors.FIELD_VECTOR, token));
}
long start = System.currentTimeMillis();
TopScoreDocCollector results = TopScoreDocCollector.create(indexArgs.depth, Integer.MAX_VALUE);
searcher.search(simQuery, results);
time += System.currentTimeMillis() - start;
Set<String> observations = new HashSet<>();
for (ScoreDoc sd : results.topDocs().scoreDocs) {
Document document = reader.document(sd.doc);
String wordValue = document.get(IndexVectors.FIELD_ID);
observations.add(wordValue);
}
double intersection = Sets.intersection(truth, observations).size();
double localRecall = intersection / (double) truth.size();
recall += localRecall;
queryCount++;
}
} catch (IOException e) {
System.err.println("search for '" + word + "' failed " + e.getLocalizedMessage());
}
}
if (queryCount >= indexArgs.samples) {
break;
}
}
recall /= queryCount;
time /= queryCount;
System.out.println(String.format("R@%d: %.4f", indexArgs.depth, recall));
System.out.println(String.format("avg query time: %s ms", time));
reader.close();
d.close();
}
Aggregations