use of org.apache.lucene.search.Collector in project lucene-solr by apache.
the class JoinUtil method createJoinQuery.
/**
* Method for query time joining for numeric fields. It supports multi- and single- values longs, ints, floats and longs.
* All considerations from {@link JoinUtil#createJoinQuery(String, boolean, String, Query, IndexSearcher, ScoreMode)} are applicable here too,
* though memory consumption might be higher.
* <p>
*
* @param fromField The from field to join from
* @param multipleValuesPerDocument Whether the from field has multiple terms per document
* when true fromField might be {@link DocValuesType#SORTED_NUMERIC},
* otherwise fromField should be {@link DocValuesType#NUMERIC}
* @param toField The to field to join to, should be {@link IntPoint}, {@link LongPoint}, {@link FloatPoint}
* or {@link DoublePoint}.
* @param numericType either {@link java.lang.Integer}, {@link java.lang.Long}, {@link java.lang.Float}
* or {@link java.lang.Double} it should correspond to toField types
* @param fromQuery The query to match documents on the from side
* @param fromSearcher The searcher that executed the specified fromQuery
* @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query
* @return a {@link Query} instance that can be used to join documents based on the
* terms in the from and to field
* @throws IOException If I/O related errors occur
*/
public static Query createJoinQuery(String fromField, boolean multipleValuesPerDocument, String toField, Class<? extends Number> numericType, Query fromQuery, IndexSearcher fromSearcher, ScoreMode scoreMode) throws IOException {
TreeSet<Long> joinValues = new TreeSet<>();
Map<Long, Float> aggregatedScores = new HashMap<>();
Map<Long, Integer> occurrences = new HashMap<>();
boolean needsScore = scoreMode != ScoreMode.None;
BiConsumer<Long, Float> scoreAggregator;
if (scoreMode == ScoreMode.Max) {
scoreAggregator = (key, score) -> {
Float currentValue = aggregatedScores.putIfAbsent(key, score);
if (currentValue != null) {
aggregatedScores.put(key, Math.max(currentValue, score));
}
};
} else if (scoreMode == ScoreMode.Min) {
scoreAggregator = (key, score) -> {
Float currentValue = aggregatedScores.putIfAbsent(key, score);
if (currentValue != null) {
aggregatedScores.put(key, Math.min(currentValue, score));
}
};
} else if (scoreMode == ScoreMode.Total) {
scoreAggregator = (key, score) -> {
Float currentValue = aggregatedScores.putIfAbsent(key, score);
if (currentValue != null) {
aggregatedScores.put(key, currentValue + score);
}
};
} else if (scoreMode == ScoreMode.Avg) {
scoreAggregator = (key, score) -> {
Float currentSore = aggregatedScores.putIfAbsent(key, score);
if (currentSore != null) {
aggregatedScores.put(key, currentSore + score);
}
Integer currentOccurrence = occurrences.putIfAbsent(key, 1);
if (currentOccurrence != null) {
occurrences.put(key, ++currentOccurrence);
}
};
} else {
scoreAggregator = (key, score) -> {
throw new UnsupportedOperationException();
};
}
LongFunction<Float> joinScorer;
if (scoreMode == ScoreMode.Avg) {
joinScorer = (joinValue) -> {
Float aggregatedScore = aggregatedScores.get(joinValue);
Integer occurrence = occurrences.get(joinValue);
return aggregatedScore / occurrence;
};
} else {
joinScorer = aggregatedScores::get;
}
Collector collector;
if (multipleValuesPerDocument) {
collector = new SimpleCollector() {
SortedNumericDocValues sortedNumericDocValues;
Scorer scorer;
@Override
public void collect(int doc) throws IOException {
if (doc > sortedNumericDocValues.docID()) {
sortedNumericDocValues.advance(doc);
}
if (doc == sortedNumericDocValues.docID()) {
for (int i = 0; i < sortedNumericDocValues.docValueCount(); i++) {
long value = sortedNumericDocValues.nextValue();
joinValues.add(value);
if (needsScore) {
scoreAggregator.accept(value, scorer.score());
}
}
}
}
@Override
protected void doSetNextReader(LeafReaderContext context) throws IOException {
sortedNumericDocValues = DocValues.getSortedNumeric(context.reader(), fromField);
}
@Override
public void setScorer(Scorer scorer) throws IOException {
this.scorer = scorer;
}
@Override
public boolean needsScores() {
return needsScore;
}
};
} else {
collector = new SimpleCollector() {
NumericDocValues numericDocValues;
Scorer scorer;
private int lastDocID = -1;
private boolean docsInOrder(int docID) {
if (docID < lastDocID) {
throw new AssertionError("docs out of order: lastDocID=" + lastDocID + " vs docID=" + docID);
}
lastDocID = docID;
return true;
}
@Override
public void collect(int doc) throws IOException {
assert docsInOrder(doc);
int dvDocID = numericDocValues.docID();
if (dvDocID < doc) {
dvDocID = numericDocValues.advance(doc);
}
long value;
if (dvDocID == doc) {
value = numericDocValues.longValue();
} else {
value = 0;
}
joinValues.add(value);
if (needsScore) {
scoreAggregator.accept(value, scorer.score());
}
}
@Override
protected void doSetNextReader(LeafReaderContext context) throws IOException {
numericDocValues = DocValues.getNumeric(context.reader(), fromField);
lastDocID = -1;
}
@Override
public void setScorer(Scorer scorer) throws IOException {
this.scorer = scorer;
}
@Override
public boolean needsScores() {
return needsScore;
}
};
}
fromSearcher.search(fromQuery, collector);
Iterator<Long> iterator = joinValues.iterator();
final int bytesPerDim;
final BytesRef encoded = new BytesRef();
final PointInSetIncludingScoreQuery.Stream stream;
if (Integer.class.equals(numericType)) {
bytesPerDim = Integer.BYTES;
stream = new PointInSetIncludingScoreQuery.Stream() {
@Override
public BytesRef next() {
if (iterator.hasNext()) {
long value = iterator.next();
IntPoint.encodeDimension((int) value, encoded.bytes, 0);
if (needsScore) {
score = joinScorer.apply(value);
}
return encoded;
} else {
return null;
}
}
};
} else if (Long.class.equals(numericType)) {
bytesPerDim = Long.BYTES;
stream = new PointInSetIncludingScoreQuery.Stream() {
@Override
public BytesRef next() {
if (iterator.hasNext()) {
long value = iterator.next();
LongPoint.encodeDimension(value, encoded.bytes, 0);
if (needsScore) {
score = joinScorer.apply(value);
}
return encoded;
} else {
return null;
}
}
};
} else if (Float.class.equals(numericType)) {
bytesPerDim = Float.BYTES;
stream = new PointInSetIncludingScoreQuery.Stream() {
@Override
public BytesRef next() {
if (iterator.hasNext()) {
long value = iterator.next();
FloatPoint.encodeDimension(Float.intBitsToFloat((int) value), encoded.bytes, 0);
if (needsScore) {
score = joinScorer.apply(value);
}
return encoded;
} else {
return null;
}
}
};
} else if (Double.class.equals(numericType)) {
bytesPerDim = Double.BYTES;
stream = new PointInSetIncludingScoreQuery.Stream() {
@Override
public BytesRef next() {
if (iterator.hasNext()) {
long value = iterator.next();
DoublePoint.encodeDimension(Double.longBitsToDouble(value), encoded.bytes, 0);
if (needsScore) {
score = joinScorer.apply(value);
}
return encoded;
} else {
return null;
}
}
};
} else {
throw new IllegalArgumentException("unsupported numeric type, only Integer, Long, Float and Double are supported");
}
encoded.bytes = new byte[bytesPerDim];
encoded.length = bytesPerDim;
if (needsScore) {
return new PointInSetIncludingScoreQuery(scoreMode, fromQuery, multipleValuesPerDocument, toField, bytesPerDim, stream) {
@Override
protected String toString(byte[] value) {
return toString.apply(value, numericType);
}
};
} else {
return new PointInSetQuery(toField, 1, bytesPerDim, stream) {
@Override
protected String toString(byte[] value) {
return PointInSetIncludingScoreQuery.toString.apply(value, numericType);
}
};
}
}
use of org.apache.lucene.search.Collector in project lucene-solr by apache.
the class SimpleFacets method getGroupedCounts.
public NamedList<Integer> getGroupedCounts(SolrIndexSearcher searcher, DocSet base, String field, boolean multiToken, int offset, int limit, int mincount, boolean missing, String sort, String prefix, Predicate<BytesRef> termFilter) throws IOException {
GroupingSpecification groupingSpecification = rb.getGroupingSpec();
final String groupField = groupingSpecification != null ? groupingSpecification.getFields()[0] : null;
if (groupField == null) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "Specify the group.field as parameter or local parameter");
}
BytesRef prefixBytesRef = prefix != null ? new BytesRef(prefix) : null;
final TermGroupFacetCollector collector = TermGroupFacetCollector.createTermGroupFacetCollector(groupField, field, multiToken, prefixBytesRef, 128);
Collector groupWrapper = getInsanityWrapper(groupField, collector);
Collector fieldWrapper = getInsanityWrapper(field, groupWrapper);
// When GroupedFacetCollector can handle numerics we can remove the wrapped collectors
searcher.search(base.getTopFilter(), fieldWrapper);
boolean orderByCount = sort.equals(FacetParams.FACET_SORT_COUNT) || sort.equals(FacetParams.FACET_SORT_COUNT_LEGACY);
TermGroupFacetCollector.GroupedFacetResult result = collector.mergeSegmentResults(limit < 0 ? Integer.MAX_VALUE : (offset + limit), mincount, orderByCount);
CharsRefBuilder charsRef = new CharsRefBuilder();
FieldType facetFieldType = searcher.getSchema().getFieldType(field);
NamedList<Integer> facetCounts = new NamedList<>();
List<TermGroupFacetCollector.FacetEntry> scopedEntries = result.getFacetEntries(offset, limit < 0 ? Integer.MAX_VALUE : limit);
for (TermGroupFacetCollector.FacetEntry facetEntry : scopedEntries) {
//:TODO:can we filter earlier than this to make it more efficient?
if (termFilter != null && !termFilter.test(facetEntry.getValue())) {
continue;
}
facetFieldType.indexedToReadable(facetEntry.getValue(), charsRef);
facetCounts.add(charsRef.toString(), facetEntry.getCount());
}
if (missing) {
facetCounts.add(null, result.getTotalMissingCount());
}
return facetCounts;
}
use of org.apache.lucene.search.Collector in project lucene-solr by apache.
the class GroupingSearch method groupByFieldOrFunction.
@SuppressWarnings({ "unchecked", "rawtypes" })
protected TopGroups groupByFieldOrFunction(IndexSearcher searcher, Query query, int groupOffset, int groupLimit) throws IOException {
int topN = groupOffset + groupLimit;
final FirstPassGroupingCollector firstPassCollector = new FirstPassGroupingCollector(grouper, groupSort, topN);
final AllGroupsCollector allGroupsCollector = allGroups ? new AllGroupsCollector(grouper) : null;
final AllGroupHeadsCollector allGroupHeadsCollector = allGroupHeads ? AllGroupHeadsCollector.newCollector(grouper, sortWithinGroup) : null;
final Collector firstRound = MultiCollector.wrap(firstPassCollector, allGroupsCollector, allGroupHeadsCollector);
CachingCollector cachedCollector = null;
if (maxCacheRAMMB != null || maxDocsToCache != null) {
if (maxCacheRAMMB != null) {
cachedCollector = CachingCollector.create(firstRound, cacheScores, maxCacheRAMMB);
} else {
cachedCollector = CachingCollector.create(firstRound, cacheScores, maxDocsToCache);
}
searcher.search(query, cachedCollector);
} else {
searcher.search(query, firstRound);
}
matchingGroups = allGroups ? allGroupsCollector.getGroups() : Collections.emptyList();
matchingGroupHeads = allGroupHeads ? allGroupHeadsCollector.retrieveGroupHeads(searcher.getIndexReader().maxDoc()) : new Bits.MatchNoBits(searcher.getIndexReader().maxDoc());
Collection<SearchGroup> topSearchGroups = firstPassCollector.getTopGroups(groupOffset, fillSortFields);
if (topSearchGroups == null) {
return new TopGroups(new SortField[0], new SortField[0], 0, 0, new GroupDocs[0], Float.NaN);
}
int topNInsideGroup = groupDocsOffset + groupDocsLimit;
TopGroupsCollector secondPassCollector = new TopGroupsCollector(grouper, topSearchGroups, groupSort, sortWithinGroup, topNInsideGroup, includeScores, includeMaxScore, fillSortFields);
if (cachedCollector != null && cachedCollector.isCached()) {
cachedCollector.replay(secondPassCollector);
} else {
searcher.search(query, secondPassCollector);
}
if (allGroups) {
return new TopGroups(secondPassCollector.getTopGroups(groupDocsOffset), matchingGroups.size());
} else {
return secondPassCollector.getTopGroups(groupDocsOffset);
}
}
use of org.apache.lucene.search.Collector in project lucene-solr by apache.
the class CommandHandler method execute.
@SuppressWarnings("unchecked")
public void execute() throws IOException {
final int nrOfCommands = commands.size();
List<Collector> collectors = new ArrayList<>(nrOfCommands);
for (Command command : commands) {
collectors.addAll(command.create());
}
ProcessedFilter filter = searcher.getProcessedFilter(queryCommand.getFilter(), queryCommand.getFilterList());
Query query = QueryUtils.makeQueryable(queryCommand.getQuery());
if (truncateGroups) {
docSet = computeGroupedDocSet(query, filter, collectors);
} else if (needDocset) {
docSet = computeDocSet(query, filter, collectors);
} else if (!collectors.isEmpty()) {
searchWithTimeLimiter(query, filter, MultiCollector.wrap(collectors.toArray(new Collector[nrOfCommands])));
} else {
searchWithTimeLimiter(query, filter, null);
}
}
use of org.apache.lucene.search.Collector in project lucene-solr by apache.
the class SearchGroupsFieldCommand method create.
@Override
public List<Collector> create() throws IOException {
final List<Collector> collectors = new ArrayList<>(2);
final FieldType fieldType = field.getType();
if (topNGroups > 0) {
if (fieldType.getNumberType() != null) {
ValueSource vs = fieldType.getValueSource(field, null);
firstPassGroupingCollector = new FirstPassGroupingCollector<>(new ValueSourceGroupSelector(vs, new HashMap<>()), groupSort, topNGroups);
} else {
firstPassGroupingCollector = new FirstPassGroupingCollector<>(new TermGroupSelector(field.getName()), groupSort, topNGroups);
}
collectors.add(firstPassGroupingCollector);
}
if (includeGroupCount) {
if (fieldType.getNumberType() != null) {
ValueSource vs = fieldType.getValueSource(field, null);
allGroupsCollector = new AllGroupsCollector<>(new ValueSourceGroupSelector(vs, new HashMap<>()));
} else {
allGroupsCollector = new AllGroupsCollector<>(new TermGroupSelector(field.getName()));
}
collectors.add(allGroupsCollector);
}
return collectors;
}
Aggregations