use of org.opensearch.common.util.concurrent.CountDown in project OpenSearch by opensearch-project.
the class AzureBlobContainerRetriesTests method testWriteBlobWithRetries.
public void testWriteBlobWithRetries() throws Exception {
// The request retry policy counts the first attempt as retry, so we need to
// account for that and increase the max retry count by one.
final int maxRetries = randomIntBetween(2, 6);
final CountDown countDown = new CountDown(maxRetries - 1);
final byte[] bytes = randomBlobContent();
httpServer.createContext("/container/write_blob_max_retries", exchange -> {
if ("PUT".equals(exchange.getRequestMethod())) {
exchange.getResponseHeaders().add("x-ms-request-server-encrypted", "false");
if (countDown.countDown()) {
final BytesReference body = Streams.readFully(exchange.getRequestBody());
if (Objects.deepEquals(bytes, BytesReference.toBytes(body))) {
exchange.sendResponseHeaders(RestStatus.CREATED.getStatus(), -1);
} else {
AzureHttpHandler.sendError(exchange, RestStatus.BAD_REQUEST);
}
exchange.close();
return;
}
if (randomBoolean()) {
if (randomBoolean()) {
Streams.readFully(exchange.getRequestBody(), new byte[randomIntBetween(1, Math.max(1, bytes.length - 1))]);
} else {
Streams.readFully(exchange.getRequestBody());
AzureHttpHandler.sendError(exchange, randomFrom(RestStatus.INTERNAL_SERVER_ERROR, RestStatus.SERVICE_UNAVAILABLE));
}
}
exchange.close();
}
});
final BlobContainer blobContainer = createBlobContainer(maxRetries);
try (InputStream stream = new InputStreamIndexInput(new ByteArrayIndexInput("desc", bytes), bytes.length)) {
blobContainer.writeBlob("write_blob_max_retries", stream, bytes.length, false);
}
assertThat(countDown.isCountedDown(), is(true));
}
use of org.opensearch.common.util.concurrent.CountDown in project OpenSearch by opensearch-project.
the class AzureBlobContainerRetriesTests method testReadBlobWithRetries.
public void testReadBlobWithRetries() throws Exception {
// The request retry policy counts the first attempt as retry, so we need to
// account for that and increase the max retry count by one.
final int maxRetries = randomIntBetween(2, 6);
final CountDown countDownHead = new CountDown(maxRetries - 1);
final CountDown countDownGet = new CountDown(maxRetries - 1);
final byte[] bytes = randomBlobContent();
httpServer.createContext("/container/read_blob_max_retries", exchange -> {
try {
Streams.readFully(exchange.getRequestBody());
if ("HEAD".equals(exchange.getRequestMethod())) {
if (countDownHead.countDown()) {
exchange.getResponseHeaders().add("Content-Type", "application/octet-stream");
exchange.getResponseHeaders().add("Content-Length", String.valueOf(bytes.length));
exchange.getResponseHeaders().add("x-ms-blob-type", "blockblob");
exchange.sendResponseHeaders(RestStatus.OK.getStatus(), -1);
return;
}
} else if ("GET".equals(exchange.getRequestMethod())) {
if (countDownGet.countDown()) {
final int rangeStart = getRangeStart(exchange);
assertThat(rangeStart, lessThan(bytes.length));
final int length = bytes.length - rangeStart;
exchange.getResponseHeaders().add("Content-Type", "application/octet-stream");
exchange.getResponseHeaders().add("Content-Length", String.valueOf(length));
exchange.getResponseHeaders().add("x-ms-blob-type", "blockblob");
exchange.sendResponseHeaders(RestStatus.OK.getStatus(), length);
exchange.getResponseBody().write(bytes, rangeStart, length);
return;
}
}
if (randomBoolean()) {
AzureHttpHandler.sendError(exchange, randomFrom(RestStatus.INTERNAL_SERVER_ERROR, RestStatus.SERVICE_UNAVAILABLE));
}
} finally {
exchange.close();
}
});
final BlobContainer blobContainer = createBlobContainer(maxRetries);
try (InputStream inputStream = blobContainer.readBlob("read_blob_max_retries")) {
assertArrayEquals(bytes, BytesReference.toBytes(Streams.readFully(inputStream)));
assertThat(countDownHead.isCountedDown(), is(true));
assertThat(countDownGet.isCountedDown(), is(true));
}
}
use of org.opensearch.common.util.concurrent.CountDown in project OpenSearch by opensearch-project.
the class TransportSearchAction method ccsRemoteReduce.
static void ccsRemoteReduce(SearchRequest searchRequest, OriginalIndices localIndices, Map<String, OriginalIndices> remoteIndices, SearchTimeProvider timeProvider, InternalAggregation.ReduceContextBuilder aggReduceContextBuilder, RemoteClusterService remoteClusterService, ThreadPool threadPool, ActionListener<SearchResponse> listener, BiConsumer<SearchRequest, ActionListener<SearchResponse>> localSearchConsumer) {
if (localIndices == null && remoteIndices.size() == 1) {
// if we are searching against a single remote cluster, we simply forward the original search request to such cluster
// and we directly perform final reduction in the remote cluster
Map.Entry<String, OriginalIndices> entry = remoteIndices.entrySet().iterator().next();
String clusterAlias = entry.getKey();
boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias);
OriginalIndices indices = entry.getValue();
SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest(searchRequest, indices.indices(), clusterAlias, timeProvider.getAbsoluteStartMillis(), true);
Client remoteClusterClient = remoteClusterService.getRemoteClusterClient(threadPool, clusterAlias);
remoteClusterClient.search(ccsSearchRequest, new ActionListener<SearchResponse>() {
@Override
public void onResponse(SearchResponse searchResponse) {
Map<String, ProfileShardResult> profileResults = searchResponse.getProfileResults();
SearchProfileShardResults profile = profileResults == null || profileResults.isEmpty() ? null : new SearchProfileShardResults(profileResults);
InternalSearchResponse internalSearchResponse = new InternalSearchResponse(searchResponse.getHits(), (InternalAggregations) searchResponse.getAggregations(), searchResponse.getSuggest(), profile, searchResponse.isTimedOut(), searchResponse.isTerminatedEarly(), searchResponse.getNumReducePhases());
listener.onResponse(new SearchResponse(internalSearchResponse, searchResponse.getScrollId(), searchResponse.getTotalShards(), searchResponse.getSuccessfulShards(), searchResponse.getSkippedShards(), timeProvider.buildTookInMillis(), searchResponse.getShardFailures(), new SearchResponse.Clusters(1, 1, 0), searchResponse.pointInTimeId()));
}
@Override
public void onFailure(Exception e) {
if (skipUnavailable) {
listener.onResponse(SearchResponse.empty(timeProvider::buildTookInMillis, new SearchResponse.Clusters(1, 0, 1)));
} else {
listener.onFailure(wrapRemoteClusterFailure(clusterAlias, e));
}
}
});
} else {
SearchResponseMerger searchResponseMerger = createSearchResponseMerger(searchRequest.source(), timeProvider, aggReduceContextBuilder);
AtomicInteger skippedClusters = new AtomicInteger(0);
final AtomicReference<Exception> exceptions = new AtomicReference<>();
int totalClusters = remoteIndices.size() + (localIndices == null ? 0 : 1);
final CountDown countDown = new CountDown(totalClusters);
for (Map.Entry<String, OriginalIndices> entry : remoteIndices.entrySet()) {
String clusterAlias = entry.getKey();
boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias);
OriginalIndices indices = entry.getValue();
SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest(searchRequest, indices.indices(), clusterAlias, timeProvider.getAbsoluteStartMillis(), false);
ActionListener<SearchResponse> ccsListener = createCCSListener(clusterAlias, skipUnavailable, countDown, skippedClusters, exceptions, searchResponseMerger, totalClusters, listener);
Client remoteClusterClient = remoteClusterService.getRemoteClusterClient(threadPool, clusterAlias);
remoteClusterClient.search(ccsSearchRequest, ccsListener);
}
if (localIndices != null) {
ActionListener<SearchResponse> ccsListener = createCCSListener(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, false, countDown, skippedClusters, exceptions, searchResponseMerger, totalClusters, listener);
SearchRequest ccsLocalSearchRequest = SearchRequest.subSearchRequest(searchRequest, localIndices.indices(), RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, timeProvider.getAbsoluteStartMillis(), false);
localSearchConsumer.accept(ccsLocalSearchRequest, ccsListener);
}
}
}
use of org.opensearch.common.util.concurrent.CountDown in project OpenSearch by opensearch-project.
the class SearchScrollQueryThenFetchAsyncAction method moveToNextPhase.
@Override
protected SearchPhase moveToNextPhase(BiFunction<String, String, DiscoveryNode> clusterNodeLookup) {
return new SearchPhase("fetch") {
@Override
public void run() {
final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedScrollQueryPhase(queryResults.asList());
ScoreDoc[] scoreDocs = reducedQueryPhase.sortedTopDocs.scoreDocs;
if (scoreDocs.length == 0) {
sendResponse(reducedQueryPhase, fetchResults);
return;
}
final IntArrayList[] docIdsToLoad = searchPhaseController.fillDocIdsToLoad(queryResults.length(), scoreDocs);
final ScoreDoc[] lastEmittedDocPerShard = searchPhaseController.getLastEmittedDocPerShard(reducedQueryPhase, queryResults.length());
final CountDown counter = new CountDown(docIdsToLoad.length);
for (int i = 0; i < docIdsToLoad.length; i++) {
final int index = i;
final IntArrayList docIds = docIdsToLoad[index];
if (docIds != null) {
final QuerySearchResult querySearchResult = queryResults.get(index);
ScoreDoc lastEmittedDoc = lastEmittedDocPerShard[index];
ShardFetchRequest shardFetchRequest = new ShardFetchRequest(querySearchResult.getContextId(), docIds, lastEmittedDoc);
SearchShardTarget searchShardTarget = querySearchResult.getSearchShardTarget();
DiscoveryNode node = clusterNodeLookup.apply(searchShardTarget.getClusterAlias(), searchShardTarget.getNodeId());
assert node != null : "target node is null in secondary phase";
Transport.Connection connection = getConnection(searchShardTarget.getClusterAlias(), node);
searchTransportService.sendExecuteFetchScroll(connection, shardFetchRequest, task, new SearchActionListener<FetchSearchResult>(querySearchResult.getSearchShardTarget(), index) {
@Override
protected void innerOnResponse(FetchSearchResult response) {
fetchResults.setOnce(response.getShardIndex(), response);
if (counter.countDown()) {
sendResponse(reducedQueryPhase, fetchResults);
}
}
@Override
public void onFailure(Exception t) {
onShardFailure(getName(), counter, querySearchResult.getContextId(), t, querySearchResult.getSearchShardTarget(), () -> sendResponsePhase(reducedQueryPhase, fetchResults));
}
});
} else {
// which can have null values so we have to count them down too
if (counter.countDown()) {
sendResponse(reducedQueryPhase, fetchResults);
}
}
}
}
};
}
use of org.opensearch.common.util.concurrent.CountDown in project OpenSearch by opensearch-project.
the class TransportBroadcastReplicationAction method doExecute.
@Override
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
final ClusterState clusterState = clusterService.state();
List<ShardId> shards = shards(request, clusterState);
final CopyOnWriteArrayList<ShardResponse> shardsResponses = new CopyOnWriteArrayList<>();
if (shards.size() == 0) {
finishAndNotifyListener(listener, shardsResponses);
}
final CountDown responsesCountDown = new CountDown(shards.size());
for (final ShardId shardId : shards) {
ActionListener<ShardResponse> shardActionListener = new ActionListener<ShardResponse>() {
@Override
public void onResponse(ShardResponse shardResponse) {
shardsResponses.add(shardResponse);
logger.trace("{}: got response from {}", actionName, shardId);
if (responsesCountDown.countDown()) {
finishAndNotifyListener(listener, shardsResponses);
}
}
@Override
public void onFailure(Exception e) {
logger.trace("{}: got failure from {}", actionName, shardId);
int totalNumCopies = clusterState.getMetadata().getIndexSafe(shardId.getIndex()).getNumberOfReplicas() + 1;
ShardResponse shardResponse = newShardResponse();
ReplicationResponse.ShardInfo.Failure[] failures;
if (TransportActions.isShardNotAvailableException(e)) {
failures = new ReplicationResponse.ShardInfo.Failure[0];
} else {
ReplicationResponse.ShardInfo.Failure failure = new ReplicationResponse.ShardInfo.Failure(shardId, null, e, ExceptionsHelper.status(e), true);
failures = new ReplicationResponse.ShardInfo.Failure[totalNumCopies];
Arrays.fill(failures, failure);
}
shardResponse.setShardInfo(new ReplicationResponse.ShardInfo(totalNumCopies, 0, failures));
shardsResponses.add(shardResponse);
if (responsesCountDown.countDown()) {
finishAndNotifyListener(listener, shardsResponses);
}
}
};
shardExecute(task, request, shardId, shardActionListener);
}
}
Aggregations