use of org.springframework.graphql.web.support.GraphQlMessage in project spring-graphql by spring-projects.
the class GraphQlWebSocketHandler method handle.
@Override
public Mono<Void> handle(WebSocketSession session) {
HandshakeInfo handshakeInfo = session.getHandshakeInfo();
if ("graphql-ws".equalsIgnoreCase(handshakeInfo.getSubProtocol())) {
if (logger.isDebugEnabled()) {
logger.debug("apollographql/subscriptions-transport-ws is not supported, nor maintained. " + "Please, use https://github.com/enisdenjo/graphql-ws.");
}
return session.close(GraphQlStatus.INVALID_MESSAGE_STATUS);
}
// Session state
AtomicReference<Map<String, Object>> connectionInitPayloadRef = new AtomicReference<>();
Map<String, Subscription> subscriptions = new ConcurrentHashMap<>();
Mono.delay(this.initTimeoutDuration).then(Mono.defer(() -> connectionInitPayloadRef.compareAndSet(null, Collections.emptyMap()) ? session.close(GraphQlStatus.INIT_TIMEOUT_STATUS) : Mono.empty())).subscribe();
session.closeStatus().doOnSuccess(closeStatus -> {
Map<String, Object> connectionInitPayload = connectionInitPayloadRef.get();
if (connectionInitPayload == null) {
return;
}
int statusCode = (closeStatus != null ? closeStatus.getCode() : 1005);
this.webSocketInterceptor.handleConnectionClosed(session.getId(), statusCode, connectionInitPayload);
}).subscribe();
return session.send(session.receive().flatMap(webSocketMessage -> {
GraphQlMessage message = this.codecDelegate.decode(webSocketMessage);
String id = message.getId();
Map<String, Object> payload = message.getPayload();
switch(message.resolvedType()) {
case SUBSCRIBE:
if (connectionInitPayloadRef.get() == null) {
return GraphQlStatus.close(session, GraphQlStatus.UNAUTHORIZED_STATUS);
}
if (id == null) {
return GraphQlStatus.close(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
}
WebInput input = new WebInput(handshakeInfo.getUri(), handshakeInfo.getHeaders(), payload, id, null);
if (logger.isDebugEnabled()) {
logger.debug("Executing: " + input);
}
return this.graphQlHandler.handleRequest(input).flatMapMany((output) -> handleWebOutput(session, id, subscriptions, output)).doOnTerminate(() -> subscriptions.remove(id));
case PING:
return Flux.just(this.codecDelegate.encode(session, GraphQlMessage.pong(null)));
case COMPLETE:
if (id != null) {
Subscription subscription = subscriptions.remove(id);
if (subscription != null) {
subscription.cancel();
}
return this.webSocketInterceptor.handleCancelledSubscription(session.getId(), id).thenMany(Flux.empty());
}
return Flux.empty();
case CONNECTION_INIT:
if (!connectionInitPayloadRef.compareAndSet(null, payload)) {
return GraphQlStatus.close(session, GraphQlStatus.TOO_MANY_INIT_REQUESTS_STATUS);
}
return this.webSocketInterceptor.handleConnectionInitialization(session.getId(), payload).defaultIfEmpty(Collections.emptyMap()).map(ackPayload -> this.codecDelegate.encodeConnectionAck(session, ackPayload)).flux().onErrorResume(ex -> GraphQlStatus.close(session, GraphQlStatus.UNAUTHORIZED_STATUS));
default:
return GraphQlStatus.close(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
}
}));
}
use of org.springframework.graphql.web.support.GraphQlMessage in project spring-graphql by spring-projects.
the class GraphQlWebSocketHandler method handleTextMessage.
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage webSocketMessage) throws Exception {
GraphQlMessage message = decode(webSocketMessage);
String id = message.getId();
Map<String, Object> payload = message.getPayload();
SessionState sessionState = getSessionInfo(session);
switch(message.resolvedType()) {
case SUBSCRIBE:
if (sessionState.getConnectionInitPayload() == null) {
GraphQlStatus.closeSession(session, GraphQlStatus.UNAUTHORIZED_STATUS);
return;
}
if (id == null) {
GraphQlStatus.closeSession(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
return;
}
URI uri = session.getUri();
Assert.notNull(uri, "Expected handshake url");
HttpHeaders headers = session.getHandshakeHeaders();
WebInput input = new WebInput(uri, headers, payload, id, null);
if (logger.isDebugEnabled()) {
logger.debug("Executing: " + input);
}
this.graphQlHandler.handleRequest(input).flatMapMany((output) -> handleWebOutput(session, input.getId(), output)).publishOn(// Serial blocking send via single thread
sessionState.getScheduler()).subscribe(new SendMessageSubscriber(id, session, sessionState));
return;
case PING:
session.sendMessage(encode(GraphQlMessage.pong(null)));
return;
case COMPLETE:
if (id != null) {
Subscription subscription = sessionState.getSubscriptions().remove(id);
if (subscription != null) {
subscription.cancel();
}
this.webSocketInterceptor.handleCancelledSubscription(session.getId(), id).block(Duration.ofSeconds(10));
}
return;
case CONNECTION_INIT:
if (!sessionState.setConnectionInitPayload(payload)) {
GraphQlStatus.closeSession(session, GraphQlStatus.TOO_MANY_INIT_REQUESTS_STATUS);
return;
}
this.webSocketInterceptor.handleConnectionInitialization(session.getId(), payload).defaultIfEmpty(Collections.emptyMap()).publishOn(// Serial blocking send via single thread
sessionState.getScheduler()).doOnNext(ackPayload -> {
TextMessage outputMessage = encode(GraphQlMessage.connectionAck(ackPayload));
try {
session.sendMessage(outputMessage);
} catch (IOException ex) {
throw new IllegalStateException(ex);
}
}).onErrorResume(ex -> {
GraphQlStatus.closeSession(session, GraphQlStatus.UNAUTHORIZED_STATUS);
return Mono.empty();
}).block(Duration.ofSeconds(10));
return;
default:
GraphQlStatus.closeSession(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
}
}
use of org.springframework.graphql.web.support.GraphQlMessage in project spring-graphql by spring-projects.
the class GraphQlWebSocketHandlerTests method subscriptionExists.
@Test
void subscriptionExists() {
Flux<WebSocketMessage> messageFlux = Flux.just(toWebSocketMessage("{\"type\":\"connection_init\"}"), toWebSocketMessage(BOOK_SUBSCRIPTION), toWebSocketMessage(BOOK_SUBSCRIPTION));
TestWebSocketSession session = handle(messageFlux, new ConsumeOneAndNeverCompleteInterceptor());
// Collect messages until session closed
List<GraphQlMessage> messages = new ArrayList<>();
session.getOutput().subscribe((message) -> messages.add(decode(message)));
StepVerifier.create(session.closeStatus()).expectNext(new CloseStatus(4409, "Subscriber for " + SUBSCRIPTION_ID + " already exists")).expectComplete().verify(TIMEOUT);
assertThat(messages.size()).isEqualTo(2);
assertThat(messages.get(0).resolvedType()).isEqualTo(GraphQlMessageType.CONNECTION_ACK);
assertThat(messages.get(1).resolvedType()).isEqualTo(GraphQlMessageType.NEXT);
}
use of org.springframework.graphql.web.support.GraphQlMessage in project spring-graphql by spring-projects.
the class GraphQlWebSocketHandlerTests method connectionInitHandling.
@Test
void connectionInitHandling() {
TestWebSocketSession session = handle(Flux.just(toWebSocketMessage("{\"type\":\"connection_init\",\"payload\":{\"key\":\"A\"}}")), new WebSocketInterceptor() {
@Override
public Mono<Object> handleConnectionInitialization(String sessionId, Map<String, Object> payload) {
Object value = payload.get("key");
return Mono.just(Collections.singletonMap("key", value + " acknowledged"));
}
});
StepVerifier.create(session.getOutput()).consumeNextWith((message) -> {
GraphQlMessage actual = decode(message);
assertThat(actual.resolvedType()).isEqualTo(GraphQlMessageType.CONNECTION_ACK);
assertThat(actual.<Map<String, Object>>getPayload()).containsEntry("key", "A acknowledged");
}).expectComplete().verify(TIMEOUT);
}
use of org.springframework.graphql.web.support.GraphQlMessage in project spring-graphql by spring-projects.
the class GraphQlWebSocketHandlerTests method assertMessageType.
private void assertMessageType(WebSocketMessage webSocketMessage, GraphQlMessageType messageType) {
GraphQlMessage message = decode(webSocketMessage);
assertThat(message.resolvedType()).isEqualTo(messageType);
if (messageType != GraphQlMessageType.CONNECTION_ACK && messageType != GraphQlMessageType.PONG) {
assertThat(message.getId()).isEqualTo(SUBSCRIPTION_ID);
}
}
Aggregations