Search in sources :

Example 6 with GraphQlMessage

use of org.springframework.graphql.web.support.GraphQlMessage in project spring-graphql by spring-projects.

the class GraphQlWebSocketHandler method handle.

@Override
public Mono<Void> handle(WebSocketSession session) {
    HandshakeInfo handshakeInfo = session.getHandshakeInfo();
    if ("graphql-ws".equalsIgnoreCase(handshakeInfo.getSubProtocol())) {
        if (logger.isDebugEnabled()) {
            logger.debug("apollographql/subscriptions-transport-ws is not supported, nor maintained. " + "Please, use https://github.com/enisdenjo/graphql-ws.");
        }
        return session.close(GraphQlStatus.INVALID_MESSAGE_STATUS);
    }
    // Session state
    AtomicReference<Map<String, Object>> connectionInitPayloadRef = new AtomicReference<>();
    Map<String, Subscription> subscriptions = new ConcurrentHashMap<>();
    Mono.delay(this.initTimeoutDuration).then(Mono.defer(() -> connectionInitPayloadRef.compareAndSet(null, Collections.emptyMap()) ? session.close(GraphQlStatus.INIT_TIMEOUT_STATUS) : Mono.empty())).subscribe();
    session.closeStatus().doOnSuccess(closeStatus -> {
        Map<String, Object> connectionInitPayload = connectionInitPayloadRef.get();
        if (connectionInitPayload == null) {
            return;
        }
        int statusCode = (closeStatus != null ? closeStatus.getCode() : 1005);
        this.webSocketInterceptor.handleConnectionClosed(session.getId(), statusCode, connectionInitPayload);
    }).subscribe();
    return session.send(session.receive().flatMap(webSocketMessage -> {
        GraphQlMessage message = this.codecDelegate.decode(webSocketMessage);
        String id = message.getId();
        Map<String, Object> payload = message.getPayload();
        switch(message.resolvedType()) {
            case SUBSCRIBE:
                if (connectionInitPayloadRef.get() == null) {
                    return GraphQlStatus.close(session, GraphQlStatus.UNAUTHORIZED_STATUS);
                }
                if (id == null) {
                    return GraphQlStatus.close(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
                }
                WebInput input = new WebInput(handshakeInfo.getUri(), handshakeInfo.getHeaders(), payload, id, null);
                if (logger.isDebugEnabled()) {
                    logger.debug("Executing: " + input);
                }
                return this.graphQlHandler.handleRequest(input).flatMapMany((output) -> handleWebOutput(session, id, subscriptions, output)).doOnTerminate(() -> subscriptions.remove(id));
            case PING:
                return Flux.just(this.codecDelegate.encode(session, GraphQlMessage.pong(null)));
            case COMPLETE:
                if (id != null) {
                    Subscription subscription = subscriptions.remove(id);
                    if (subscription != null) {
                        subscription.cancel();
                    }
                    return this.webSocketInterceptor.handleCancelledSubscription(session.getId(), id).thenMany(Flux.empty());
                }
                return Flux.empty();
            case CONNECTION_INIT:
                if (!connectionInitPayloadRef.compareAndSet(null, payload)) {
                    return GraphQlStatus.close(session, GraphQlStatus.TOO_MANY_INIT_REQUESTS_STATUS);
                }
                return this.webSocketInterceptor.handleConnectionInitialization(session.getId(), payload).defaultIfEmpty(Collections.emptyMap()).map(ackPayload -> this.codecDelegate.encodeConnectionAck(session, ackPayload)).flux().onErrorResume(ex -> GraphQlStatus.close(session, GraphQlStatus.UNAUTHORIZED_STATUS));
            default:
                return GraphQlStatus.close(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
        }
    }));
}
Also used : GraphQlMessage(org.springframework.graphql.web.support.GraphQlMessage) Arrays(java.util.Arrays) WebSocketSession(org.springframework.web.reactive.socket.WebSocketSession) CloseStatus(org.springframework.web.reactive.socket.CloseStatus) WebOutput(org.springframework.graphql.web.WebOutput) AtomicReference(java.util.concurrent.atomic.AtomicReference) ExecutionResult(graphql.ExecutionResult) Duration(java.time.Duration) Map(java.util.Map) WebSocketHandler(org.springframework.web.reactive.socket.WebSocketHandler) WebGraphQlHandler(org.springframework.graphql.web.WebGraphQlHandler) HandshakeInfo(org.springframework.web.reactive.socket.HandshakeInfo) Publisher(org.reactivestreams.Publisher) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) Mono(reactor.core.publisher.Mono) CodecConfigurer(org.springframework.http.codec.CodecConfigurer) Flux(reactor.core.publisher.Flux) List(java.util.List) WebInput(org.springframework.graphql.web.WebInput) CollectionUtils(org.springframework.util.CollectionUtils) Subscription(org.reactivestreams.Subscription) WebSocketMessage(org.springframework.web.reactive.socket.WebSocketMessage) Log(org.apache.commons.logging.Log) LogFactory(org.apache.commons.logging.LogFactory) Collections(java.util.Collections) WebSocketInterceptor(org.springframework.graphql.web.WebSocketInterceptor) Assert(org.springframework.util.Assert) GraphQlMessage(org.springframework.graphql.web.support.GraphQlMessage) WebInput(org.springframework.graphql.web.WebInput) AtomicReference(java.util.concurrent.atomic.AtomicReference) Subscription(org.reactivestreams.Subscription) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) Map(java.util.Map) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HandshakeInfo(org.springframework.web.reactive.socket.HandshakeInfo)

Example 7 with GraphQlMessage

use of org.springframework.graphql.web.support.GraphQlMessage in project spring-graphql by spring-projects.

the class GraphQlWebSocketHandler method handleTextMessage.

@Override
protected void handleTextMessage(WebSocketSession session, TextMessage webSocketMessage) throws Exception {
    GraphQlMessage message = decode(webSocketMessage);
    String id = message.getId();
    Map<String, Object> payload = message.getPayload();
    SessionState sessionState = getSessionInfo(session);
    switch(message.resolvedType()) {
        case SUBSCRIBE:
            if (sessionState.getConnectionInitPayload() == null) {
                GraphQlStatus.closeSession(session, GraphQlStatus.UNAUTHORIZED_STATUS);
                return;
            }
            if (id == null) {
                GraphQlStatus.closeSession(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
                return;
            }
            URI uri = session.getUri();
            Assert.notNull(uri, "Expected handshake url");
            HttpHeaders headers = session.getHandshakeHeaders();
            WebInput input = new WebInput(uri, headers, payload, id, null);
            if (logger.isDebugEnabled()) {
                logger.debug("Executing: " + input);
            }
            this.graphQlHandler.handleRequest(input).flatMapMany((output) -> handleWebOutput(session, input.getId(), output)).publishOn(// Serial blocking send via single thread
            sessionState.getScheduler()).subscribe(new SendMessageSubscriber(id, session, sessionState));
            return;
        case PING:
            session.sendMessage(encode(GraphQlMessage.pong(null)));
            return;
        case COMPLETE:
            if (id != null) {
                Subscription subscription = sessionState.getSubscriptions().remove(id);
                if (subscription != null) {
                    subscription.cancel();
                }
                this.webSocketInterceptor.handleCancelledSubscription(session.getId(), id).block(Duration.ofSeconds(10));
            }
            return;
        case CONNECTION_INIT:
            if (!sessionState.setConnectionInitPayload(payload)) {
                GraphQlStatus.closeSession(session, GraphQlStatus.TOO_MANY_INIT_REQUESTS_STATUS);
                return;
            }
            this.webSocketInterceptor.handleConnectionInitialization(session.getId(), payload).defaultIfEmpty(Collections.emptyMap()).publishOn(// Serial blocking send via single thread
            sessionState.getScheduler()).doOnNext(ackPayload -> {
                TextMessage outputMessage = encode(GraphQlMessage.connectionAck(ackPayload));
                try {
                    session.sendMessage(outputMessage);
                } catch (IOException ex) {
                    throw new IllegalStateException(ex);
                }
            }).onErrorResume(ex -> {
                GraphQlStatus.closeSession(session, GraphQlStatus.UNAUTHORIZED_STATUS);
                return Mono.empty();
            }).block(Duration.ofSeconds(10));
            return;
        default:
            GraphQlStatus.closeSession(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
    }
}
Also used : GraphQlMessage(org.springframework.graphql.web.support.GraphQlMessage) Arrays(java.util.Arrays) ByteArrayOutputStream(java.io.ByteArrayOutputStream) WebOutput(org.springframework.graphql.web.WebOutput) Scheduler(reactor.core.scheduler.Scheduler) AtomicReference(java.util.concurrent.atomic.AtomicReference) WebSocketSession(org.springframework.web.socket.WebSocketSession) CloseStatus(org.springframework.web.socket.CloseStatus) ExecutionResult(graphql.ExecutionResult) TextMessage(org.springframework.web.socket.TextMessage) TextWebSocketHandler(org.springframework.web.socket.handler.TextWebSocketHandler) ByteArrayInputStream(java.io.ByteArrayInputStream) GraphQLError(graphql.GraphQLError) Duration(java.time.Duration) Map(java.util.Map) Schedulers(reactor.core.scheduler.Schedulers) Nullable(org.springframework.lang.Nullable) URI(java.net.URI) OutputStream(java.io.OutputStream) WebGraphQlHandler(org.springframework.graphql.web.WebGraphQlHandler) HttpHeaders(org.springframework.http.HttpHeaders) ExceptionWebSocketHandlerDecorator(org.springframework.web.socket.handler.ExceptionWebSocketHandlerDecorator) Publisher(org.reactivestreams.Publisher) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) IOException(java.io.IOException) Mono(reactor.core.publisher.Mono) GraphqlErrorBuilder(graphql.GraphqlErrorBuilder) BaseSubscriber(reactor.core.publisher.BaseSubscriber) Flux(reactor.core.publisher.Flux) List(java.util.List) GraphQlMessageType(org.springframework.graphql.web.support.GraphQlMessageType) HttpInputMessage(org.springframework.http.HttpInputMessage) WebInput(org.springframework.graphql.web.WebInput) HttpMessageConverter(org.springframework.http.converter.HttpMessageConverter) CollectionUtils(org.springframework.util.CollectionUtils) Subscription(org.reactivestreams.Subscription) Log(org.apache.commons.logging.Log) LogFactory(org.apache.commons.logging.LogFactory) Collections(java.util.Collections) WebSocketInterceptor(org.springframework.graphql.web.WebSocketInterceptor) SubProtocolCapable(org.springframework.web.socket.SubProtocolCapable) InputStream(java.io.InputStream) HttpOutputMessage(org.springframework.http.HttpOutputMessage) GenericHttpMessageConverter(org.springframework.http.converter.GenericHttpMessageConverter) Assert(org.springframework.util.Assert) HttpHeaders(org.springframework.http.HttpHeaders) GraphQlMessage(org.springframework.graphql.web.support.GraphQlMessage) IOException(java.io.IOException) URI(java.net.URI) WebInput(org.springframework.graphql.web.WebInput) Subscription(org.reactivestreams.Subscription) TextMessage(org.springframework.web.socket.TextMessage)

Example 8 with GraphQlMessage

use of org.springframework.graphql.web.support.GraphQlMessage in project spring-graphql by spring-projects.

the class GraphQlWebSocketHandlerTests method subscriptionExists.

@Test
void subscriptionExists() {
    Flux<WebSocketMessage> messageFlux = Flux.just(toWebSocketMessage("{\"type\":\"connection_init\"}"), toWebSocketMessage(BOOK_SUBSCRIPTION), toWebSocketMessage(BOOK_SUBSCRIPTION));
    TestWebSocketSession session = handle(messageFlux, new ConsumeOneAndNeverCompleteInterceptor());
    // Collect messages until session closed
    List<GraphQlMessage> messages = new ArrayList<>();
    session.getOutput().subscribe((message) -> messages.add(decode(message)));
    StepVerifier.create(session.closeStatus()).expectNext(new CloseStatus(4409, "Subscriber for " + SUBSCRIPTION_ID + " already exists")).expectComplete().verify(TIMEOUT);
    assertThat(messages.size()).isEqualTo(2);
    assertThat(messages.get(0).resolvedType()).isEqualTo(GraphQlMessageType.CONNECTION_ACK);
    assertThat(messages.get(1).resolvedType()).isEqualTo(GraphQlMessageType.NEXT);
}
Also used : ConsumeOneAndNeverCompleteInterceptor(org.springframework.graphql.web.ConsumeOneAndNeverCompleteInterceptor) GraphQlMessage(org.springframework.graphql.web.support.GraphQlMessage) ArrayList(java.util.ArrayList) WebSocketMessage(org.springframework.web.reactive.socket.WebSocketMessage) CloseStatus(org.springframework.web.reactive.socket.CloseStatus) Test(org.junit.jupiter.api.Test)

Example 9 with GraphQlMessage

use of org.springframework.graphql.web.support.GraphQlMessage in project spring-graphql by spring-projects.

the class GraphQlWebSocketHandlerTests method connectionInitHandling.

@Test
void connectionInitHandling() {
    TestWebSocketSession session = handle(Flux.just(toWebSocketMessage("{\"type\":\"connection_init\",\"payload\":{\"key\":\"A\"}}")), new WebSocketInterceptor() {

        @Override
        public Mono<Object> handleConnectionInitialization(String sessionId, Map<String, Object> payload) {
            Object value = payload.get("key");
            return Mono.just(Collections.singletonMap("key", value + " acknowledged"));
        }
    });
    StepVerifier.create(session.getOutput()).consumeNextWith((message) -> {
        GraphQlMessage actual = decode(message);
        assertThat(actual.resolvedType()).isEqualTo(GraphQlMessageType.CONNECTION_ACK);
        assertThat(actual.<Map<String, Object>>getPayload()).containsEntry("key", "A acknowledged");
    }).expectComplete().verify(TIMEOUT);
}
Also used : WebSocketInterceptor(org.springframework.graphql.web.WebSocketInterceptor) Mono(reactor.core.publisher.Mono) GraphQlMessage(org.springframework.graphql.web.support.GraphQlMessage) Map(java.util.Map) Test(org.junit.jupiter.api.Test)

Example 10 with GraphQlMessage

use of org.springframework.graphql.web.support.GraphQlMessage in project spring-graphql by spring-projects.

the class GraphQlWebSocketHandlerTests method assertMessageType.

private void assertMessageType(WebSocketMessage webSocketMessage, GraphQlMessageType messageType) {
    GraphQlMessage message = decode(webSocketMessage);
    assertThat(message.resolvedType()).isEqualTo(messageType);
    if (messageType != GraphQlMessageType.CONNECTION_ACK && messageType != GraphQlMessageType.PONG) {
        assertThat(message.getId()).isEqualTo(SUBSCRIPTION_ID);
    }
}
Also used : GraphQlMessage(org.springframework.graphql.web.support.GraphQlMessage)

Aggregations

GraphQlMessage (org.springframework.graphql.web.support.GraphQlMessage)12 Map (java.util.Map)8 Test (org.junit.jupiter.api.Test)8 WebSocketInterceptor (org.springframework.graphql.web.WebSocketInterceptor)8 Mono (reactor.core.publisher.Mono)8 Duration (java.time.Duration)7 ArrayList (java.util.ArrayList)7 Collections (java.util.Collections)7 List (java.util.List)7 ConsumeOneAndNeverCompleteInterceptor (org.springframework.graphql.web.ConsumeOneAndNeverCompleteInterceptor)7 WebGraphQlHandler (org.springframework.graphql.web.WebGraphQlHandler)7 Flux (reactor.core.publisher.Flux)7 GraphQlMessageType (org.springframework.graphql.web.support.GraphQlMessageType)6 CloseStatus (org.springframework.web.socket.CloseStatus)6 TextMessage (org.springframework.web.socket.TextMessage)6 ByteArrayInputStream (java.io.ByteArrayInputStream)5 IOException (java.io.IOException)5 InputStream (java.io.InputStream)5 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)5 BiConsumer (java.util.function.BiConsumer)5