use of com.disney.groovity.servlet.cors.CORSProcessor in project groovity by disney.
the class WebSocketAuthFilter method doFilter.
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {
if (factory == null) {
this.factory = (GroovityScriptViewFactory) req.getServletContext().getAttribute(GroovityServlet.SERVLET_CONTEXT_GROOVITY_VIEW_FACTORY);
}
if (req instanceof HttpServletRequest) {
HttpServletRequest hreq = (HttpServletRequest) req;
HttpServletResponse hres = (HttpServletResponse) res;
if ("websocket".equalsIgnoreCase(hreq.getHeader(UPGRADE_HEADER))) {
String requestPath = hreq.getPathInfo();
if (requestPath == null) {
// when running as default servlet fall back
requestPath = hreq.getServletPath();
}
String socketName = null;
if (requestPath.startsWith("ws/")) {
socketName = requestPath.substring(3);
} else if (requestPath.startsWith("/ws/")) {
socketName = requestPath.substring(4);
}
if (socketName != null) {
if (log.isLoggable(Level.FINE)) {
log.fine("VALIDATING WEB SOCKET REQUEST for socket " + socketName + " " + hreq.getHeader("authorization"));
}
try {
GroovityScriptView gsv = factory.getSocketByName(socketName);
if (gsv != null) {
if (gsv.getVerifier() != null) {
VerifierResult vf = gsv.getVerifier().verify(new ServletAuthorizationRequest(hreq));
if (vf.getAuthenticationInfo() != null) {
hres.setHeader(AuthConstants.AUTHENTICATION_INFO, vf.getAuthenticationInfo());
}
if (vf.isAuthenticated()) {
if (vf.isAuthorized()) {
if (vf.getPrincipal() != null) {
hreq = new AuthenticatedRequestWrapper(hreq, vf.getPrincipal());
}
} else {
if (log.isLoggable(Level.FINE)) {
log.fine("Verification failed 403 " + vf.getMessage() + ", challenge " + vf.getChallenge());
}
hres.sendError(403, vf.getMessage());
return;
}
} else {
if (vf.getChallenge() != null) {
hres.setHeader(AuthConstants.WWW_AUTHENTICATE_HEADER, vf.getChallenge());
}
if (log.isLoggable(Level.FINE)) {
log.fine("Verification failed 401 " + vf.getMessage() + ", challenge " + vf.getChallenge());
}
hres.sendError(401, vf.getMessage());
return;
}
if (log.isLoggable(Level.FINE)) {
log.fine("Verification succeeded for " + vf.getPrincipal());
}
}
String origin = hreq.getHeader(ORIGIN_HEADER);
String host = hreq.getHeader(HOST_HEADER);
if (hreq.isSecure()) {
host = "https://".concat(host);
} else {
host = "http://".concat(host);
}
if (host.equals(origin)) {
// default CORS behavior, allow same-origin requests
if (log.isLoggable(Level.FINE)) {
log.fine("WebSocket Origin " + origin + " matches host " + host);
}
} else {
AtomicBoolean allowed = new AtomicBoolean(false);
CORSProcessor cp = gsv.getCORSProcessor();
if (cp != null) {
cp.process(hreq, new HttpServletResponseWrapper(hres) {
public void setHeader(String name, String value) {
if (ACCESS_CONTROL_ALLOW_ORIGIN.equals(name)) {
allowed.set(true);
}
super.setHeader(name, value);
}
});
}
if (!allowed.get()) {
if (log.isLoggable(Level.FINE)) {
log.fine("Disallowing websocket due to cors violation from " + origin + " to host " + host);
}
hres.sendError(403, "Origin not allowed");
return;
}
}
}
} catch (Exception e) {
throw new ServletException(e);
}
}
}
chain.doFilter(hreq, hres);
}
}
Aggregations