use of org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest in project beam by apache.
the class FakeWindmillServer method getDataStream.
@Override
public GetDataStream getDataStream() {
Instant startTime = Instant.now();
return new GetDataStream() {
@Override
public Windmill.KeyedGetDataResponse requestKeyedData(String computation, KeyedGetDataRequest request) {
Windmill.GetDataRequest getDataRequest = GetDataRequest.newBuilder().addRequests(ComputationGetDataRequest.newBuilder().setComputationId(computation).addRequests(request).build()).build();
GetDataResponse getDataResponse = getData(getDataRequest);
if (getDataResponse.getDataList().isEmpty()) {
return null;
}
assertEquals(1, getDataResponse.getDataCount());
if (getDataResponse.getData(0).getDataList().isEmpty()) {
return null;
}
assertEquals(1, getDataResponse.getData(0).getDataCount());
return getDataResponse.getData(0).getData(0);
}
@Override
public Windmill.GlobalData requestGlobalData(Windmill.GlobalDataRequest request) {
Windmill.GetDataRequest getDataRequest = GetDataRequest.newBuilder().addGlobalDataFetchRequests(request).build();
GetDataResponse getDataResponse = getData(getDataRequest);
if (getDataResponse.getGlobalDataList().isEmpty()) {
return null;
}
assertEquals(1, getDataResponse.getGlobalDataCount());
return getDataResponse.getGlobalData(0);
}
@Override
public void refreshActiveWork(Map<String, List<KeyedGetDataRequest>> active) {
}
@Override
public void close() {
}
@Override
public boolean awaitTermination(int time, TimeUnit unit) {
return true;
}
@Override
public Instant startTime() {
return startTime;
}
};
}
use of org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest in project beam by apache.
the class FakeWindmillServer method validateGetDataRequest.
private void validateGetDataRequest(Windmill.GetDataRequest request) {
for (ComputationGetDataRequest computationRequest : request.getRequestsList()) {
for (KeyedGetDataRequest keyRequest : computationRequest.getRequestsList()) {
errorCollector.checkThat(keyRequest.hasWorkToken(), equalTo(true));
errorCollector.checkThat(keyRequest.getShardingKey(), allOf(greaterThan(0L), lessThan(Long.MAX_VALUE)));
errorCollector.checkThat(keyRequest.getMaxBytes(), greaterThanOrEqualTo(0L));
}
}
}
use of org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest in project beam by apache.
the class WindmillStateReaderTest method testBatching.
@Test
public void testBatching() throws Exception {
// Reads two bags and verifies that we batch them up correctly.
Future<Instant> watermarkFuture = underTest.watermarkFuture(STATE_KEY_2, STATE_FAMILY);
Future<Iterable<Integer>> bagFuture = underTest.bagFuture(STATE_KEY_1, STATE_FAMILY, INT_CODER);
Mockito.verifyNoMoreInteractions(mockWindmill);
ArgumentCaptor<Windmill.KeyedGetDataRequest> request = ArgumentCaptor.forClass(Windmill.KeyedGetDataRequest.class);
Windmill.KeyedGetDataResponse.Builder response = Windmill.KeyedGetDataResponse.newBuilder().setKey(DATA_KEY).addWatermarkHolds(Windmill.WatermarkHold.newBuilder().setTag(STATE_KEY_2).setStateFamily(STATE_FAMILY).addTimestamps(5000000).addTimestamps(6000000)).addBags(Windmill.TagBag.newBuilder().setTag(STATE_KEY_1).setStateFamily(STATE_FAMILY).addValues(intData(5)).addValues(intData(100)));
Mockito.when(mockWindmill.getStateData(Mockito.eq(COMPUTATION), Mockito.isA(Windmill.KeyedGetDataRequest.class))).thenReturn(response.build());
Instant result = watermarkFuture.get();
Mockito.verify(mockWindmill).getStateData(Mockito.eq(COMPUTATION), request.capture());
// Verify the request looks right.
KeyedGetDataRequest keyedRequest = request.getValue();
assertThat(keyedRequest.getKey(), Matchers.equalTo(DATA_KEY));
assertThat(keyedRequest.getWorkToken(), Matchers.equalTo(WORK_TOKEN));
assertThat(keyedRequest.getBagsToFetchCount(), Matchers.equalTo(1));
assertThat(keyedRequest.getBagsToFetch(0).getDeleteAll(), Matchers.equalTo(false));
assertThat(keyedRequest.getBagsToFetch(0).getTag(), Matchers.equalTo(STATE_KEY_1));
assertThat(keyedRequest.getWatermarkHoldsToFetchCount(), Matchers.equalTo(1));
assertThat(keyedRequest.getWatermarkHoldsToFetch(0).getTag(), Matchers.equalTo(STATE_KEY_2));
// Verify the values returned to the user.
assertThat(result, Matchers.equalTo(new Instant(5000)));
Mockito.verifyNoMoreInteractions(mockWindmill);
assertThat(bagFuture.get(), Matchers.contains(5, 100));
Mockito.verifyNoMoreInteractions(mockWindmill);
// And verify that getting a future again returns the already completed future.
Future<Instant> watermarkFuture2 = underTest.watermarkFuture(STATE_KEY_2, STATE_FAMILY);
assertTrue(watermarkFuture2.isDone());
assertNoReader(watermarkFuture);
assertNoReader(watermarkFuture2);
}
use of org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest in project beam by apache.
the class GrpcWindmillServerTest method testStreamingGetData.
@Test
@SuppressWarnings("FutureReturnValueIgnored")
public void testStreamingGetData() throws Exception {
// This server responds to GetDataRequests with responses that mirror the requests.
serviceRegistry.addService(new CloudWindmillServiceV1Alpha1ImplBase() {
@Override
public StreamObserver<StreamingGetDataRequest> getDataStream(StreamObserver<StreamingGetDataResponse> responseObserver) {
return new StreamObserver<StreamingGetDataRequest>() {
boolean sawHeader = false;
HashSet<Long> seenIds = new HashSet<>();
ResponseErrorInjector injector = new ResponseErrorInjector(responseObserver);
StreamingGetDataResponse.Builder responseBuilder = StreamingGetDataResponse.newBuilder();
@Override
public void onNext(StreamingGetDataRequest chunk) {
maybeInjectError(responseObserver);
try {
if (!sawHeader) {
LOG.info("Received header");
errorCollector.checkThat(chunk.getHeader(), Matchers.equalTo(JobHeader.newBuilder().setJobId("job").setProjectId("project").setWorkerId("worker").build()));
sawHeader = true;
} else {
LOG.info("Received get data of {} global data, {} data requests", chunk.getGlobalDataRequestCount(), chunk.getStateRequestCount());
errorCollector.checkThat(chunk.getSerializedSize(), Matchers.lessThanOrEqualTo(STREAM_CHUNK_SIZE));
int i = 0;
for (GlobalDataRequest request : chunk.getGlobalDataRequestList()) {
long requestId = chunk.getRequestId(i++);
errorCollector.checkThat(seenIds.add(requestId), Matchers.is(true));
sendResponse(requestId, processGlobalDataRequest(request));
}
for (ComputationGetDataRequest request : chunk.getStateRequestList()) {
long requestId = chunk.getRequestId(i++);
errorCollector.checkThat(seenIds.add(requestId), Matchers.is(true));
sendResponse(requestId, processStateRequest(request));
}
flushResponse();
}
} catch (Exception e) {
errorCollector.addError(e);
}
}
@Override
public void onError(Throwable throwable) {
}
@Override
public void onCompleted() {
injector.cancel();
responseObserver.onCompleted();
}
private ByteString processGlobalDataRequest(GlobalDataRequest request) {
errorCollector.checkThat(request.getStateFamily(), Matchers.is("family"));
return GlobalData.newBuilder().setDataId(request.getDataId()).setStateFamily("family").setData(ByteString.copyFromUtf8(request.getDataId().getTag())).build().toByteString();
}
private ByteString processStateRequest(ComputationGetDataRequest compRequest) {
errorCollector.checkThat(compRequest.getRequestsCount(), Matchers.is(1));
errorCollector.checkThat(compRequest.getComputationId(), Matchers.is("computation"));
KeyedGetDataRequest request = compRequest.getRequests(0);
KeyedGetDataResponse response = makeGetDataResponse(request.getValuesToFetch(0).getTag().toStringUtf8());
return response.toByteString();
}
private void sendResponse(long id, ByteString serializedResponse) {
if (ThreadLocalRandom.current().nextInt(4) == 0) {
sendChunkedResponse(id, serializedResponse);
} else {
responseBuilder.addRequestId(id).addSerializedResponse(serializedResponse);
if (responseBuilder.getRequestIdCount() > 10) {
flushResponse();
}
}
}
private void sendChunkedResponse(long id, ByteString serializedResponse) {
LOG.info("Sending response with {} chunks", (serializedResponse.size() / 10) + 1);
for (int i = 0; i < serializedResponse.size(); i += 10) {
int end = Math.min(serializedResponse.size(), i + 10);
try {
responseObserver.onNext(StreamingGetDataResponse.newBuilder().addRequestId(id).addSerializedResponse(serializedResponse.substring(i, end)).setRemainingBytesForResponse(serializedResponse.size() - end).build());
} catch (IllegalStateException e) {
// Stream is already closed.
}
}
}
private void flushResponse() {
if (responseBuilder.getRequestIdCount() > 0) {
LOG.info("Sending batched response of {} ids", responseBuilder.getRequestIdCount());
try {
responseObserver.onNext(responseBuilder.build());
} catch (IllegalStateException e) {
// Stream is already closed.
}
responseBuilder.clear();
}
}
};
}
});
GetDataStream stream = client.getDataStream();
// Make requests of varying sizes to test chunking, and verify the responses.
ExecutorService executor = Executors.newFixedThreadPool(50);
final CountDownLatch done = new CountDownLatch(200);
for (int i = 0; i < 100; ++i) {
final String key = "key" + i;
final String s = i % 5 == 0 ? largeString(i) : "tag";
executor.submit(() -> {
errorCollector.checkThat(stream.requestKeyedData("computation", makeGetDataRequest(key, s)), Matchers.equalTo(makeGetDataResponse(s)));
done.countDown();
});
executor.execute(() -> {
errorCollector.checkThat(stream.requestGlobalData(makeGlobalDataRequest(key)), Matchers.equalTo(makeGlobalDataResponse(key)));
done.countDown();
});
}
done.await();
stream.close();
assertTrue(stream.awaitTermination(60, TimeUnit.SECONDS));
executor.shutdown();
}
Aggregations