use of org.apache.flink.streaming.api.utils.ByteArrayWrapper in project flink by apache.
the class SimpleStateRequestHandler method handleMapGetRequest.
private CompletionStage<BeamFnApi.StateResponse.Builder> handleMapGetRequest(BeamFnApi.StateRequest request) throws Exception {
MapState<ByteArrayWrapper, byte[]> mapState = getMapState(request);
// The continuation token structure of GET request is:
// [flag (1 byte)][serialized map key]
// The continuation token structure of CHECK_EMPTY request is:
// [flag (1 byte)]
// The continuation token structure of ITERATE request is:
// [flag (1 byte)][iterate type (1 byte)][iterator token length (int32)][iterator token]
byte[] getRequest = request.getGet().getContinuationToken().toByteArray();
byte getFlag = getRequest[0];
BeamFnApi.StateGetResponse.Builder response;
switch(getFlag) {
case GET_FLAG:
reuseByteArrayWrapper.setData(getRequest);
reuseByteArrayWrapper.setOffset(1);
reuseByteArrayWrapper.setLimit(getRequest.length);
response = handleMapGetValueRequest(reuseByteArrayWrapper, mapState);
break;
case CHECK_EMPTY_FLAG:
response = handleMapCheckEmptyRequest(mapState);
break;
case ITERATE_FLAG:
bais.setBuffer(getRequest, 1, getRequest.length - 1);
IterateType iterateType = IterateType.fromOrd(baisWrapper.readByte());
int iterateTokenLength = baisWrapper.readInt();
ByteArrayWrapper iterateToken;
if (iterateTokenLength > 0) {
reuseByteArrayWrapper.setData(getRequest);
reuseByteArrayWrapper.setOffset(bais.getPosition());
reuseByteArrayWrapper.setLimit(bais.getPosition() + iterateTokenLength);
iterateToken = reuseByteArrayWrapper;
} else {
iterateToken = null;
}
response = handleMapIterateRequest(mapState, iterateType, iterateToken);
break;
default:
throw new RuntimeException(String.format("Unsupported get request type: '%d' for map state.", getFlag));
}
return CompletableFuture.completedFuture(BeamFnApi.StateResponse.newBuilder().setId(request.getId()).setGet(response));
}
use of org.apache.flink.streaming.api.utils.ByteArrayWrapper in project flink by apache.
the class SimpleStateRequestHandler method handleMapIterateRequest.
private BeamFnApi.StateGetResponse.Builder handleMapIterateRequest(MapState<ByteArrayWrapper, byte[]> mapState, IterateType iterateType, ByteArrayWrapper iteratorToken) throws Exception {
final Iterator iterator;
if (iteratorToken == null) {
switch(iterateType) {
case ITEMS:
case VALUES:
iterator = mapState.iterator();
break;
case KEYS:
iterator = mapState.keys().iterator();
break;
default:
throw new RuntimeException("Unsupported iterate type: " + iterateType);
}
} else {
iterator = mapStateIteratorCache.get(iteratorToken);
if (iterator == null) {
throw new RuntimeException("The cached iterator does not exist!");
}
}
baos.reset();
switch(iterateType) {
case ITEMS:
case VALUES:
Iterator<Map.Entry<ByteArrayWrapper, byte[]>> entryIterator = iterator;
for (int i = 0; i < mapStateIterateResponseBatchSize; i++) {
if (entryIterator.hasNext()) {
Map.Entry<ByteArrayWrapper, byte[]> entry = entryIterator.next();
ByteArrayWrapper key = entry.getKey();
baosWrapper.write(key.getData(), key.getOffset(), key.getLimit() - key.getOffset());
baosWrapper.writeBoolean(entry.getValue() != null);
if (entry.getValue() != null) {
baosWrapper.write(entry.getValue());
}
} else {
break;
}
}
break;
case KEYS:
Iterator<ByteArrayWrapper> keyIterator = iterator;
for (int i = 0; i < mapStateIterateResponseBatchSize; i++) {
if (keyIterator.hasNext()) {
ByteArrayWrapper key = keyIterator.next();
baosWrapper.write(key.getData(), key.getOffset(), key.getLimit() - key.getOffset());
} else {
break;
}
}
break;
default:
throw new RuntimeException("Unsupported iterate type: " + iterateType);
}
if (!iterator.hasNext()) {
if (iteratorToken != null) {
mapStateIteratorCache.remove(iteratorToken);
}
iteratorToken = null;
} else {
if (iteratorToken == null) {
iteratorToken = new ByteArrayWrapper(UUID.randomUUID().toString().getBytes());
}
mapStateIteratorCache.put(iteratorToken, iterator);
}
BeamFnApi.StateGetResponse.Builder responseBuilder = BeamFnApi.StateGetResponse.newBuilder().setData(ByteString.copyFrom(baos.toByteArray()));
if (iteratorToken != null) {
responseBuilder.setContinuationToken(ByteString.copyFrom(iteratorToken.getData(), iteratorToken.getOffset(), iteratorToken.getLimit() - iteratorToken.getOffset()));
}
return responseBuilder;
}
use of org.apache.flink.streaming.api.utils.ByteArrayWrapper in project flink by apache.
the class SimpleStateRequestHandler method getMapState.
private MapState<ByteArrayWrapper, byte[]> getMapState(BeamFnApi.StateRequest request) throws Exception {
BeamFnApi.StateKey.MultimapSideInput mapUserState = request.getStateKey().getMultimapSideInput();
byte[] data = Base64.getDecoder().decode(mapUserState.getSideInputId());
FlinkFnApi.StateDescriptor stateDescriptor = FlinkFnApi.StateDescriptor.parseFrom(data);
String stateName = PYTHON_STATE_PREFIX + stateDescriptor.getStateName();
StateDescriptor cachedStateDescriptor = stateDescriptorCache.get(stateName);
MapStateDescriptor<ByteArrayWrapper, byte[]> mapStateDescriptor;
if (cachedStateDescriptor instanceof MapStateDescriptor) {
mapStateDescriptor = (MapStateDescriptor<ByteArrayWrapper, byte[]>) cachedStateDescriptor;
} else if (cachedStateDescriptor == null) {
mapStateDescriptor = new MapStateDescriptor<>(stateName, ByteArrayWrapperSerializer.INSTANCE, valueSerializer);
if (stateDescriptor.hasStateTtlConfig()) {
FlinkFnApi.StateDescriptor.StateTTLConfig stateTtlConfigProto = stateDescriptor.getStateTtlConfig();
StateTtlConfig stateTtlConfig = ProtoUtils.parseStateTtlConfigFromProto(stateTtlConfigProto);
mapStateDescriptor.enableTimeToLive(stateTtlConfig);
}
stateDescriptorCache.put(stateName, mapStateDescriptor);
} else {
throw new RuntimeException(String.format("State name corrupt detected: " + "'%s' is used both as MAP state and '%s' state at the same time.", stateName, cachedStateDescriptor.getType()));
}
byte[] windowBytes = mapUserState.getWindow().toByteArray();
if (windowBytes.length != 0) {
bais.setBuffer(windowBytes, 0, windowBytes.length);
Object namespace = namespaceSerializer.deserialize(baisWrapper);
return (MapState<ByteArrayWrapper, byte[]>) keyedStateBackend.getPartitionedState(namespace, namespaceSerializer, mapStateDescriptor);
} else {
return (MapState<ByteArrayWrapper, byte[]>) keyedStateBackend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, mapStateDescriptor);
}
}
Aggregations