use of org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitRequestChunk in project beam by apache.
the class GrpcWindmillServerTest method testStreamingCommit.
@Test
public void testStreamingCommit() throws Exception {
List<WorkItemCommitRequest> commitRequestList = new ArrayList<>();
List<CountDownLatch> latches = new ArrayList<>();
Map<Long, WorkItemCommitRequest> commitRequests = new HashMap<>();
for (int i = 0; i < 500; ++i) {
// Build some requests of varying size with a few big ones.
WorkItemCommitRequest request = makeCommitRequest(i, i * (i < 480 ? 8 : 128));
commitRequestList.add(request);
commitRequests.put((long) i, request);
latches.add(new CountDownLatch(1));
}
Collections.shuffle(commitRequestList);
// This server receives WorkItemCommitRequests, and verifies they are equal to the above
// commitRequest.
serviceRegistry.addService(new CloudWindmillServiceV1Alpha1ImplBase() {
@Override
public StreamObserver<StreamingCommitWorkRequest> commitWorkStream(StreamObserver<StreamingCommitResponse> responseObserver) {
return new StreamObserver<StreamingCommitWorkRequest>() {
boolean sawHeader = false;
InputStream buffer = null;
long remainingBytes = 0;
ResponseErrorInjector injector = new ResponseErrorInjector(responseObserver);
@Override
public void onNext(StreamingCommitWorkRequest request) {
maybeInjectError(responseObserver);
if (!sawHeader) {
errorCollector.checkThat(request.getHeader(), Matchers.equalTo(JobHeader.newBuilder().setJobId("job").setProjectId("project").setWorkerId("worker").build()));
sawHeader = true;
LOG.info("Received header");
} else {
boolean first = true;
LOG.info("Received request with {} chunks", request.getCommitChunkCount());
for (StreamingCommitRequestChunk chunk : request.getCommitChunkList()) {
assertTrue(chunk.getSerializedWorkItemCommit().size() <= STREAM_CHUNK_SIZE);
if (first || chunk.hasComputationId()) {
errorCollector.checkThat(chunk.getComputationId(), Matchers.equalTo("computation"));
}
if (remainingBytes != 0) {
errorCollector.checkThat(buffer, Matchers.notNullValue());
errorCollector.checkThat(remainingBytes, Matchers.is(chunk.getSerializedWorkItemCommit().size() + chunk.getRemainingBytesForWorkItem()));
buffer = new SequenceInputStream(buffer, chunk.getSerializedWorkItemCommit().newInput());
} else {
errorCollector.checkThat(buffer, Matchers.nullValue());
buffer = chunk.getSerializedWorkItemCommit().newInput();
}
remainingBytes = chunk.getRemainingBytesForWorkItem();
if (remainingBytes == 0) {
try {
WorkItemCommitRequest received = WorkItemCommitRequest.parseFrom(buffer);
errorCollector.checkThat(received, Matchers.equalTo(commitRequests.get(received.getWorkToken())));
try {
responseObserver.onNext(StreamingCommitResponse.newBuilder().addRequestId(chunk.getRequestId()).build());
} catch (IllegalStateException e) {
// Stream is closed.
}
} catch (Exception e) {
errorCollector.addError(e);
}
buffer = null;
} else {
errorCollector.checkThat(first, Matchers.is(true));
}
first = false;
}
}
}
@Override
public void onError(Throwable throwable) {
}
@Override
public void onCompleted() {
injector.cancel();
responseObserver.onCompleted();
}
};
}
});
// Make the commit requests, waiting for each of them to be verified and acknowledged.
CommitWorkStream stream = client.commitWorkStream();
for (int i = 0; i < commitRequestList.size(); ) {
final CountDownLatch latch = latches.get(i);
if (stream.commitWorkItem("computation", commitRequestList.get(i), (CommitStatus status) -> {
assertEquals(status, CommitStatus.OK);
latch.countDown();
})) {
i++;
} else {
stream.flush();
}
}
stream.flush();
for (CountDownLatch latch : latches) {
assertTrue(latch.await(1, TimeUnit.MINUTES));
}
stream.close();
assertTrue(stream.awaitTermination(30, TimeUnit.SECONDS));
}
Aggregations