use of org.apache.solr.ltr.feature.Feature in project lucene-solr by apache.
the class TestLTRReRankingPipeline method testDifferentTopN.
@Ignore
@Test
public void testDifferentTopN() throws IOException {
final Directory dir = newDirectory();
final RandomIndexWriter w = new RandomIndexWriter(random(), dir);
Document doc = new Document();
doc.add(newStringField("id", "0", Field.Store.YES));
doc.add(newTextField("field", "wizard oz oz oz oz oz", Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 1.0f));
w.addDocument(doc);
doc = new Document();
doc.add(newStringField("id", "1", Field.Store.YES));
doc.add(newTextField("field", "wizard oz oz oz oz the", Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 2.0f));
w.addDocument(doc);
doc = new Document();
doc.add(newStringField("id", "2", Field.Store.YES));
doc.add(newTextField("field", "wizard oz oz oz the the ", Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 3.0f));
w.addDocument(doc);
doc = new Document();
doc.add(newStringField("id", "3", Field.Store.YES));
doc.add(newTextField("field", "wizard oz oz the the the the ", Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 4.0f));
w.addDocument(doc);
doc = new Document();
doc.add(newStringField("id", "4", Field.Store.YES));
doc.add(newTextField("field", "wizard oz the the the the the the", Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 5.0f));
w.addDocument(doc);
final IndexReader r = w.getReader();
w.close();
// Do ordinary BooleanQuery:
final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
final IndexSearcher searcher = getSearcher(r);
// first run the standard query
TopDocs hits = searcher.search(bqBuilder.build(), 10);
assertEquals(5, hits.totalHits);
assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id"));
assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id"));
assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id"));
final List<Feature> features = makeFieldValueFeatures(new int[] { 0, 1, 2 }, "final-score");
final List<Normalizer> norms = new ArrayList<Normalizer>(Collections.nCopies(features.size(), IdentityNormalizer.INSTANCE));
final List<Feature> allFeatures = makeFieldValueFeatures(new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }, "final-score");
final LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test", features, norms, "test", allFeatures, null);
final LTRRescorer rescorer = new LTRRescorer(new LTRScoringQuery(ltrScoringModel));
// rerank @ 0 should not change the order
hits = rescorer.rescore(searcher, hits, 0);
assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id"));
assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id"));
assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id"));
for (int topN = 1; topN <= 5; topN++) {
log.info("rerank {} documents ", topN);
hits = searcher.search(bqBuilder.build(), 10);
final ScoreDoc[] slice = new ScoreDoc[topN];
System.arraycopy(hits.scoreDocs, 0, slice, 0, topN);
hits = new TopDocs(hits.totalHits, slice, hits.getMaxScore());
hits = rescorer.rescore(searcher, hits, topN);
for (int i = topN - 1, j = 0; i >= 0; i--, j++) {
log.info("doc {} in pos {}", searcher.doc(hits.scoreDocs[j].doc).get("id"), j);
assertEquals(i, Integer.parseInt(searcher.doc(hits.scoreDocs[j].doc).get("id")));
assertEquals(i + 1, hits.scoreDocs[j].score, 0.00001);
}
}
r.close();
dir.close();
}
use of org.apache.solr.ltr.feature.Feature in project lucene-solr by apache.
the class TestLTRScoringQuery method makeFeatureWeights.
private static Map<String, Object> makeFeatureWeights(List<Feature> features) {
final Map<String, Object> nameParams = new HashMap<String, Object>();
final HashMap<String, Double> modelWeights = new HashMap<String, Double>();
for (final Feature feat : features) {
modelWeights.put(feat.getName(), 0.1);
}
nameParams.put("weights", modelWeights);
return nameParams;
}
use of org.apache.solr.ltr.feature.Feature in project lucene-solr by apache.
the class TestLTRScoringQuery method makeFilterFeatures.
private static List<Feature> makeFilterFeatures(int[] featureIds) {
final List<Feature> features = new ArrayList<>();
for (final int i : featureIds) {
Map<String, Object> params = new HashMap<String, Object>();
params.put("value", i);
final Feature f = Feature.getInstance(solrResourceLoader, ValueFeature.class.getCanonicalName(), "f" + i, params);
f.setIndex(i);
features.add(f);
}
return features;
}
use of org.apache.solr.ltr.feature.Feature in project lucene-solr by apache.
the class TestLTRScoringQuery method makeFeatures.
private static List<Feature> makeFeatures(int[] featureIds) {
final List<Feature> features = new ArrayList<>();
for (final int i : featureIds) {
Map<String, Object> params = new HashMap<String, Object>();
params.put("value", i);
final Feature f = Feature.getInstance(solrResourceLoader, ValueFeature.class.getCanonicalName(), "f" + i, params);
f.setIndex(i);
features.add(f);
}
return features;
}
use of org.apache.solr.ltr.feature.Feature in project lucene-solr by apache.
the class TestLTRScoringQuery method testLTRScoringQueryEquality.
@Test
public void testLTRScoringQueryEquality() throws ModelException {
final List<Feature> features = makeFeatures(new int[] { 0, 1, 2 });
final List<Normalizer> norms = new ArrayList<Normalizer>(Collections.nCopies(features.size(), IdentityNormalizer.INSTANCE));
final List<Feature> allFeatures = makeFeatures(new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
final Map<String, Object> modelParams = makeFeatureWeights(features);
final LTRScoringModel algorithm1 = TestLinearModel.createLinearModel("testModelName", features, norms, "testStoreName", allFeatures, modelParams);
final LTRScoringQuery m0 = new LTRScoringQuery(algorithm1);
final HashMap<String, String[]> externalFeatureInfo = new HashMap<>();
externalFeatureInfo.put("queryIntent", new String[] { "company" });
externalFeatureInfo.put("user_query", new String[] { "abc" });
final LTRScoringQuery m1 = new LTRScoringQuery(algorithm1, externalFeatureInfo, false, null);
final HashMap<String, String[]> externalFeatureInfo2 = new HashMap<>();
externalFeatureInfo2.put("user_query", new String[] { "abc" });
externalFeatureInfo2.put("queryIntent", new String[] { "company" });
int totalPoolThreads = 10, numThreadsPerRequest = 10;
LTRThreadModule threadManager = new LTRThreadModule(totalPoolThreads, numThreadsPerRequest);
final LTRScoringQuery m2 = new LTRScoringQuery(algorithm1, externalFeatureInfo2, false, threadManager);
// Models with same algorithm and efis, just in different order should be the same
assertEquals(m1, m2);
assertEquals(m1.hashCode(), m2.hashCode());
// Models with same algorithm, but different efi content should not match
assertFalse(m1.equals(m0));
assertFalse(m1.hashCode() == m0.hashCode());
final LTRScoringModel algorithm2 = TestLinearModel.createLinearModel("testModelName2", features, norms, "testStoreName", allFeatures, modelParams);
final LTRScoringQuery m3 = new LTRScoringQuery(algorithm2);
assertFalse(m1.equals(m3));
assertFalse(m1.hashCode() == m3.hashCode());
final LTRScoringModel algorithm3 = TestLinearModel.createLinearModel("testModelName", features, norms, "testStoreName3", allFeatures, modelParams);
final LTRScoringQuery m4 = new LTRScoringQuery(algorithm3);
assertFalse(m1.equals(m4));
assertFalse(m1.hashCode() == m4.hashCode());
}
Aggregations