Search in sources :

Example 1 with ByteArrayWrapper

use of org.apache.flink.streaming.api.utils.ByteArrayWrapper in project flink by apache.

the class SimpleStateRequestHandler method handleMapGetRequest.

private CompletionStage<BeamFnApi.StateResponse.Builder> handleMapGetRequest(BeamFnApi.StateRequest request) throws Exception {
    MapState<ByteArrayWrapper, byte[]> mapState = getMapState(request);
    // The continuation token structure of GET request is:
    // [flag (1 byte)][serialized map key]
    // The continuation token structure of CHECK_EMPTY request is:
    // [flag (1 byte)]
    // The continuation token structure of ITERATE request is:
    // [flag (1 byte)][iterate type (1 byte)][iterator token length (int32)][iterator token]
    byte[] getRequest = request.getGet().getContinuationToken().toByteArray();
    byte getFlag = getRequest[0];
    BeamFnApi.StateGetResponse.Builder response;
    switch(getFlag) {
        case GET_FLAG:
            reuseByteArrayWrapper.setData(getRequest);
            reuseByteArrayWrapper.setOffset(1);
            reuseByteArrayWrapper.setLimit(getRequest.length);
            response = handleMapGetValueRequest(reuseByteArrayWrapper, mapState);
            break;
        case CHECK_EMPTY_FLAG:
            response = handleMapCheckEmptyRequest(mapState);
            break;
        case ITERATE_FLAG:
            bais.setBuffer(getRequest, 1, getRequest.length - 1);
            IterateType iterateType = IterateType.fromOrd(baisWrapper.readByte());
            int iterateTokenLength = baisWrapper.readInt();
            ByteArrayWrapper iterateToken;
            if (iterateTokenLength > 0) {
                reuseByteArrayWrapper.setData(getRequest);
                reuseByteArrayWrapper.setOffset(bais.getPosition());
                reuseByteArrayWrapper.setLimit(bais.getPosition() + iterateTokenLength);
                iterateToken = reuseByteArrayWrapper;
            } else {
                iterateToken = null;
            }
            response = handleMapIterateRequest(mapState, iterateType, iterateToken);
            break;
        default:
            throw new RuntimeException(String.format("Unsupported get request type: '%d' for map state.", getFlag));
    }
    return CompletableFuture.completedFuture(BeamFnApi.StateResponse.newBuilder().setId(request.getId()).setGet(response));
}
Also used : ByteArrayWrapper(org.apache.flink.streaming.api.utils.ByteArrayWrapper)

Example 2 with ByteArrayWrapper

use of org.apache.flink.streaming.api.utils.ByteArrayWrapper in project flink by apache.

the class SimpleStateRequestHandler method handleMapIterateRequest.

private BeamFnApi.StateGetResponse.Builder handleMapIterateRequest(MapState<ByteArrayWrapper, byte[]> mapState, IterateType iterateType, ByteArrayWrapper iteratorToken) throws Exception {
    final Iterator iterator;
    if (iteratorToken == null) {
        switch(iterateType) {
            case ITEMS:
            case VALUES:
                iterator = mapState.iterator();
                break;
            case KEYS:
                iterator = mapState.keys().iterator();
                break;
            default:
                throw new RuntimeException("Unsupported iterate type: " + iterateType);
        }
    } else {
        iterator = mapStateIteratorCache.get(iteratorToken);
        if (iterator == null) {
            throw new RuntimeException("The cached iterator does not exist!");
        }
    }
    baos.reset();
    switch(iterateType) {
        case ITEMS:
        case VALUES:
            Iterator<Map.Entry<ByteArrayWrapper, byte[]>> entryIterator = iterator;
            for (int i = 0; i < mapStateIterateResponseBatchSize; i++) {
                if (entryIterator.hasNext()) {
                    Map.Entry<ByteArrayWrapper, byte[]> entry = entryIterator.next();
                    ByteArrayWrapper key = entry.getKey();
                    baosWrapper.write(key.getData(), key.getOffset(), key.getLimit() - key.getOffset());
                    baosWrapper.writeBoolean(entry.getValue() != null);
                    if (entry.getValue() != null) {
                        baosWrapper.write(entry.getValue());
                    }
                } else {
                    break;
                }
            }
            break;
        case KEYS:
            Iterator<ByteArrayWrapper> keyIterator = iterator;
            for (int i = 0; i < mapStateIterateResponseBatchSize; i++) {
                if (keyIterator.hasNext()) {
                    ByteArrayWrapper key = keyIterator.next();
                    baosWrapper.write(key.getData(), key.getOffset(), key.getLimit() - key.getOffset());
                } else {
                    break;
                }
            }
            break;
        default:
            throw new RuntimeException("Unsupported iterate type: " + iterateType);
    }
    if (!iterator.hasNext()) {
        if (iteratorToken != null) {
            mapStateIteratorCache.remove(iteratorToken);
        }
        iteratorToken = null;
    } else {
        if (iteratorToken == null) {
            iteratorToken = new ByteArrayWrapper(UUID.randomUUID().toString().getBytes());
        }
        mapStateIteratorCache.put(iteratorToken, iterator);
    }
    BeamFnApi.StateGetResponse.Builder responseBuilder = BeamFnApi.StateGetResponse.newBuilder().setData(ByteString.copyFrom(baos.toByteArray()));
    if (iteratorToken != null) {
        responseBuilder.setContinuationToken(ByteString.copyFrom(iteratorToken.getData(), iteratorToken.getOffset(), iteratorToken.getLimit() - iteratorToken.getOffset()));
    }
    return responseBuilder;
}
Also used : ByteArrayWrapper(org.apache.flink.streaming.api.utils.ByteArrayWrapper) Iterator(java.util.Iterator) HashMap(java.util.HashMap) Map(java.util.Map)

Example 3 with ByteArrayWrapper

use of org.apache.flink.streaming.api.utils.ByteArrayWrapper in project flink by apache.

the class SimpleStateRequestHandler method getMapState.

private MapState<ByteArrayWrapper, byte[]> getMapState(BeamFnApi.StateRequest request) throws Exception {
    BeamFnApi.StateKey.MultimapSideInput mapUserState = request.getStateKey().getMultimapSideInput();
    byte[] data = Base64.getDecoder().decode(mapUserState.getSideInputId());
    FlinkFnApi.StateDescriptor stateDescriptor = FlinkFnApi.StateDescriptor.parseFrom(data);
    String stateName = PYTHON_STATE_PREFIX + stateDescriptor.getStateName();
    StateDescriptor cachedStateDescriptor = stateDescriptorCache.get(stateName);
    MapStateDescriptor<ByteArrayWrapper, byte[]> mapStateDescriptor;
    if (cachedStateDescriptor instanceof MapStateDescriptor) {
        mapStateDescriptor = (MapStateDescriptor<ByteArrayWrapper, byte[]>) cachedStateDescriptor;
    } else if (cachedStateDescriptor == null) {
        mapStateDescriptor = new MapStateDescriptor<>(stateName, ByteArrayWrapperSerializer.INSTANCE, valueSerializer);
        if (stateDescriptor.hasStateTtlConfig()) {
            FlinkFnApi.StateDescriptor.StateTTLConfig stateTtlConfigProto = stateDescriptor.getStateTtlConfig();
            StateTtlConfig stateTtlConfig = ProtoUtils.parseStateTtlConfigFromProto(stateTtlConfigProto);
            mapStateDescriptor.enableTimeToLive(stateTtlConfig);
        }
        stateDescriptorCache.put(stateName, mapStateDescriptor);
    } else {
        throw new RuntimeException(String.format("State name corrupt detected: " + "'%s' is used both as MAP state and '%s' state at the same time.", stateName, cachedStateDescriptor.getType()));
    }
    byte[] windowBytes = mapUserState.getWindow().toByteArray();
    if (windowBytes.length != 0) {
        bais.setBuffer(windowBytes, 0, windowBytes.length);
        Object namespace = namespaceSerializer.deserialize(baisWrapper);
        return (MapState<ByteArrayWrapper, byte[]>) keyedStateBackend.getPartitionedState(namespace, namespaceSerializer, mapStateDescriptor);
    } else {
        return (MapState<ByteArrayWrapper, byte[]>) keyedStateBackend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, mapStateDescriptor);
    }
}
Also used : MapStateDescriptor(org.apache.flink.api.common.state.MapStateDescriptor) MapState(org.apache.flink.api.common.state.MapState) ByteString(org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString) StateTtlConfig(org.apache.flink.api.common.state.StateTtlConfig) FlinkFnApi(org.apache.flink.fnexecution.v1.FlinkFnApi) ByteArrayWrapper(org.apache.flink.streaming.api.utils.ByteArrayWrapper) MapStateDescriptor(org.apache.flink.api.common.state.MapStateDescriptor) ListStateDescriptor(org.apache.flink.api.common.state.ListStateDescriptor) StateDescriptor(org.apache.flink.api.common.state.StateDescriptor)

Aggregations

ByteArrayWrapper (org.apache.flink.streaming.api.utils.ByteArrayWrapper)3 HashMap (java.util.HashMap)1 Iterator (java.util.Iterator)1 Map (java.util.Map)1 ByteString (org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString)1 ListStateDescriptor (org.apache.flink.api.common.state.ListStateDescriptor)1 MapState (org.apache.flink.api.common.state.MapState)1 MapStateDescriptor (org.apache.flink.api.common.state.MapStateDescriptor)1 StateDescriptor (org.apache.flink.api.common.state.StateDescriptor)1 StateTtlConfig (org.apache.flink.api.common.state.StateTtlConfig)1 FlinkFnApi (org.apache.flink.fnexecution.v1.FlinkFnApi)1