Search in sources :

Example 1 with OAuthSecurityContextImpl

use of org.springframework.security.oauth.consumer.OAuthSecurityContextImpl in project spring-security-oauth by spring-projects.

the class OAuthConsumerContextFilter method doFilter.

public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain) throws IOException, ServletException {
    HttpServletRequest request = (HttpServletRequest) servletRequest;
    HttpServletResponse response = (HttpServletResponse) servletResponse;
    OAuthSecurityContextImpl context = new OAuthSecurityContextImpl();
    context.setDetails(request);
    Map<String, OAuthConsumerToken> rememberedTokens = getRememberMeServices().loadRememberedTokens(request, response);
    Map<String, OAuthConsumerToken> accessTokens = new TreeMap<String, OAuthConsumerToken>();
    Map<String, OAuthConsumerToken> requestTokens = new TreeMap<String, OAuthConsumerToken>();
    if (rememberedTokens != null) {
        for (Map.Entry<String, OAuthConsumerToken> tokenEntry : rememberedTokens.entrySet()) {
            OAuthConsumerToken token = tokenEntry.getValue();
            if (token != null) {
                if (token.isAccessToken()) {
                    accessTokens.put(tokenEntry.getKey(), token);
                } else {
                    requestTokens.put(tokenEntry.getKey(), token);
                }
            }
        }
    }
    context.setAccessTokens(accessTokens);
    OAuthSecurityContextHolder.setContext(context);
    if (LOG.isDebugEnabled()) {
        LOG.debug("Storing access tokens in request attribute '" + getAccessTokensRequestAttribute() + "'.");
    }
    try {
        try {
            request.setAttribute(getAccessTokensRequestAttribute(), new ArrayList<OAuthConsumerToken>(accessTokens.values()));
            chain.doFilter(request, response);
        } catch (Exception e) {
            try {
                ProtectedResourceDetails resourceThatNeedsAuthorization = checkForResourceThatNeedsAuthorization(e);
                String neededResourceId = resourceThatNeedsAuthorization.getId();
                while (!accessTokens.containsKey(neededResourceId)) {
                    OAuthConsumerToken token = requestTokens.remove(neededResourceId);
                    if (token == null) {
                        token = getTokenServices().getToken(neededResourceId);
                    }
                    String verifier = request.getParameter(OAuthProviderParameter.oauth_verifier.toString());
                    // if there is NO access token and (we're not using 1.0a or the verifier is not null)
                    if (token == null || (!token.isAccessToken() && (!resourceThatNeedsAuthorization.isUse10a() || verifier == null))) {
                        // if there's a request token, but no verifier, we'll assume that a previous oauth request failed and we need to get a new request token.
                        if (LOG.isDebugEnabled()) {
                            LOG.debug("Obtaining request token for resource: " + neededResourceId);
                        }
                        // obtain authorization.
                        String callbackURL = response.encodeRedirectURL(getCallbackURL(request));
                        token = getConsumerSupport().getUnauthorizedRequestToken(neededResourceId, callbackURL);
                        if (LOG.isDebugEnabled()) {
                            LOG.debug("Request token obtained for resource " + neededResourceId + ": " + token);
                        }
                        // okay, we've got a request token, now we need to authorize it.
                        requestTokens.put(neededResourceId, token);
                        getTokenServices().storeToken(neededResourceId, token);
                        String redirect = getUserAuthorizationRedirectURL(resourceThatNeedsAuthorization, token, callbackURL);
                        if (LOG.isDebugEnabled()) {
                            LOG.debug("Redirecting request to " + redirect + " for user authorization of the request token for resource " + neededResourceId + ".");
                        }
                        request.setAttribute("org.springframework.security.oauth.consumer.AccessTokenRequiredException", e);
                        this.redirectStrategy.sendRedirect(request, response, redirect);
                        return;
                    } else if (!token.isAccessToken()) {
                        // we have a presumably authorized request token, let's try to get an access token with it.
                        if (LOG.isDebugEnabled()) {
                            LOG.debug("Obtaining access token for resource: " + neededResourceId);
                        }
                        // authorize the request token and store it.
                        try {
                            token = getConsumerSupport().getAccessToken(token, verifier);
                        } finally {
                            getTokenServices().removeToken(neededResourceId);
                        }
                        if (LOG.isDebugEnabled()) {
                            LOG.debug("Access token " + token + " obtained for resource " + neededResourceId + ". Now storing and using.");
                        }
                        getTokenServices().storeToken(neededResourceId, token);
                    }
                    accessTokens.put(neededResourceId, token);
                    try {
                        // try again
                        if (!response.isCommitted()) {
                            request.setAttribute(getAccessTokensRequestAttribute(), new ArrayList<OAuthConsumerToken>(accessTokens.values()));
                            chain.doFilter(request, response);
                        } else {
                            // dang. what do we do now?
                            throw new IllegalStateException("Unable to reprocess filter chain with needed OAuth2 resources because the response is already committed.");
                        }
                    } catch (Exception e1) {
                        resourceThatNeedsAuthorization = checkForResourceThatNeedsAuthorization(e1);
                        neededResourceId = resourceThatNeedsAuthorization.getId();
                    }
                }
            } catch (OAuthRequestFailedException eo) {
                fail(request, response, eo);
            } catch (Exception ex) {
                Throwable[] causeChain = getThrowableAnalyzer().determineCauseChain(ex);
                OAuthRequestFailedException rfe = (OAuthRequestFailedException) getThrowableAnalyzer().getFirstThrowableOfType(OAuthRequestFailedException.class, causeChain);
                if (rfe != null) {
                    fail(request, response, rfe);
                } else {
                    // Rethrow ServletExceptions and RuntimeExceptions as-is
                    if (ex instanceof ServletException) {
                        throw (ServletException) ex;
                    } else if (ex instanceof RuntimeException) {
                        throw (RuntimeException) ex;
                    }
                    // Wrap other Exceptions. These are not expected to happen
                    throw new RuntimeException(ex);
                }
            }
        }
    } finally {
        OAuthSecurityContextHolder.setContext(null);
        HashMap<String, OAuthConsumerToken> tokensToRemember = new HashMap<String, OAuthConsumerToken>();
        tokensToRemember.putAll(requestTokens);
        tokensToRemember.putAll(accessTokens);
        getRememberMeServices().rememberTokens(tokensToRemember, request, response);
    }
}
Also used : HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) HttpServletResponse(javax.servlet.http.HttpServletResponse) TreeMap(java.util.TreeMap) OAuthRequestFailedException(org.springframework.security.oauth.consumer.OAuthRequestFailedException) ServletException(javax.servlet.ServletException) AccessTokenRequiredException(org.springframework.security.oauth.consumer.AccessTokenRequiredException) OAuthRequestFailedException(org.springframework.security.oauth.consumer.OAuthRequestFailedException) IOException(java.io.IOException) UnsupportedEncodingException(java.io.UnsupportedEncodingException) OAuthConsumerToken(org.springframework.security.oauth.consumer.OAuthConsumerToken) HttpServletRequest(javax.servlet.http.HttpServletRequest) ServletException(javax.servlet.ServletException) OAuthSecurityContextImpl(org.springframework.security.oauth.consumer.OAuthSecurityContextImpl) HashMap(java.util.HashMap) Map(java.util.Map) TreeMap(java.util.TreeMap) ProtectedResourceDetails(org.springframework.security.oauth.consumer.ProtectedResourceDetails)

Example 2 with OAuthSecurityContextImpl

use of org.springframework.security.oauth.consumer.OAuthSecurityContextImpl in project spring-security-oauth by spring-projects.

the class OAuthClientHttpRequestFactory method createRequest.

public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException {
    OAuthSecurityContext context = OAuthSecurityContextHolder.getContext();
    if (context == null) {
        context = new OAuthSecurityContextImpl();
    }
    Map<String, OAuthConsumerToken> accessTokens = context.getAccessTokens();
    OAuthConsumerToken accessToken = accessTokens == null ? null : accessTokens.get(this.resource.getId());
    boolean useAuthHeader = this.resource.isAcceptsAuthorizationHeader();
    if (!useAuthHeader) {
        String queryString = this.support.getOAuthQueryString(this.resource, accessToken, uri.toURL(), httpMethod.name(), this.additionalOAuthParameters);
        String uriValue = String.valueOf(uri);
        uri = URI.create((uriValue.contains("?") ? uriValue.substring(0, uriValue.indexOf('?')) : uriValue) + "?" + queryString);
    }
    ClientHttpRequest req = delegate.createRequest(uri, httpMethod);
    if (useAuthHeader) {
        String authHeader = this.support.getAuthorizationHeader(this.resource, accessToken, uri.toURL(), httpMethod.name(), this.additionalOAuthParameters);
        req.getHeaders().add("Authorization", authHeader);
    }
    Map<String, String> additionalHeaders = this.resource.getAdditionalRequestHeaders();
    if (additionalHeaders != null) {
        for (Map.Entry<String, String> header : additionalHeaders.entrySet()) {
            req.getHeaders().add(header.getKey(), header.getValue());
        }
    }
    return req;
}
Also used : OAuthSecurityContextImpl(org.springframework.security.oauth.consumer.OAuthSecurityContextImpl) OAuthSecurityContext(org.springframework.security.oauth.consumer.OAuthSecurityContext) ClientHttpRequest(org.springframework.http.client.ClientHttpRequest) HashMap(java.util.HashMap) Map(java.util.Map) OAuthConsumerToken(org.springframework.security.oauth.consumer.OAuthConsumerToken)

Aggregations

HashMap (java.util.HashMap)2 Map (java.util.Map)2 OAuthConsumerToken (org.springframework.security.oauth.consumer.OAuthConsumerToken)2 OAuthSecurityContextImpl (org.springframework.security.oauth.consumer.OAuthSecurityContextImpl)2 IOException (java.io.IOException)1 UnsupportedEncodingException (java.io.UnsupportedEncodingException)1 ArrayList (java.util.ArrayList)1 TreeMap (java.util.TreeMap)1 ServletException (javax.servlet.ServletException)1 HttpServletRequest (javax.servlet.http.HttpServletRequest)1 HttpServletResponse (javax.servlet.http.HttpServletResponse)1 ClientHttpRequest (org.springframework.http.client.ClientHttpRequest)1 AccessTokenRequiredException (org.springframework.security.oauth.consumer.AccessTokenRequiredException)1 OAuthRequestFailedException (org.springframework.security.oauth.consumer.OAuthRequestFailedException)1 OAuthSecurityContext (org.springframework.security.oauth.consumer.OAuthSecurityContext)1 ProtectedResourceDetails (org.springframework.security.oauth.consumer.ProtectedResourceDetails)1