use of org.apache.lucene.search.grouping.SearchGroup in project lucene-solr by apache.
the class SearchGroupShardResponseProcessor method process.
/**
* {@inheritDoc}
*/
@Override
public void process(ResponseBuilder rb, ShardRequest shardRequest) {
SortSpec groupSortSpec = rb.getGroupingSpec().getGroupSortSpec();
Sort groupSort = rb.getGroupingSpec().getGroupSort();
final String[] fields = rb.getGroupingSpec().getFields();
Sort withinGroupSort = rb.getGroupingSpec().getSortWithinGroup();
assert withinGroupSort != null;
final Map<String, List<Collection<SearchGroup<BytesRef>>>> commandSearchGroups = new HashMap<>(fields.length, 1.0f);
final Map<String, Map<SearchGroup<BytesRef>, Set<String>>> tempSearchGroupToShards = new HashMap<>(fields.length, 1.0f);
for (String field : fields) {
commandSearchGroups.put(field, new ArrayList<Collection<SearchGroup<BytesRef>>>(shardRequest.responses.size()));
tempSearchGroupToShards.put(field, new HashMap<SearchGroup<BytesRef>, Set<String>>());
if (!rb.searchGroupToShards.containsKey(field)) {
rb.searchGroupToShards.put(field, new HashMap<SearchGroup<BytesRef>, Set<String>>());
}
}
SearchGroupsResultTransformer serializer = new SearchGroupsResultTransformer(rb.req.getSearcher());
int maxElapsedTime = 0;
int hitCountDuringFirstPhase = 0;
NamedList<Object> shardInfo = null;
if (rb.req.getParams().getBool(ShardParams.SHARDS_INFO, false)) {
shardInfo = new SimpleOrderedMap<>(shardRequest.responses.size());
rb.rsp.getValues().add(ShardParams.SHARDS_INFO + ".firstPhase", shardInfo);
}
for (ShardResponse srsp : shardRequest.responses) {
if (shardInfo != null) {
SimpleOrderedMap<Object> nl = new SimpleOrderedMap<>(4);
if (srsp.getException() != null) {
Throwable t = srsp.getException();
if (t instanceof SolrServerException) {
t = ((SolrServerException) t).getCause();
}
nl.add("error", t.toString());
StringWriter trace = new StringWriter();
t.printStackTrace(new PrintWriter(trace));
nl.add("trace", trace.toString());
} else {
nl.add("numFound", (Integer) srsp.getSolrResponse().getResponse().get("totalHitCount"));
}
if (srsp.getSolrResponse() != null) {
nl.add("time", srsp.getSolrResponse().getElapsedTime());
}
if (srsp.getShardAddress() != null) {
nl.add("shardAddress", srsp.getShardAddress());
}
shardInfo.add(srsp.getShard(), nl);
}
if (rb.req.getParams().getBool(ShardParams.SHARDS_TOLERANT, false) && srsp.getException() != null) {
if (rb.rsp.getResponseHeader().get(SolrQueryResponse.RESPONSE_HEADER_PARTIAL_RESULTS_KEY) == null) {
rb.rsp.getResponseHeader().add(SolrQueryResponse.RESPONSE_HEADER_PARTIAL_RESULTS_KEY, Boolean.TRUE);
}
// continue if there was an error and we're tolerant.
continue;
}
maxElapsedTime = (int) Math.max(maxElapsedTime, srsp.getSolrResponse().getElapsedTime());
@SuppressWarnings("unchecked") NamedList<NamedList> firstPhaseResult = (NamedList<NamedList>) srsp.getSolrResponse().getResponse().get("firstPhase");
final Map<String, SearchGroupsFieldCommandResult> result = serializer.transformToNative(firstPhaseResult, groupSort, withinGroupSort, srsp.getShard());
for (String field : commandSearchGroups.keySet()) {
final SearchGroupsFieldCommandResult firstPhaseCommandResult = result.get(field);
final Integer groupCount = firstPhaseCommandResult.getGroupCount();
if (groupCount != null) {
Integer existingGroupCount = rb.mergedGroupCounts.get(field);
// Assuming groups don't cross shard boundary...
rb.mergedGroupCounts.put(field, existingGroupCount != null ? Integer.valueOf(existingGroupCount + groupCount) : groupCount);
}
final Collection<SearchGroup<BytesRef>> searchGroups = firstPhaseCommandResult.getSearchGroups();
if (searchGroups == null) {
continue;
}
commandSearchGroups.get(field).add(searchGroups);
for (SearchGroup<BytesRef> searchGroup : searchGroups) {
Map<SearchGroup<BytesRef>, Set<String>> map = tempSearchGroupToShards.get(field);
Set<String> shards = map.get(searchGroup);
if (shards == null) {
shards = new HashSet<>();
map.put(searchGroup, shards);
}
shards.add(srsp.getShard());
}
}
hitCountDuringFirstPhase += (Integer) srsp.getSolrResponse().getResponse().get("totalHitCount");
}
rb.totalHitCount = hitCountDuringFirstPhase;
rb.firstPhaseElapsedTime = maxElapsedTime;
for (String groupField : commandSearchGroups.keySet()) {
List<Collection<SearchGroup<BytesRef>>> topGroups = commandSearchGroups.get(groupField);
Collection<SearchGroup<BytesRef>> mergedTopGroups = SearchGroup.merge(topGroups, groupSortSpec.getOffset(), groupSortSpec.getCount(), groupSort);
if (mergedTopGroups == null) {
continue;
}
rb.mergedSearchGroups.put(groupField, mergedTopGroups);
for (SearchGroup<BytesRef> mergedTopGroup : mergedTopGroups) {
rb.searchGroupToShards.get(groupField).put(mergedTopGroup, tempSearchGroupToShards.get(groupField).get(mergedTopGroup));
}
}
}
use of org.apache.lucene.search.grouping.SearchGroup in project lucene-solr by apache.
the class SearchGroupsResultTransformer method transformToNative.
/**
* {@inheritDoc}
*/
@Override
public Map<String, SearchGroupsFieldCommandResult> transformToNative(NamedList<NamedList> shardResponse, Sort groupSort, Sort withinGroupSort, String shard) {
final Map<String, SearchGroupsFieldCommandResult> result = new HashMap<>(shardResponse.size());
for (Map.Entry<String, NamedList> command : shardResponse) {
List<SearchGroup<BytesRef>> searchGroups = new ArrayList<>();
NamedList topGroupsAndGroupCount = command.getValue();
@SuppressWarnings("unchecked") final NamedList<List<Comparable>> rawSearchGroups = (NamedList<List<Comparable>>) topGroupsAndGroupCount.get(TOP_GROUPS);
if (rawSearchGroups != null) {
for (Map.Entry<String, List<Comparable>> rawSearchGroup : rawSearchGroups) {
SearchGroup<BytesRef> searchGroup = new SearchGroup<>();
SchemaField groupField = rawSearchGroup.getKey() != null ? searcher.getSchema().getFieldOrNull(command.getKey()) : null;
searchGroup.groupValue = null;
if (rawSearchGroup.getKey() != null) {
if (groupField != null) {
BytesRefBuilder builder = new BytesRefBuilder();
groupField.getType().readableToIndexed(rawSearchGroup.getKey(), builder);
searchGroup.groupValue = builder.get();
} else {
searchGroup.groupValue = new BytesRef(rawSearchGroup.getKey());
}
}
searchGroup.sortValues = rawSearchGroup.getValue().toArray(new Comparable[rawSearchGroup.getValue().size()]);
for (int i = 0; i < searchGroup.sortValues.length; i++) {
SchemaField field = groupSort.getSort()[i].getField() != null ? searcher.getSchema().getFieldOrNull(groupSort.getSort()[i].getField()) : null;
searchGroup.sortValues[i] = ShardResultTransformerUtils.unmarshalSortValue(searchGroup.sortValues[i], field);
}
searchGroups.add(searchGroup);
}
}
final Integer groupCount = (Integer) topGroupsAndGroupCount.get(GROUP_COUNT);
result.put(command.getKey(), new SearchGroupsFieldCommandResult(groupCount, searchGroups));
}
return result;
}
Aggregations