Search in sources :

Example 1 with LTRScoringModel

use of org.apache.solr.ltr.model.LTRScoringModel in project lucene-solr by apache.

the class TestLTRScoringQuery method testLTRScoringQuery.

@Test
public void testLTRScoringQuery() throws IOException, ModelException {
    final Directory dir = newDirectory();
    final RandomIndexWriter w = new RandomIndexWriter(random(), dir);
    Document doc = new Document();
    doc.add(newStringField("id", "0", Field.Store.YES));
    doc.add(newTextField("field", "wizard the the the the the oz", Field.Store.NO));
    doc.add(new FloatDocValuesField("final-score", 1.0f));
    w.addDocument(doc);
    doc = new Document();
    doc.add(newStringField("id", "1", Field.Store.YES));
    // 1 extra token, but wizard and oz are close;
    doc.add(newTextField("field", "wizard oz the the the the the the", Field.Store.NO));
    doc.add(new FloatDocValuesField("final-score", 2.0f));
    w.addDocument(doc);
    final IndexReader r = w.getReader();
    w.close();
    // Do ordinary BooleanQuery:
    final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
    bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
    bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
    final IndexSearcher searcher = getSearcher(r);
    // first run the standard query
    final TopDocs hits = searcher.search(bqBuilder.build(), 10);
    assertEquals(2, hits.totalHits);
    assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
    assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
    List<Feature> features = makeFeatures(new int[] { 0, 1, 2 });
    final List<Feature> allFeatures = makeFeatures(new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
    List<Normalizer> norms = new ArrayList<Normalizer>(Collections.nCopies(features.size(), IdentityNormalizer.INSTANCE));
    LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test", features, norms, "test", allFeatures, makeFeatureWeights(features));
    LTRScoringQuery.ModelWeight modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, new LTRScoringQuery(ltrScoringModel));
    assertEquals(3, modelWeight.getModelFeatureValuesNormalized().length);
    for (int i = 0; i < 3; i++) {
        assertEquals(i, modelWeight.getModelFeatureValuesNormalized()[i], 0.0001);
    }
    int[] posVals = new int[] { 0, 1, 2 };
    int pos = 0;
    for (LTRScoringQuery.FeatureInfo fInfo : modelWeight.getFeaturesInfo()) {
        if (fInfo == null) {
            continue;
        }
        assertEquals(posVals[pos], fInfo.getValue(), 0.0001);
        assertEquals("f" + posVals[pos], fInfo.getName());
        pos++;
    }
    final int[] mixPositions = new int[] { 8, 2, 4, 9, 0 };
    features = makeFeatures(mixPositions);
    norms = new ArrayList<Normalizer>(Collections.nCopies(features.size(), IdentityNormalizer.INSTANCE));
    ltrScoringModel = TestLinearModel.createLinearModel("test", features, norms, "test", allFeatures, makeFeatureWeights(features));
    modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, new LTRScoringQuery(ltrScoringModel));
    assertEquals(mixPositions.length, modelWeight.getModelFeatureWeights().length);
    for (int i = 0; i < mixPositions.length; i++) {
        assertEquals(mixPositions[i], modelWeight.getModelFeatureValuesNormalized()[i], 0.0001);
    }
    final ModelException expectedModelException = new ModelException("no features declared for model test");
    final int[] noPositions = new int[] {};
    features = makeFeatures(noPositions);
    norms = new ArrayList<Normalizer>(Collections.nCopies(features.size(), IdentityNormalizer.INSTANCE));
    try {
        ltrScoringModel = TestLinearModel.createLinearModel("test", features, norms, "test", allFeatures, makeFeatureWeights(features));
        fail("unexpectedly got here instead of catching " + expectedModelException);
        modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, new LTRScoringQuery(ltrScoringModel));
        assertEquals(0, modelWeight.getModelFeatureWeights().length);
    } catch (ModelException actualModelException) {
        assertEquals(expectedModelException.toString(), actualModelException.toString());
    }
    // test normalizers
    features = makeFilterFeatures(mixPositions);
    final Normalizer norm = new Normalizer() {

        @Override
        public float normalize(float value) {
            return 42.42f;
        }

        @Override
        public LinkedHashMap<String, Object> paramsToMap() {
            return null;
        }

        @Override
        protected void validate() throws NormalizerException {
        }
    };
    norms = new ArrayList<Normalizer>(Collections.nCopies(features.size(), norm));
    final LTRScoringModel normMeta = TestLinearModel.createLinearModel("test", features, norms, "test", allFeatures, makeFeatureWeights(features));
    modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, new LTRScoringQuery(normMeta));
    normMeta.normalizeFeaturesInPlace(modelWeight.getModelFeatureValuesNormalized());
    assertEquals(mixPositions.length, modelWeight.getModelFeatureWeights().length);
    for (int i = 0; i < mixPositions.length; i++) {
        assertEquals(42.42f, modelWeight.getModelFeatureValuesNormalized()[i], 0.0001);
    }
    r.close();
    dir.close();
}
Also used : IndexSearcher(org.apache.lucene.search.IndexSearcher) BooleanQuery(org.apache.lucene.search.BooleanQuery) ArrayList(java.util.ArrayList) FloatDocValuesField(org.apache.lucene.document.FloatDocValuesField) Document(org.apache.lucene.document.Document) ValueFeature(org.apache.solr.ltr.feature.ValueFeature) Feature(org.apache.solr.ltr.feature.Feature) TopDocs(org.apache.lucene.search.TopDocs) Directory(org.apache.lucene.store.Directory) TermQuery(org.apache.lucene.search.TermQuery) ModelException(org.apache.solr.ltr.model.ModelException) Normalizer(org.apache.solr.ltr.norm.Normalizer) IdentityNormalizer(org.apache.solr.ltr.norm.IdentityNormalizer) Term(org.apache.lucene.index.Term) IndexReader(org.apache.lucene.index.IndexReader) RandomIndexWriter(org.apache.lucene.index.RandomIndexWriter) LTRScoringModel(org.apache.solr.ltr.model.LTRScoringModel) Test(org.junit.Test)

Example 2 with LTRScoringModel

use of org.apache.solr.ltr.model.LTRScoringModel in project lucene-solr by apache.

the class TestLTRReRankingPipeline method testDifferentTopN.

@Ignore
@Test
public void testDifferentTopN() throws IOException {
    final Directory dir = newDirectory();
    final RandomIndexWriter w = new RandomIndexWriter(random(), dir);
    Document doc = new Document();
    doc.add(newStringField("id", "0", Field.Store.YES));
    doc.add(newTextField("field", "wizard oz oz oz oz oz", Field.Store.NO));
    doc.add(new FloatDocValuesField("final-score", 1.0f));
    w.addDocument(doc);
    doc = new Document();
    doc.add(newStringField("id", "1", Field.Store.YES));
    doc.add(newTextField("field", "wizard oz oz oz oz the", Field.Store.NO));
    doc.add(new FloatDocValuesField("final-score", 2.0f));
    w.addDocument(doc);
    doc = new Document();
    doc.add(newStringField("id", "2", Field.Store.YES));
    doc.add(newTextField("field", "wizard oz oz oz the the ", Field.Store.NO));
    doc.add(new FloatDocValuesField("final-score", 3.0f));
    w.addDocument(doc);
    doc = new Document();
    doc.add(newStringField("id", "3", Field.Store.YES));
    doc.add(newTextField("field", "wizard oz oz the the the the ", Field.Store.NO));
    doc.add(new FloatDocValuesField("final-score", 4.0f));
    w.addDocument(doc);
    doc = new Document();
    doc.add(newStringField("id", "4", Field.Store.YES));
    doc.add(newTextField("field", "wizard oz the the the the the the", Field.Store.NO));
    doc.add(new FloatDocValuesField("final-score", 5.0f));
    w.addDocument(doc);
    final IndexReader r = w.getReader();
    w.close();
    // Do ordinary BooleanQuery:
    final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
    bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
    bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
    final IndexSearcher searcher = getSearcher(r);
    // first run the standard query
    TopDocs hits = searcher.search(bqBuilder.build(), 10);
    assertEquals(5, hits.totalHits);
    assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
    assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
    assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id"));
    assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id"));
    assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id"));
    final List<Feature> features = makeFieldValueFeatures(new int[] { 0, 1, 2 }, "final-score");
    final List<Normalizer> norms = new ArrayList<Normalizer>(Collections.nCopies(features.size(), IdentityNormalizer.INSTANCE));
    final List<Feature> allFeatures = makeFieldValueFeatures(new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }, "final-score");
    final LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test", features, norms, "test", allFeatures, null);
    final LTRRescorer rescorer = new LTRRescorer(new LTRScoringQuery(ltrScoringModel));
    // rerank @ 0 should not change the order
    hits = rescorer.rescore(searcher, hits, 0);
    assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
    assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
    assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id"));
    assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id"));
    assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id"));
    for (int topN = 1; topN <= 5; topN++) {
        log.info("rerank {} documents ", topN);
        hits = searcher.search(bqBuilder.build(), 10);
        final ScoreDoc[] slice = new ScoreDoc[topN];
        System.arraycopy(hits.scoreDocs, 0, slice, 0, topN);
        hits = new TopDocs(hits.totalHits, slice, hits.getMaxScore());
        hits = rescorer.rescore(searcher, hits, topN);
        for (int i = topN - 1, j = 0; i >= 0; i--, j++) {
            log.info("doc {} in pos {}", searcher.doc(hits.scoreDocs[j].doc).get("id"), j);
            assertEquals(i, Integer.parseInt(searcher.doc(hits.scoreDocs[j].doc).get("id")));
            assertEquals(i + 1, hits.scoreDocs[j].score, 0.00001);
        }
    }
    r.close();
    dir.close();
}
Also used : IndexSearcher(org.apache.lucene.search.IndexSearcher) BooleanQuery(org.apache.lucene.search.BooleanQuery) TermQuery(org.apache.lucene.search.TermQuery) Normalizer(org.apache.solr.ltr.norm.Normalizer) IdentityNormalizer(org.apache.solr.ltr.norm.IdentityNormalizer) ArrayList(java.util.ArrayList) FloatDocValuesField(org.apache.lucene.document.FloatDocValuesField) Term(org.apache.lucene.index.Term) Document(org.apache.lucene.document.Document) FieldValueFeature(org.apache.solr.ltr.feature.FieldValueFeature) Feature(org.apache.solr.ltr.feature.Feature) ScoreDoc(org.apache.lucene.search.ScoreDoc) TopDocs(org.apache.lucene.search.TopDocs) IndexReader(org.apache.lucene.index.IndexReader) RandomIndexWriter(org.apache.lucene.index.RandomIndexWriter) LTRScoringModel(org.apache.solr.ltr.model.LTRScoringModel) Directory(org.apache.lucene.store.Directory) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 3 with LTRScoringModel

use of org.apache.solr.ltr.model.LTRScoringModel in project lucene-solr by apache.

the class TestLTRScoringQuery method testLTRScoringQueryEquality.

@Test
public void testLTRScoringQueryEquality() throws ModelException {
    final List<Feature> features = makeFeatures(new int[] { 0, 1, 2 });
    final List<Normalizer> norms = new ArrayList<Normalizer>(Collections.nCopies(features.size(), IdentityNormalizer.INSTANCE));
    final List<Feature> allFeatures = makeFeatures(new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
    final Map<String, Object> modelParams = makeFeatureWeights(features);
    final LTRScoringModel algorithm1 = TestLinearModel.createLinearModel("testModelName", features, norms, "testStoreName", allFeatures, modelParams);
    final LTRScoringQuery m0 = new LTRScoringQuery(algorithm1);
    final HashMap<String, String[]> externalFeatureInfo = new HashMap<>();
    externalFeatureInfo.put("queryIntent", new String[] { "company" });
    externalFeatureInfo.put("user_query", new String[] { "abc" });
    final LTRScoringQuery m1 = new LTRScoringQuery(algorithm1, externalFeatureInfo, false, null);
    final HashMap<String, String[]> externalFeatureInfo2 = new HashMap<>();
    externalFeatureInfo2.put("user_query", new String[] { "abc" });
    externalFeatureInfo2.put("queryIntent", new String[] { "company" });
    int totalPoolThreads = 10, numThreadsPerRequest = 10;
    LTRThreadModule threadManager = new LTRThreadModule(totalPoolThreads, numThreadsPerRequest);
    final LTRScoringQuery m2 = new LTRScoringQuery(algorithm1, externalFeatureInfo2, false, threadManager);
    // Models with same algorithm and efis, just in different order should be the same
    assertEquals(m1, m2);
    assertEquals(m1.hashCode(), m2.hashCode());
    // Models with same algorithm, but different efi content should not match
    assertFalse(m1.equals(m0));
    assertFalse(m1.hashCode() == m0.hashCode());
    final LTRScoringModel algorithm2 = TestLinearModel.createLinearModel("testModelName2", features, norms, "testStoreName", allFeatures, modelParams);
    final LTRScoringQuery m3 = new LTRScoringQuery(algorithm2);
    assertFalse(m1.equals(m3));
    assertFalse(m1.hashCode() == m3.hashCode());
    final LTRScoringModel algorithm3 = TestLinearModel.createLinearModel("testModelName", features, norms, "testStoreName3", allFeatures, modelParams);
    final LTRScoringQuery m4 = new LTRScoringQuery(algorithm3);
    assertFalse(m1.equals(m4));
    assertFalse(m1.hashCode() == m4.hashCode());
}
Also used : HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Normalizer(org.apache.solr.ltr.norm.Normalizer) IdentityNormalizer(org.apache.solr.ltr.norm.IdentityNormalizer) ArrayList(java.util.ArrayList) ValueFeature(org.apache.solr.ltr.feature.ValueFeature) Feature(org.apache.solr.ltr.feature.Feature) LTRScoringModel(org.apache.solr.ltr.model.LTRScoringModel) Test(org.junit.Test)

Example 4 with LTRScoringModel

use of org.apache.solr.ltr.model.LTRScoringModel in project lucene-solr by apache.

the class TestRerankBase method createModelFromFiles.

public static LTRScoringModel createModelFromFiles(String modelFileName, String featureFileName, String featureStoreName) throws ModelException, Exception {
    URL url = TestRerankBase.class.getResource("/modelExamples/" + modelFileName);
    final String modelJson = FileUtils.readFileToString(new File(url.toURI()), "UTF-8");
    final ManagedModelStore ms = getManagedModelStore();
    url = TestRerankBase.class.getResource("/featureExamples/" + featureFileName);
    final String featureJson = FileUtils.readFileToString(new File(url.toURI()), "UTF-8");
    Object parsedFeatureJson = null;
    try {
        parsedFeatureJson = ObjectBuilder.fromJSON(featureJson);
    } catch (final IOException ioExc) {
        throw new ModelException("ObjectBuilder failed parsing json", ioExc);
    }
    final ManagedFeatureStore fs = getManagedFeatureStore();
    // fs.getFeatureStore(null).clear();
    // is this safe??
    fs.doDeleteChild(null, featureStoreName);
    // based on my need to call this I dont think that
    // "getNewManagedFeatureStore()"
    // is actually returning a new feature store each time
    fs.applyUpdatesToManagedData(parsedFeatureJson);
    // can we skip this and just use fs directly below?
    ms.setManagedFeatureStore(fs);
    final LTRScoringModel ltrScoringModel = ManagedModelStore.fromLTRScoringModelMap(solrResourceLoader, mapFromJson(modelJson), ms.getManagedFeatureStore());
    ms.addModel(ltrScoringModel);
    return ltrScoringModel;
}
Also used : ManagedModelStore(org.apache.solr.ltr.store.rest.ManagedModelStore) ModelException(org.apache.solr.ltr.model.ModelException) IOException(java.io.IOException) File(java.io.File) URL(java.net.URL) ManagedFeatureStore(org.apache.solr.ltr.store.rest.ManagedFeatureStore) LTRScoringModel(org.apache.solr.ltr.model.LTRScoringModel)

Example 5 with LTRScoringModel

use of org.apache.solr.ltr.model.LTRScoringModel in project lucene-solr by apache.

the class TestSelectiveWeightCreation method testScoringQueryWeightCreation.

@Test
public void testScoringQueryWeightCreation() throws IOException, ModelException {
    final Directory dir = newDirectory();
    final RandomIndexWriter w = new RandomIndexWriter(random(), dir);
    Document doc = new Document();
    doc.add(newStringField("id", "10", Field.Store.YES));
    doc.add(newTextField("field", "wizard the the the the the oz", Field.Store.NO));
    doc.add(new FloatDocValuesField("final-score", 1.0f));
    w.addDocument(doc);
    doc = new Document();
    doc.add(newStringField("id", "11", Field.Store.YES));
    // 1 extra token, but wizard and oz are close;
    doc.add(newTextField("field", "wizard oz the the the the the the", Field.Store.NO));
    doc.add(new FloatDocValuesField("final-score", 2.0f));
    w.addDocument(doc);
    final IndexReader r = w.getReader();
    w.close();
    // Do ordinary BooleanQuery:
    final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
    bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
    bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
    final IndexSearcher searcher = getSearcher(r);
    // first run the standard query
    final TopDocs hits = searcher.search(bqBuilder.build(), 10);
    assertEquals(2, hits.totalHits);
    assertEquals("10", searcher.doc(hits.scoreDocs[0].doc).get("id"));
    assertEquals("11", searcher.doc(hits.scoreDocs[1].doc).get("id"));
    List<Feature> features = makeFeatures(new int[] { 0, 1, 2 });
    final List<Feature> allFeatures = makeFeatures(new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
    final List<Normalizer> norms = new ArrayList<>();
    for (int k = 0; k < features.size(); ++k) {
        norms.add(IdentityNormalizer.INSTANCE);
    }
    // when features are NOT requested in the response, only the modelFeature weights should be created
    final LTRScoringModel ltrScoringModel1 = TestLinearModel.createLinearModel("test", features, norms, "test", allFeatures, makeFeatureWeights(features));
    LTRScoringQuery.ModelWeight modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, // features not requested in response
    new LTRScoringQuery(ltrScoringModel1, false));
    LTRScoringQuery.FeatureInfo[] featuresInfo = modelWeight.getFeaturesInfo();
    assertEquals(features.size(), modelWeight.getModelFeatureValuesNormalized().length);
    int validFeatures = 0;
    for (int i = 0; i < featuresInfo.length; ++i) {
        if (featuresInfo[i] != null && featuresInfo[i].isUsed()) {
            validFeatures += 1;
        }
    }
    assertEquals(validFeatures, features.size());
    // when features are requested in the response, weights should be created for all features
    final LTRScoringModel ltrScoringModel2 = TestLinearModel.createLinearModel("test", features, norms, "test", allFeatures, makeFeatureWeights(features));
    modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, // features requested in response
    new LTRScoringQuery(ltrScoringModel2, true));
    featuresInfo = modelWeight.getFeaturesInfo();
    assertEquals(features.size(), modelWeight.getModelFeatureValuesNormalized().length);
    assertEquals(allFeatures.size(), modelWeight.getExtractedFeatureWeights().length);
    validFeatures = 0;
    for (int i = 0; i < featuresInfo.length; ++i) {
        if (featuresInfo[i] != null && featuresInfo[i].isUsed()) {
            validFeatures += 1;
        }
    }
    assertEquals(validFeatures, allFeatures.size());
    assertU(delI("10"));
    assertU(delI("11"));
    r.close();
    dir.close();
}
Also used : IndexSearcher(org.apache.lucene.search.IndexSearcher) BooleanQuery(org.apache.lucene.search.BooleanQuery) TermQuery(org.apache.lucene.search.TermQuery) Normalizer(org.apache.solr.ltr.norm.Normalizer) IdentityNormalizer(org.apache.solr.ltr.norm.IdentityNormalizer) ArrayList(java.util.ArrayList) FloatDocValuesField(org.apache.lucene.document.FloatDocValuesField) Term(org.apache.lucene.index.Term) Document(org.apache.lucene.document.Document) ValueFeature(org.apache.solr.ltr.feature.ValueFeature) Feature(org.apache.solr.ltr.feature.Feature) TopDocs(org.apache.lucene.search.TopDocs) IndexReader(org.apache.lucene.index.IndexReader) RandomIndexWriter(org.apache.lucene.index.RandomIndexWriter) LTRScoringModel(org.apache.solr.ltr.model.LTRScoringModel) Directory(org.apache.lucene.store.Directory) Test(org.junit.Test)

Aggregations

LTRScoringModel (org.apache.solr.ltr.model.LTRScoringModel)7 ArrayList (java.util.ArrayList)5 Feature (org.apache.solr.ltr.feature.Feature)5 IdentityNormalizer (org.apache.solr.ltr.norm.IdentityNormalizer)5 Normalizer (org.apache.solr.ltr.norm.Normalizer)5 Test (org.junit.Test)5 Document (org.apache.lucene.document.Document)4 FloatDocValuesField (org.apache.lucene.document.FloatDocValuesField)4 IndexReader (org.apache.lucene.index.IndexReader)4 RandomIndexWriter (org.apache.lucene.index.RandomIndexWriter)4 Term (org.apache.lucene.index.Term)4 BooleanQuery (org.apache.lucene.search.BooleanQuery)4 IndexSearcher (org.apache.lucene.search.IndexSearcher)4 TermQuery (org.apache.lucene.search.TermQuery)4 TopDocs (org.apache.lucene.search.TopDocs)4 Directory (org.apache.lucene.store.Directory)4 ValueFeature (org.apache.solr.ltr.feature.ValueFeature)3 ModelException (org.apache.solr.ltr.model.ModelException)3 FieldValueFeature (org.apache.solr.ltr.feature.FieldValueFeature)2 Ignore (org.junit.Ignore)2