use of org.apache.solr.ltr.feature.Feature in project lucene-solr by apache.
the class LTRScoringQuery method createWeightsParallel.
// end of call CreateWeightCallable
private void createWeightsParallel(IndexSearcher searcher, boolean needsScores, List<Feature.FeatureWeight> featureWeights, Collection<Feature> features) throws RuntimeException {
final SolrQueryRequest req = getRequest();
List<Future<Feature.FeatureWeight>> futures = new ArrayList<>(features.size());
try {
for (final Feature f : features) {
CreateWeightCallable callable = new CreateWeightCallable(f, searcher, needsScores, req);
RunnableFuture<Feature.FeatureWeight> runnableFuture = new FutureTask<>(callable);
// always acquire before the ltrSemaphore is acquired, to guarantee a that the current query is within the limit for max. threads
querySemaphore.acquire();
//may block and/or interrupt
ltrThreadMgr.acquireLTRSemaphore();
//releases semaphore when done
ltrThreadMgr.execute(runnableFuture);
futures.add(runnableFuture);
}
//Loop over futures to get the feature weight objects
for (final Future<Feature.FeatureWeight> future : futures) {
// future.get() will block if the job is still running
featureWeights.add(future.get());
}
} catch (Exception e) {
// To catch InterruptedException and ExecutionException
log.info("Error while creating weights in LTR: InterruptedException", e);
throw new RuntimeException("Error while creating weights in LTR: " + e.getMessage(), e);
}
}
use of org.apache.solr.ltr.feature.Feature in project lucene-solr by apache.
the class LTRScoringQuery method createWeights.
private void createWeights(IndexSearcher searcher, boolean needsScores, List<Feature.FeatureWeight> featureWeights, Collection<Feature> features) throws IOException {
final SolrQueryRequest req = getRequest();
// since the feature store is a linkedhashmap order is preserved
for (final Feature f : features) {
try {
Feature.FeatureWeight fw = f.createWeight(searcher, needsScores, req, originalQuery, efi);
featureWeights.add(fw);
} catch (final Exception e) {
throw new RuntimeException("Exception from createWeight for " + f.toString() + " " + e.getMessage(), e);
}
}
}
use of org.apache.solr.ltr.feature.Feature in project lucene-solr by apache.
the class LTRScoringQuery method createWeight.
@Override
public ModelWeight createWeight(IndexSearcher searcher, boolean needsScores, float boost) throws IOException {
final Collection<Feature> modelFeatures = ltrScoringModel.getFeatures();
final Collection<Feature> allFeatures = ltrScoringModel.getAllFeatures();
int modelFeatSize = modelFeatures.size();
Collection<Feature> features = null;
if (this.extractAllFeatures) {
features = allFeatures;
} else {
features = modelFeatures;
}
final Feature.FeatureWeight[] extractedFeatureWeights = new Feature.FeatureWeight[features.size()];
final Feature.FeatureWeight[] modelFeaturesWeights = new Feature.FeatureWeight[modelFeatSize];
List<Feature.FeatureWeight> featureWeights = new ArrayList<>(features.size());
if (querySemaphore == null) {
createWeights(searcher, needsScores, featureWeights, features);
} else {
createWeightsParallel(searcher, needsScores, featureWeights, features);
}
int i = 0, j = 0;
if (this.extractAllFeatures) {
for (final Feature.FeatureWeight fw : featureWeights) {
extractedFeatureWeights[i++] = fw;
}
for (final Feature f : modelFeatures) {
// we can lookup by featureid because all features will be extracted when this.extractAllFeatures is set
modelFeaturesWeights[j++] = extractedFeatureWeights[f.getIndex()];
}
} else {
for (final Feature.FeatureWeight fw : featureWeights) {
extractedFeatureWeights[i++] = fw;
modelFeaturesWeights[j++] = fw;
}
}
return new ModelWeight(modelFeaturesWeights, extractedFeatureWeights, allFeatures.size());
}
use of org.apache.solr.ltr.feature.Feature in project lucene-solr by apache.
the class TestLTRScoringQuery method testLTRScoringQuery.
@Test
public void testLTRScoringQuery() throws IOException, ModelException {
final Directory dir = newDirectory();
final RandomIndexWriter w = new RandomIndexWriter(random(), dir);
Document doc = new Document();
doc.add(newStringField("id", "0", Field.Store.YES));
doc.add(newTextField("field", "wizard the the the the the oz", Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 1.0f));
w.addDocument(doc);
doc = new Document();
doc.add(newStringField("id", "1", Field.Store.YES));
// 1 extra token, but wizard and oz are close;
doc.add(newTextField("field", "wizard oz the the the the the the", Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 2.0f));
w.addDocument(doc);
final IndexReader r = w.getReader();
w.close();
// Do ordinary BooleanQuery:
final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
final IndexSearcher searcher = getSearcher(r);
// first run the standard query
final TopDocs hits = searcher.search(bqBuilder.build(), 10);
assertEquals(2, hits.totalHits);
assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
List<Feature> features = makeFeatures(new int[] { 0, 1, 2 });
final List<Feature> allFeatures = makeFeatures(new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
List<Normalizer> norms = new ArrayList<Normalizer>(Collections.nCopies(features.size(), IdentityNormalizer.INSTANCE));
LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test", features, norms, "test", allFeatures, makeFeatureWeights(features));
LTRScoringQuery.ModelWeight modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, new LTRScoringQuery(ltrScoringModel));
assertEquals(3, modelWeight.getModelFeatureValuesNormalized().length);
for (int i = 0; i < 3; i++) {
assertEquals(i, modelWeight.getModelFeatureValuesNormalized()[i], 0.0001);
}
int[] posVals = new int[] { 0, 1, 2 };
int pos = 0;
for (LTRScoringQuery.FeatureInfo fInfo : modelWeight.getFeaturesInfo()) {
if (fInfo == null) {
continue;
}
assertEquals(posVals[pos], fInfo.getValue(), 0.0001);
assertEquals("f" + posVals[pos], fInfo.getName());
pos++;
}
final int[] mixPositions = new int[] { 8, 2, 4, 9, 0 };
features = makeFeatures(mixPositions);
norms = new ArrayList<Normalizer>(Collections.nCopies(features.size(), IdentityNormalizer.INSTANCE));
ltrScoringModel = TestLinearModel.createLinearModel("test", features, norms, "test", allFeatures, makeFeatureWeights(features));
modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, new LTRScoringQuery(ltrScoringModel));
assertEquals(mixPositions.length, modelWeight.getModelFeatureWeights().length);
for (int i = 0; i < mixPositions.length; i++) {
assertEquals(mixPositions[i], modelWeight.getModelFeatureValuesNormalized()[i], 0.0001);
}
final ModelException expectedModelException = new ModelException("no features declared for model test");
final int[] noPositions = new int[] {};
features = makeFeatures(noPositions);
norms = new ArrayList<Normalizer>(Collections.nCopies(features.size(), IdentityNormalizer.INSTANCE));
try {
ltrScoringModel = TestLinearModel.createLinearModel("test", features, norms, "test", allFeatures, makeFeatureWeights(features));
fail("unexpectedly got here instead of catching " + expectedModelException);
modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, new LTRScoringQuery(ltrScoringModel));
assertEquals(0, modelWeight.getModelFeatureWeights().length);
} catch (ModelException actualModelException) {
assertEquals(expectedModelException.toString(), actualModelException.toString());
}
// test normalizers
features = makeFilterFeatures(mixPositions);
final Normalizer norm = new Normalizer() {
@Override
public float normalize(float value) {
return 42.42f;
}
@Override
public LinkedHashMap<String, Object> paramsToMap() {
return null;
}
@Override
protected void validate() throws NormalizerException {
}
};
norms = new ArrayList<Normalizer>(Collections.nCopies(features.size(), norm));
final LTRScoringModel normMeta = TestLinearModel.createLinearModel("test", features, norms, "test", allFeatures, makeFeatureWeights(features));
modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, new LTRScoringQuery(normMeta));
normMeta.normalizeFeaturesInPlace(modelWeight.getModelFeatureValuesNormalized());
assertEquals(mixPositions.length, modelWeight.getModelFeatureWeights().length);
for (int i = 0; i < mixPositions.length; i++) {
assertEquals(42.42f, modelWeight.getModelFeatureValuesNormalized()[i], 0.0001);
}
r.close();
dir.close();
}
use of org.apache.solr.ltr.feature.Feature in project lucene-solr by apache.
the class ManagedFeatureStore method addFeature.
public synchronized void addFeature(Map<String, Object> map, String featureStore) {
log.info("register feature based on {}", map);
final FeatureStore fstore = getFeatureStore(featureStore);
final Feature feature = fromFeatureMap(solrResourceLoader, map);
fstore.add(feature);
}
Aggregations