use of jakarta.websocket.Extension in project tomcat by apache.
the class WsWebSocketContainer method connectToServerRecursive.
private Session connectToServerRecursive(ClientEndpointHolder clientEndpointHolder, ClientEndpointConfig clientEndpointConfiguration, URI path, Set<URI> redirectSet) throws DeploymentException {
if (log.isDebugEnabled()) {
log.debug(sm.getString("wsWebSocketContainer.connect.entry", clientEndpointHolder.getClassName(), path));
}
boolean secure = false;
ByteBuffer proxyConnect = null;
URI proxyPath;
// Validate scheme (and build proxyPath)
String scheme = path.getScheme();
if ("ws".equalsIgnoreCase(scheme)) {
proxyPath = URI.create("http" + path.toString().substring(2));
} else if ("wss".equalsIgnoreCase(scheme)) {
proxyPath = URI.create("https" + path.toString().substring(3));
secure = true;
} else {
throw new DeploymentException(sm.getString("wsWebSocketContainer.pathWrongScheme", scheme));
}
// Validate host
String host = path.getHost();
if (host == null) {
throw new DeploymentException(sm.getString("wsWebSocketContainer.pathNoHost"));
}
int port = path.getPort();
SocketAddress sa = null;
// Check to see if a proxy is configured. Javadoc indicates return value
// will never be null
List<Proxy> proxies = ProxySelector.getDefault().select(proxyPath);
Proxy selectedProxy = null;
for (Proxy proxy : proxies) {
if (proxy.type().equals(Proxy.Type.HTTP)) {
sa = proxy.address();
if (sa instanceof InetSocketAddress) {
InetSocketAddress inet = (InetSocketAddress) sa;
if (inet.isUnresolved()) {
sa = new InetSocketAddress(inet.getHostName(), inet.getPort());
}
}
selectedProxy = proxy;
break;
}
}
// scheme
if (port == -1) {
if ("ws".equalsIgnoreCase(scheme)) {
port = 80;
} else {
// Must be wss due to scheme validation above
port = 443;
}
}
// If sa is null, no proxy is configured so need to create sa
if (sa == null) {
sa = new InetSocketAddress(host, port);
} else {
proxyConnect = createProxyRequest(host, port);
}
// Create the initial HTTP request to open the WebSocket connection
Map<String, List<String>> reqHeaders = createRequestHeaders(host, port, secure, clientEndpointConfiguration);
clientEndpointConfiguration.getConfigurator().beforeRequest(reqHeaders);
if (Constants.DEFAULT_ORIGIN_HEADER_VALUE != null && !reqHeaders.containsKey(Constants.ORIGIN_HEADER_NAME)) {
List<String> originValues = new ArrayList<>(1);
originValues.add(Constants.DEFAULT_ORIGIN_HEADER_VALUE);
reqHeaders.put(Constants.ORIGIN_HEADER_NAME, originValues);
}
ByteBuffer request = createRequest(path, reqHeaders);
AsynchronousSocketChannel socketChannel;
try {
socketChannel = AsynchronousSocketChannel.open(getAsynchronousChannelGroup());
} catch (IOException ioe) {
throw new DeploymentException(sm.getString("wsWebSocketContainer.asynchronousSocketChannelFail"), ioe);
}
Map<String, Object> userProperties = clientEndpointConfiguration.getUserProperties();
// Get the connection timeout
long timeout = Constants.IO_TIMEOUT_MS_DEFAULT;
String timeoutValue = (String) userProperties.get(Constants.IO_TIMEOUT_MS_PROPERTY);
if (timeoutValue != null) {
timeout = Long.valueOf(timeoutValue).intValue();
}
// Set-up
// Same size as the WsFrame input buffer
ByteBuffer response = ByteBuffer.allocate(getDefaultMaxBinaryMessageBufferSize());
String subProtocol;
boolean success = false;
List<Extension> extensionsAgreed = new ArrayList<>();
Transformation transformation = null;
AsyncChannelWrapper channel = null;
try {
// Open the connection
Future<Void> fConnect = socketChannel.connect(sa);
if (proxyConnect != null) {
fConnect.get(timeout, TimeUnit.MILLISECONDS);
// Proxy CONNECT is clear text
channel = new AsyncChannelWrapperNonSecure(socketChannel);
writeRequest(channel, proxyConnect, timeout);
HttpResponse httpResponse = processResponse(response, channel, timeout);
if (httpResponse.getStatus() != 200) {
throw new DeploymentException(sm.getString("wsWebSocketContainer.proxyConnectFail", selectedProxy, Integer.toString(httpResponse.getStatus())));
}
}
if (secure) {
// Regardless of whether a non-secure wrapper was created for a
// proxy CONNECT, need to use TLS from this point on so wrap the
// original AsynchronousSocketChannel
SSLEngine sslEngine = createSSLEngine(clientEndpointConfiguration, host, port);
channel = new AsyncChannelWrapperSecure(socketChannel, sslEngine);
} else if (channel == null) {
// Only need to wrap as this point if it wasn't wrapped to process a
// proxy CONNECT
channel = new AsyncChannelWrapperNonSecure(socketChannel);
}
fConnect.get(timeout, TimeUnit.MILLISECONDS);
Future<Void> fHandshake = channel.handshake();
fHandshake.get(timeout, TimeUnit.MILLISECONDS);
if (log.isDebugEnabled()) {
SocketAddress localAddress = null;
try {
localAddress = channel.getLocalAddress();
} catch (IOException ioe) {
// Ignore
}
log.debug(sm.getString("wsWebSocketContainer.connect.write", Integer.valueOf(request.position()), Integer.valueOf(request.limit()), localAddress));
}
writeRequest(channel, request, timeout);
HttpResponse httpResponse = processResponse(response, channel, timeout);
// Check maximum permitted redirects
int maxRedirects = Constants.MAX_REDIRECTIONS_DEFAULT;
String maxRedirectsValue = (String) userProperties.get(Constants.MAX_REDIRECTIONS_PROPERTY);
if (maxRedirectsValue != null) {
maxRedirects = Integer.parseInt(maxRedirectsValue);
}
if (httpResponse.status != 101) {
if (isRedirectStatus(httpResponse.status)) {
List<String> locationHeader = httpResponse.getHandshakeResponse().getHeaders().get(Constants.LOCATION_HEADER_NAME);
if (locationHeader == null || locationHeader.isEmpty() || locationHeader.get(0) == null || locationHeader.get(0).isEmpty()) {
throw new DeploymentException(sm.getString("wsWebSocketContainer.missingLocationHeader", Integer.toString(httpResponse.status)));
}
URI redirectLocation = URI.create(locationHeader.get(0)).normalize();
if (!redirectLocation.isAbsolute()) {
redirectLocation = path.resolve(redirectLocation);
}
String redirectScheme = redirectLocation.getScheme().toLowerCase();
if (redirectScheme.startsWith("http")) {
redirectLocation = new URI(redirectScheme.replace("http", "ws"), redirectLocation.getUserInfo(), redirectLocation.getHost(), redirectLocation.getPort(), redirectLocation.getPath(), redirectLocation.getQuery(), redirectLocation.getFragment());
}
if (!redirectSet.add(redirectLocation) || redirectSet.size() > maxRedirects) {
throw new DeploymentException(sm.getString("wsWebSocketContainer.redirectThreshold", redirectLocation, Integer.toString(redirectSet.size()), Integer.toString(maxRedirects)));
}
return connectToServerRecursive(clientEndpointHolder, clientEndpointConfiguration, redirectLocation, redirectSet);
} else if (httpResponse.status == 401) {
if (userProperties.get(Constants.AUTHORIZATION_HEADER_NAME) != null) {
throw new DeploymentException(sm.getString("wsWebSocketContainer.failedAuthentication", Integer.valueOf(httpResponse.status)));
}
List<String> wwwAuthenticateHeaders = httpResponse.getHandshakeResponse().getHeaders().get(Constants.WWW_AUTHENTICATE_HEADER_NAME);
if (wwwAuthenticateHeaders == null || wwwAuthenticateHeaders.isEmpty() || wwwAuthenticateHeaders.get(0) == null || wwwAuthenticateHeaders.get(0).isEmpty()) {
throw new DeploymentException(sm.getString("wsWebSocketContainer.missingWWWAuthenticateHeader", Integer.toString(httpResponse.status)));
}
String authScheme = wwwAuthenticateHeaders.get(0).split("\\s+", 2)[0];
String requestUri = new String(request.array(), StandardCharsets.ISO_8859_1).split("\\s", 3)[1];
Authenticator auth = AuthenticatorFactory.getAuthenticator(authScheme);
if (auth == null) {
throw new DeploymentException(sm.getString("wsWebSocketContainer.unsupportedAuthScheme", Integer.valueOf(httpResponse.status), authScheme));
}
userProperties.put(Constants.AUTHORIZATION_HEADER_NAME, auth.getAuthorization(requestUri, wwwAuthenticateHeaders.get(0), userProperties));
return connectToServerRecursive(clientEndpointHolder, clientEndpointConfiguration, path, redirectSet);
} else {
throw new DeploymentException(sm.getString("wsWebSocketContainer.invalidStatus", Integer.toString(httpResponse.status)));
}
}
HandshakeResponse handshakeResponse = httpResponse.getHandshakeResponse();
clientEndpointConfiguration.getConfigurator().afterResponse(handshakeResponse);
// Sub-protocol
List<String> protocolHeaders = handshakeResponse.getHeaders().get(Constants.WS_PROTOCOL_HEADER_NAME);
if (protocolHeaders == null || protocolHeaders.size() == 0) {
subProtocol = null;
} else if (protocolHeaders.size() == 1) {
subProtocol = protocolHeaders.get(0);
} else {
throw new DeploymentException(sm.getString("wsWebSocketContainer.invalidSubProtocol"));
}
// Extensions
// Should normally only be one header but handle the case of
// multiple headers
List<String> extHeaders = handshakeResponse.getHeaders().get(Constants.WS_EXTENSIONS_HEADER_NAME);
if (extHeaders != null) {
for (String extHeader : extHeaders) {
Util.parseExtensionHeader(extensionsAgreed, extHeader);
}
}
// Build the transformations
TransformationFactory factory = TransformationFactory.getInstance();
for (Extension extension : extensionsAgreed) {
List<List<Extension.Parameter>> wrapper = new ArrayList<>(1);
wrapper.add(extension.getParameters());
Transformation t = factory.create(extension.getName(), wrapper, false);
if (t == null) {
throw new DeploymentException(sm.getString("wsWebSocketContainer.invalidExtensionParameters"));
}
if (transformation == null) {
transformation = t;
} else {
transformation.setNext(t);
}
}
success = true;
} catch (ExecutionException | InterruptedException | SSLException | EOFException | TimeoutException | URISyntaxException | AuthenticationException e) {
throw new DeploymentException(sm.getString("wsWebSocketContainer.httpRequestFailed", path), e);
} finally {
if (!success) {
if (channel != null) {
channel.close();
} else {
try {
socketChannel.close();
} catch (IOException ioe) {
// Ignore
}
}
}
}
// Switch to WebSocket
WsRemoteEndpointImplClient wsRemoteEndpointClient = new WsRemoteEndpointImplClient(channel);
WsSession wsSession = new WsSession(clientEndpointHolder, wsRemoteEndpointClient, this, extensionsAgreed, subProtocol, Collections.<String, String>emptyMap(), secure, clientEndpointConfiguration);
WsFrameClient wsFrameClient = new WsFrameClient(response, channel, wsSession, transformation);
// WsFrame adds the necessary final transformations. Copy the
// completed transformation chain to the remote end point.
wsRemoteEndpointClient.setTransformation(wsFrameClient.getTransformation());
wsSession.getLocal().onOpen(wsSession, clientEndpointConfiguration);
registerSession(wsSession.getLocal(), wsSession);
/* It is possible that the server sent one or more messages as soon as
* the WebSocket connection was established. Depending on the exact
* timing of when those messages were sent they could be sat in the
* input buffer waiting to be read and will not trigger a "data
* available to read" event. Therefore, it is necessary to process the
* input buffer here. Note that this happens on the current thread which
* means that this thread will be used for any onMessage notifications.
* This is a special case. Subsequent "data available to read" events
* will be handled by threads from the AsyncChannelGroup's executor.
*/
wsFrameClient.startInputProcessing();
return wsSession;
}
use of jakarta.websocket.Extension in project tomcat by apache.
the class UpgradeUtil method doUpgrade.
public static void doUpgrade(WsServerContainer sc, HttpServletRequest req, HttpServletResponse resp, ServerEndpointConfig sec, Map<String, String> pathParams) throws ServletException, IOException {
// Validate the rest of the headers and reject the request if that
// validation fails
String key;
String subProtocol = null;
if (!headerContainsToken(req, Constants.CONNECTION_HEADER_NAME, Constants.CONNECTION_HEADER_VALUE)) {
resp.sendError(HttpServletResponse.SC_BAD_REQUEST);
return;
}
if (!headerContainsToken(req, Constants.WS_VERSION_HEADER_NAME, Constants.WS_VERSION_HEADER_VALUE)) {
resp.setStatus(426);
resp.setHeader(Constants.WS_VERSION_HEADER_NAME, Constants.WS_VERSION_HEADER_VALUE);
return;
}
key = req.getHeader(Constants.WS_KEY_HEADER_NAME);
if (key == null) {
resp.sendError(HttpServletResponse.SC_BAD_REQUEST);
return;
}
// Origin check
String origin = req.getHeader(Constants.ORIGIN_HEADER_NAME);
if (!sec.getConfigurator().checkOrigin(origin)) {
resp.sendError(HttpServletResponse.SC_FORBIDDEN);
return;
}
// Sub-protocols
List<String> subProtocols = getTokensFromHeader(req, Constants.WS_PROTOCOL_HEADER_NAME);
subProtocol = sec.getConfigurator().getNegotiatedSubprotocol(sec.getSubprotocols(), subProtocols);
// Extensions
// Should normally only be one header but handle the case of multiple
// headers
List<Extension> extensionsRequested = new ArrayList<>();
Enumeration<String> extHeaders = req.getHeaders(Constants.WS_EXTENSIONS_HEADER_NAME);
while (extHeaders.hasMoreElements()) {
Util.parseExtensionHeader(extensionsRequested, extHeaders.nextElement());
}
// Negotiation phase 1. By default this simply filters out the
// extensions that the server does not support but applications could
// use a custom configurator to do more than this.
List<Extension> installedExtensions = null;
if (sec.getExtensions().size() == 0) {
installedExtensions = Constants.INSTALLED_EXTENSIONS;
} else {
installedExtensions = new ArrayList<>();
installedExtensions.addAll(sec.getExtensions());
installedExtensions.addAll(Constants.INSTALLED_EXTENSIONS);
}
List<Extension> negotiatedExtensionsPhase1 = sec.getConfigurator().getNegotiatedExtensions(installedExtensions, extensionsRequested);
// Negotiation phase 2. Create the Transformations that will be applied
// to this connection. Note than an extension may be dropped at this
// point if the client has requested a configuration that the server is
// unable to support.
List<Transformation> transformations = createTransformations(negotiatedExtensionsPhase1);
List<Extension> negotiatedExtensionsPhase2;
if (transformations.isEmpty()) {
negotiatedExtensionsPhase2 = Collections.emptyList();
} else {
negotiatedExtensionsPhase2 = new ArrayList<>(transformations.size());
for (Transformation t : transformations) {
negotiatedExtensionsPhase2.add(t.getExtensionResponse());
}
}
// Build the transformation pipeline
Transformation transformation = null;
StringBuilder responseHeaderExtensions = new StringBuilder();
boolean first = true;
for (Transformation t : transformations) {
if (first) {
first = false;
} else {
responseHeaderExtensions.append(',');
}
append(responseHeaderExtensions, t.getExtensionResponse());
if (transformation == null) {
transformation = t;
} else {
transformation.setNext(t);
}
}
// Now we have the full pipeline, validate the use of the RSV bits.
if (transformation != null && !transformation.validateRsvBits(0)) {
throw new ServletException(sm.getString("upgradeUtil.incompatibleRsv"));
}
// If we got this far, all is good. Accept the connection.
resp.setHeader(Constants.UPGRADE_HEADER_NAME, Constants.UPGRADE_HEADER_VALUE);
resp.setHeader(Constants.CONNECTION_HEADER_NAME, Constants.CONNECTION_HEADER_VALUE);
resp.setHeader(HandshakeResponse.SEC_WEBSOCKET_ACCEPT, getWebSocketAccept(key));
if (subProtocol != null && subProtocol.length() > 0) {
// RFC6455 4.2.2 explicitly states "" is not valid here
resp.setHeader(Constants.WS_PROTOCOL_HEADER_NAME, subProtocol);
}
if (!transformations.isEmpty()) {
resp.setHeader(Constants.WS_EXTENSIONS_HEADER_NAME, responseHeaderExtensions.toString());
}
// Add method mapping to user properties
if (!Endpoint.class.isAssignableFrom(sec.getEndpointClass()) && sec.getUserProperties().get(org.apache.tomcat.websocket.pojo.Constants.POJO_METHOD_MAPPING_KEY) == null) {
// directly. Need to add the method mapping.
try {
PojoMethodMapping methodMapping = new PojoMethodMapping(sec.getEndpointClass(), sec.getDecoders(), sec.getPath(), sc.getInstanceManager(Thread.currentThread().getContextClassLoader()));
if (methodMapping.getOnClose() != null || methodMapping.getOnOpen() != null || methodMapping.getOnError() != null || methodMapping.hasMessageHandlers()) {
sec.getUserProperties().put(org.apache.tomcat.websocket.pojo.Constants.POJO_METHOD_MAPPING_KEY, methodMapping);
}
} catch (DeploymentException e) {
throw new ServletException(sm.getString("upgradeUtil.pojoMapFail", sec.getEndpointClass().getName()), e);
}
}
WsPerSessionServerEndpointConfig perSessionServerEndpointConfig = new WsPerSessionServerEndpointConfig(sec);
WsHandshakeRequest wsRequest = new WsHandshakeRequest(req, pathParams);
WsHandshakeResponse wsResponse = new WsHandshakeResponse();
sec.getConfigurator().modifyHandshake(perSessionServerEndpointConfig, wsRequest, wsResponse);
wsRequest.finished();
// Add any additional headers
for (Entry<String, List<String>> entry : wsResponse.getHeaders().entrySet()) {
for (String headerValue : entry.getValue()) {
resp.addHeader(entry.getKey(), headerValue);
}
}
WsHttpUpgradeHandler wsHandler = req.upgrade(WsHttpUpgradeHandler.class);
wsHandler.preInit(perSessionServerEndpointConfig, sc, wsRequest, negotiatedExtensionsPhase2, subProtocol, transformation, pathParams, req.isSecure());
}
use of jakarta.websocket.Extension in project tomcat by apache.
the class UpgradeUtil method createTransformations.
private static List<Transformation> createTransformations(List<Extension> negotiatedExtensions) {
TransformationFactory factory = TransformationFactory.getInstance();
LinkedHashMap<String, List<List<Extension.Parameter>>> extensionPreferences = new LinkedHashMap<>();
// Result will likely be smaller than this
List<Transformation> result = new ArrayList<>(negotiatedExtensions.size());
for (Extension extension : negotiatedExtensions) {
List<List<Extension.Parameter>> preferences = extensionPreferences.get(extension.getName());
if (preferences == null) {
preferences = new ArrayList<>();
extensionPreferences.put(extension.getName(), preferences);
}
preferences.add(extension.getParameters());
}
for (Map.Entry<String, List<List<Extension.Parameter>>> entry : extensionPreferences.entrySet()) {
Transformation transformation = factory.create(entry.getKey(), entry.getValue(), true);
if (transformation != null) {
result.add(transformation);
}
}
return result;
}
use of jakarta.websocket.Extension in project tomcat by apache.
the class DefaultServerEndpointConfigurator method getNegotiatedExtensions.
@Override
public List<Extension> getNegotiatedExtensions(List<Extension> installed, List<Extension> requested) {
Set<String> installedNames = new HashSet<>();
for (Extension e : installed) {
installedNames.add(e.getName());
}
List<Extension> result = new ArrayList<>();
for (Extension request : requested) {
if (installedNames.contains(request.getName())) {
result.add(request);
}
}
return result;
}
use of jakarta.websocket.Extension in project tomcat by apache.
the class TestWsWebSocketContainer method doTestPerMessageDeflateClient.
private void doTestPerMessageDeflateClient(String msg, int count) throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(TesterEchoServer.Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
tomcat.start();
Extension perMessageDeflate = new WsExtension(PerMessageDeflate.NAME);
List<Extension> extensions = new ArrayList<>(1);
extensions.add(perMessageDeflate);
ClientEndpointConfig clientConfig = ClientEndpointConfig.Builder.create().extensions(extensions).build();
WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();
Session wsSession = wsContainer.connectToServer(TesterProgrammaticEndpoint.class, clientConfig, new URI("ws://" + getHostName() + ":" + getPort() + TesterEchoServer.Config.PATH_ASYNC));
CountDownLatch latch = new CountDownLatch(count);
BasicText handler = new BasicText(latch, msg);
wsSession.addMessageHandler(handler);
for (int i = 0; i < count; i++) {
wsSession.getBasicRemote().sendText(msg);
}
boolean latchResult = handler.getLatch().await(10, TimeUnit.SECONDS);
Assert.assertTrue(latchResult);
((WsWebSocketContainer) wsContainer).destroy();
}
Aggregations