use of org.opensearch.common.util.concurrent.CountDown in project OpenSearch by opensearch-project.
the class TransportSearchAction method collectSearchShards.
static void collectSearchShards(IndicesOptions indicesOptions, String preference, String routing, AtomicInteger skippedClusters, Map<String, OriginalIndices> remoteIndicesByCluster, RemoteClusterService remoteClusterService, ThreadPool threadPool, ActionListener<Map<String, ClusterSearchShardsResponse>> listener) {
final CountDown responsesCountDown = new CountDown(remoteIndicesByCluster.size());
final Map<String, ClusterSearchShardsResponse> searchShardsResponses = new ConcurrentHashMap<>();
final AtomicReference<Exception> exceptions = new AtomicReference<>();
for (Map.Entry<String, OriginalIndices> entry : remoteIndicesByCluster.entrySet()) {
final String clusterAlias = entry.getKey();
boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias);
Client clusterClient = remoteClusterService.getRemoteClusterClient(threadPool, clusterAlias);
final String[] indices = entry.getValue().indices();
ClusterSearchShardsRequest searchShardsRequest = new ClusterSearchShardsRequest(indices).indicesOptions(indicesOptions).local(true).preference(preference).routing(routing);
clusterClient.admin().cluster().searchShards(searchShardsRequest, new CCSActionListener<ClusterSearchShardsResponse, Map<String, ClusterSearchShardsResponse>>(clusterAlias, skipUnavailable, responsesCountDown, skippedClusters, exceptions, listener) {
@Override
void innerOnResponse(ClusterSearchShardsResponse clusterSearchShardsResponse) {
searchShardsResponses.put(clusterAlias, clusterSearchShardsResponse);
}
@Override
Map<String, ClusterSearchShardsResponse> createFinalResponse() {
return searchShardsResponses;
}
});
}
}
use of org.opensearch.common.util.concurrent.CountDown in project OpenSearch by opensearch-project.
the class SearchScrollAsyncAction method run.
private void run(BiFunction<String, String, DiscoveryNode> clusterNodeLookup, final SearchContextIdForNode[] context) {
final CountDown counter = new CountDown(scrollId.getContext().length);
for (int i = 0; i < context.length; i++) {
SearchContextIdForNode target = context[i];
final int shardIndex = i;
final Transport.Connection connection;
try {
DiscoveryNode node = clusterNodeLookup.apply(target.getClusterAlias(), target.getNode());
if (node == null) {
throw new IllegalStateException("node [" + target.getNode() + "] is not available");
}
connection = getConnection(target.getClusterAlias(), node);
} catch (Exception ex) {
onShardFailure("query", counter, target.getSearchContextId(), ex, null, () -> SearchScrollAsyncAction.this.moveToNextPhase(clusterNodeLookup));
continue;
}
final InternalScrollSearchRequest internalRequest = TransportSearchHelper.internalScrollSearchRequest(target.getSearchContextId(), request);
// we can't create a SearchShardTarget here since we don't know the index and shard ID we are talking to
// we only know the node and the search context ID. Yet, the response will contain the SearchShardTarget
// from the target node instead...that's why we pass null here
SearchActionListener<T> searchActionListener = new SearchActionListener<T>(null, shardIndex) {
@Override
protected void setSearchShardTarget(T response) {
// don't do this - it's part of the response...
assert response.getSearchShardTarget() != null : "search shard target must not be null";
if (target.getClusterAlias() != null) {
// re-create the search target and add the cluster alias if there is any,
// we need this down the road for subseq. phases
SearchShardTarget searchShardTarget = response.getSearchShardTarget();
response.setSearchShardTarget(new SearchShardTarget(searchShardTarget.getNodeId(), searchShardTarget.getShardId(), target.getClusterAlias(), null));
}
}
@Override
protected void innerOnResponse(T result) {
assert shardIndex == result.getShardIndex() : "shard index mismatch: " + shardIndex + " but got: " + result.getShardIndex();
onFirstPhaseResult(shardIndex, result);
if (counter.countDown()) {
SearchPhase phase = moveToNextPhase(clusterNodeLookup);
try {
phase.run();
} catch (Exception e) {
// we need to fail the entire request here - the entire phase just blew up
// don't call onShardFailure or onFailure here since otherwise we'd countDown the counter
// again which would result in an exception
listener.onFailure(new SearchPhaseExecutionException(phase.getName(), "Phase failed", e, ShardSearchFailure.EMPTY_ARRAY));
}
}
}
@Override
public void onFailure(Exception t) {
onShardFailure("query", counter, target.getSearchContextId(), t, null, () -> SearchScrollAsyncAction.this.moveToNextPhase(clusterNodeLookup));
}
};
executeInitialPhase(connection, internalRequest, searchActionListener);
}
}
use of org.opensearch.common.util.concurrent.CountDown in project OpenSearch by opensearch-project.
the class AzureBlobContainerRetriesTests method testReadRangeBlobWithRetries.
public void testReadRangeBlobWithRetries() throws Exception {
// The request retry policy counts the first attempt as retry, so we need to
// account for that and increase the max retry count by one.
final int maxRetries = randomIntBetween(2, 6);
final CountDown countDownGet = new CountDown(maxRetries - 1);
final byte[] bytes = randomBlobContent();
httpServer.createContext("/container/read_range_blob_max_retries", exchange -> {
try {
Streams.readFully(exchange.getRequestBody());
if ("HEAD".equals(exchange.getRequestMethod())) {
exchange.getResponseHeaders().add("Content-Type", "application/octet-stream");
exchange.getResponseHeaders().add("Content-Length", String.valueOf(bytes.length));
exchange.getResponseHeaders().add("x-ms-blob-type", "blockblob");
exchange.sendResponseHeaders(RestStatus.OK.getStatus(), -1);
return;
} else if ("GET".equals(exchange.getRequestMethod())) {
if (countDownGet.countDown()) {
final int rangeStart = getRangeStart(exchange);
assertThat(rangeStart, lessThan(bytes.length));
final Optional<Integer> rangeEnd = getRangeEnd(exchange);
assertThat(rangeEnd.isPresent(), is(true));
assertThat(rangeEnd.get(), greaterThanOrEqualTo(rangeStart));
final int length = (rangeEnd.get() - rangeStart) + 1;
assertThat(length, lessThanOrEqualTo(bytes.length - rangeStart));
exchange.getResponseHeaders().add("Content-Type", "application/octet-stream");
exchange.getResponseHeaders().add("Content-Length", String.valueOf(length));
exchange.getResponseHeaders().add("x-ms-blob-type", "blockblob");
exchange.sendResponseHeaders(RestStatus.OK.getStatus(), length);
exchange.getResponseBody().write(bytes, rangeStart, length);
return;
}
}
if (randomBoolean()) {
AzureHttpHandler.sendError(exchange, randomFrom(RestStatus.INTERNAL_SERVER_ERROR, RestStatus.SERVICE_UNAVAILABLE));
}
} finally {
exchange.close();
}
});
final BlobContainer blobContainer = createBlobContainer(maxRetries);
final int position = randomIntBetween(0, bytes.length - 1);
final int length = randomIntBetween(1, bytes.length - position);
try (InputStream inputStream = blobContainer.readBlob("read_range_blob_max_retries", position, length)) {
final byte[] bytesRead = BytesReference.toBytes(Streams.readFully(inputStream));
assertArrayEquals(Arrays.copyOfRange(bytes, position, Math.min(bytes.length, position + length)), bytesRead);
assertThat(countDownGet.isCountedDown(), is(true));
}
}
use of org.opensearch.common.util.concurrent.CountDown in project OpenSearch by opensearch-project.
the class AzureBlobContainerRetriesTests method testWriteLargeBlob.
public void testWriteLargeBlob() throws Exception {
// The request retry policy counts the first attempt as retry, so we need to
// account for that and increase the max retry count by one.
final int maxRetries = randomIntBetween(3, 6);
final int nbBlocks = randomIntBetween(1, 2);
final byte[] data = randomBytes(BlobClient.BLOB_DEFAULT_UPLOAD_BLOCK_SIZE * nbBlocks);
// we want all requests to fail at least once
final int nbErrors = 2;
final AtomicInteger countDownUploads = new AtomicInteger(nbErrors * nbBlocks);
final CountDown countDownComplete = new CountDown(nbErrors);
final Map<String, BytesReference> blocks = new ConcurrentHashMap<>();
httpServer.createContext("/container/write_large_blob", exchange -> {
if ("PUT".equals(exchange.getRequestMethod())) {
final Map<String, String> params = new HashMap<>();
if (exchange.getRequestURI().getQuery() != null) {
RestUtils.decodeQueryString(exchange.getRequestURI().getQuery(), 0, params);
}
final String blockId = params.get("blockid");
if (Strings.hasText(blockId) && (countDownUploads.decrementAndGet() % 2 == 0)) {
blocks.put(blockId, Streams.readFully(exchange.getRequestBody()));
exchange.getResponseHeaders().add("x-ms-request-server-encrypted", "false");
exchange.sendResponseHeaders(RestStatus.CREATED.getStatus(), -1);
exchange.close();
return;
}
final String complete = params.get("comp");
if ("blocklist".equals(complete) && (countDownComplete.countDown())) {
final String blockList = Streams.copyToString(new InputStreamReader(exchange.getRequestBody(), UTF_8));
final List<String> blockUids = Arrays.stream(blockList.split("<Latest>")).filter(line -> line.contains("</Latest>")).map(line -> line.substring(0, line.indexOf("</Latest>"))).collect(Collectors.toList());
final ByteArrayOutputStream blob = new ByteArrayOutputStream();
for (String blockUid : blockUids) {
BytesReference block = blocks.remove(blockUid);
assert block != null;
block.writeTo(blob);
}
assertArrayEquals(data, blob.toByteArray());
exchange.getResponseHeaders().add("x-ms-request-server-encrypted", "false");
exchange.sendResponseHeaders(RestStatus.CREATED.getStatus(), -1);
exchange.close();
return;
}
}
if (randomBoolean()) {
Streams.readFully(exchange.getRequestBody());
AzureHttpHandler.sendError(exchange, randomFrom(RestStatus.INTERNAL_SERVER_ERROR, RestStatus.SERVICE_UNAVAILABLE));
}
exchange.close();
});
final BlobContainer blobContainer = createBlobContainer(maxRetries);
try (InputStream stream = new InputStreamIndexInput(new ByteArrayIndexInput("desc", data), data.length)) {
blobContainer.writeBlob("write_large_blob", stream, data.length, false);
}
assertThat(countDownUploads.get(), equalTo(0));
assertThat(countDownComplete.isCountedDown(), is(true));
assertThat(blocks.isEmpty(), is(true));
}
use of org.opensearch.common.util.concurrent.CountDown in project OpenSearch by opensearch-project.
the class S3BlobContainerRetriesTests method testWriteLargeBlob.
public void testWriteLargeBlob() throws Exception {
final boolean useTimeout = rarely();
final TimeValue readTimeout = useTimeout ? TimeValue.timeValueMillis(randomIntBetween(100, 500)) : null;
final ByteSizeValue bufferSize = new ByteSizeValue(5, ByteSizeUnit.MB);
final BlobContainer blobContainer = createBlobContainer(null, readTimeout, true, bufferSize);
final int parts = randomIntBetween(1, 5);
final long lastPartSize = randomLongBetween(10, 512);
final long blobSize = (parts * bufferSize.getBytes()) + lastPartSize;
// we want all requests to fail at least once
final int nbErrors = 2;
final CountDown countDownInitiate = new CountDown(nbErrors);
final AtomicInteger countDownUploads = new AtomicInteger(nbErrors * (parts + 1));
final CountDown countDownComplete = new CountDown(nbErrors);
httpServer.createContext("/bucket/write_large_blob", exchange -> {
final long contentLength = Long.parseLong(exchange.getRequestHeaders().getFirst("Content-Length"));
if ("POST".equals(exchange.getRequestMethod()) && exchange.getRequestURI().getQuery().equals("uploads")) {
// initiate multipart upload request
if (countDownInitiate.countDown()) {
byte[] response = ("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n" + "<InitiateMultipartUploadResult>\n" + " <Bucket>bucket</Bucket>\n" + " <Key>write_large_blob</Key>\n" + " <UploadId>TEST</UploadId>\n" + "</InitiateMultipartUploadResult>").getBytes(StandardCharsets.UTF_8);
exchange.getResponseHeaders().add("Content-Type", "application/xml");
exchange.sendResponseHeaders(HttpStatus.SC_OK, response.length);
exchange.getResponseBody().write(response);
exchange.close();
return;
}
} else if ("PUT".equals(exchange.getRequestMethod()) && exchange.getRequestURI().getQuery().contains("uploadId=TEST") && exchange.getRequestURI().getQuery().contains("partNumber=")) {
// upload part request
MD5DigestCalculatingInputStream md5 = new MD5DigestCalculatingInputStream(exchange.getRequestBody());
BytesReference bytes = Streams.readFully(md5);
assertThat((long) bytes.length(), anyOf(equalTo(lastPartSize), equalTo(bufferSize.getBytes())));
assertThat(contentLength, anyOf(equalTo(lastPartSize), equalTo(bufferSize.getBytes())));
if (countDownUploads.decrementAndGet() % 2 == 0) {
exchange.getResponseHeaders().add("ETag", Base16.encodeAsString(md5.getMd5Digest()));
exchange.sendResponseHeaders(HttpStatus.SC_OK, -1);
exchange.close();
return;
}
} else if ("POST".equals(exchange.getRequestMethod()) && exchange.getRequestURI().getQuery().equals("uploadId=TEST")) {
// complete multipart upload request
if (countDownComplete.countDown()) {
Streams.readFully(exchange.getRequestBody());
byte[] response = ("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n" + "<CompleteMultipartUploadResult>\n" + " <Bucket>bucket</Bucket>\n" + " <Key>write_large_blob</Key>\n" + "</CompleteMultipartUploadResult>").getBytes(StandardCharsets.UTF_8);
exchange.getResponseHeaders().add("Content-Type", "application/xml");
exchange.sendResponseHeaders(HttpStatus.SC_OK, response.length);
exchange.getResponseBody().write(response);
exchange.close();
return;
}
}
// sends an error back or let the request time out
if (useTimeout == false) {
if (randomBoolean() && contentLength > 0) {
Streams.readFully(exchange.getRequestBody(), new byte[randomIntBetween(1, Math.toIntExact(contentLength - 1))]);
} else {
Streams.readFully(exchange.getRequestBody());
exchange.sendResponseHeaders(randomFrom(HttpStatus.SC_INTERNAL_SERVER_ERROR, HttpStatus.SC_BAD_GATEWAY, HttpStatus.SC_SERVICE_UNAVAILABLE, HttpStatus.SC_GATEWAY_TIMEOUT), -1);
}
exchange.close();
}
});
blobContainer.writeBlob("write_large_blob", new ZeroInputStream(blobSize), blobSize, false);
assertThat(countDownInitiate.isCountedDown(), is(true));
assertThat(countDownUploads.get(), equalTo(0));
assertThat(countDownComplete.isCountedDown(), is(true));
}
Aggregations