use of com.nike.riposte.server.http.filter.RequestAndResponseFilter in project riposte by Nike-Inc.
the class HttpChannelInitializerTest method constructor_works_with_valid_args.
@Test
public void constructor_works_with_valid_args() {
// given
SslContext sslCtx = mock(SslContext.class);
int maxRequestSizeInBytes = 42;
Collection<Endpoint<?>> endpoints = Arrays.asList(getMockEndpoint("/some/path", HttpMethod.GET));
RequestAndResponseFilter beforeSecurityRequestFilter = mock(RequestAndResponseFilter.class);
doReturn(true).when(beforeSecurityRequestFilter).shouldExecuteBeforeSecurityValidation();
RequestAndResponseFilter afterSecurityRequestFilter = mock(RequestAndResponseFilter.class);
doReturn(false).when(afterSecurityRequestFilter).shouldExecuteBeforeSecurityValidation();
List<RequestAndResponseFilter> reqResFilters = Arrays.asList(beforeSecurityRequestFilter, afterSecurityRequestFilter);
Executor longRunningTaskExecutor = mock(Executor.class);
RiposteErrorHandler riposteErrorHandler = mock(RiposteErrorHandler.class);
RiposteUnhandledErrorHandler riposteUnhandledErrorHandler = mock(RiposteUnhandledErrorHandler.class);
RequestValidator validationService = mock(RequestValidator.class);
ObjectMapper requestContentDeserializer = mock(ObjectMapper.class);
ResponseSender responseSender = mock(ResponseSender.class);
@SuppressWarnings("unchecked") MetricsListener metricsListener = mock(MetricsListener.class);
long defaultCompletableFutureTimeoutMillis = 4242L;
AccessLogger accessLogger = mock(AccessLogger.class);
List<PipelineCreateHook> pipelineCreateHooks = mock(List.class);
RequestSecurityValidator requestSecurityValidator = mock(RequestSecurityValidator.class);
long workerChannelIdleTimeoutMillis = 121000;
long proxyRouterConnectTimeoutMillis = 4200;
long incompleteHttpCallTimeoutMillis = 1234;
int maxOpenChannelsThreshold = 1000;
boolean debugChannelLifecycleLoggingEnabled = true;
List<String> userIdHeaderKeys = mock(List.class);
// when
HttpChannelInitializer hci = new HttpChannelInitializer(sslCtx, maxRequestSizeInBytes, endpoints, reqResFilters, longRunningTaskExecutor, riposteErrorHandler, riposteUnhandledErrorHandler, validationService, requestContentDeserializer, responseSender, metricsListener, defaultCompletableFutureTimeoutMillis, accessLogger, pipelineCreateHooks, requestSecurityValidator, workerChannelIdleTimeoutMillis, proxyRouterConnectTimeoutMillis, incompleteHttpCallTimeoutMillis, maxOpenChannelsThreshold, debugChannelLifecycleLoggingEnabled, userIdHeaderKeys);
// then
assertThat(extractField(hci, "sslCtx"), is(sslCtx));
assertThat(extractField(hci, "maxRequestSizeInBytes"), is(maxRequestSizeInBytes));
assertThat(extractField(hci, "endpoints"), is(endpoints));
assertThat(extractField(hci, "longRunningTaskExecutor"), is(longRunningTaskExecutor));
assertThat(extractField(hci, "riposteErrorHandler"), is(riposteErrorHandler));
assertThat(extractField(hci, "riposteUnhandledErrorHandler"), is(riposteUnhandledErrorHandler));
assertThat(extractField(hci, "validationService"), is(validationService));
assertThat(extractField(hci, "requestContentDeserializer"), is(requestContentDeserializer));
assertThat(extractField(hci, "responseSender"), is(responseSender));
assertThat(extractField(hci, "metricsListener"), is(metricsListener));
assertThat(extractField(hci, "defaultCompletableFutureTimeoutMillis"), is(defaultCompletableFutureTimeoutMillis));
assertThat(extractField(hci, "accessLogger"), is(accessLogger));
assertThat(extractField(hci, "pipelineCreateHooks"), is(pipelineCreateHooks));
assertThat(extractField(hci, "requestSecurityValidator"), is(requestSecurityValidator));
assertThat(extractField(hci, "workerChannelIdleTimeoutMillis"), is(workerChannelIdleTimeoutMillis));
assertThat(extractField(hci, "incompleteHttpCallTimeoutMillis"), is(incompleteHttpCallTimeoutMillis));
assertThat(extractField(hci, "maxOpenChannelsThreshold"), is(maxOpenChannelsThreshold));
assertThat(extractField(hci, "debugChannelLifecycleLoggingEnabled"), is(debugChannelLifecycleLoggingEnabled));
assertThat(extractField(hci, "userIdHeaderKeys"), is(userIdHeaderKeys));
StreamingAsyncHttpClient sahc = extractField(hci, "streamingAsyncHttpClientForProxyRouterEndpoints");
assertThat(extractField(sahc, "idleChannelTimeoutMillis"), is(workerChannelIdleTimeoutMillis));
assertThat(extractField(sahc, "downstreamConnectionTimeoutMillis"), is((int) proxyRouterConnectTimeoutMillis));
assertThat(extractField(sahc, "debugChannelLifecycleLoggingEnabled"), is(debugChannelLifecycleLoggingEnabled));
RequestFilterHandler beforeSecReqFH = extractField(hci, "beforeSecurityRequestFilterHandler");
assertThat(extractField(beforeSecReqFH, "filters"), is(Collections.singletonList(beforeSecurityRequestFilter)));
RequestFilterHandler afterSecReqFH = extractField(hci, "afterSecurityRequestFilterHandler");
assertThat(extractField(afterSecReqFH, "filters"), is(Collections.singletonList(afterSecurityRequestFilter)));
ResponseFilterHandler resFH = extractField(hci, "cachedResponseFilterHandler");
List<RequestAndResponseFilter> reversedFilters = new ArrayList<>(reqResFilters);
Collections.reverse(reversedFilters);
assertThat(extractField(resFH, "filtersInResponseProcessingOrder"), is(reversedFilters));
}
use of com.nike.riposte.server.http.filter.RequestAndResponseFilter in project riposte by Nike-Inc.
the class HttpChannelInitializerTest method initChannel_adds_before_and_after_RequestFilterHandler_appropriately_before_and_after_security_filter.
@Test
public void initChannel_adds_before_and_after_RequestFilterHandler_appropriately_before_and_after_security_filter() {
// given
RequestAndResponseFilter beforeSecurityRequestFilter = mock(RequestAndResponseFilter.class);
doReturn(true).when(beforeSecurityRequestFilter).shouldExecuteBeforeSecurityValidation();
RequestAndResponseFilter afterSecurityRequestFilter = mock(RequestAndResponseFilter.class);
doReturn(false).when(afterSecurityRequestFilter).shouldExecuteBeforeSecurityValidation();
List<RequestAndResponseFilter> requestAndResponseFilters = Arrays.asList(beforeSecurityRequestFilter, afterSecurityRequestFilter);
HttpChannelInitializer hci = basicHttpChannelInitializer(null, 0, 42, false, null, requestAndResponseFilters);
// when
hci.initChannel(socketChannelMock);
// then
ArgumentCaptor<ChannelHandler> channelHandlerArgumentCaptor = ArgumentCaptor.forClass(ChannelHandler.class);
verify(channelPipelineMock, atLeastOnce()).addLast(anyString(), channelHandlerArgumentCaptor.capture());
List<ChannelHandler> handlers = channelHandlerArgumentCaptor.getAllValues();
Pair<Integer, RequestInfoSetterHandler> requestInfoSetterHandler = findChannelHandler(handlers, RequestInfoSetterHandler.class);
Pair<Integer, RequestFilterHandler> beforeSecurityRequestFilterHandler = findChannelHandler(handlers, RequestFilterHandler.class);
Pair<Integer, RequestFilterHandler> afterSecurityRequestFilterHandler = findChannelHandler(handlers, RequestFilterHandler.class, true);
Pair<Integer, RoutingHandler> routingHandler = findChannelHandler(handlers, RoutingHandler.class);
Pair<Integer, SecurityValidationHandler> securityValidationHandler = findChannelHandler(handlers, SecurityValidationHandler.class);
Pair<Integer, RequestContentDeserializerHandler> requestContentDeserializerHandler = findChannelHandler(handlers, RequestContentDeserializerHandler.class);
assertThat(requestInfoSetterHandler, notNullValue());
assertThat(beforeSecurityRequestFilterHandler, notNullValue());
assertThat(routingHandler, notNullValue());
assertThat(afterSecurityRequestFilterHandler, notNullValue());
assertThat(securityValidationHandler, notNullValue());
assertThat(requestContentDeserializerHandler, notNullValue());
Assertions.assertThat(beforeSecurityRequestFilterHandler.getLeft()).isGreaterThan(requestInfoSetterHandler.getLeft());
Assertions.assertThat(beforeSecurityRequestFilterHandler.getLeft()).isLessThan(routingHandler.getLeft());
Assertions.assertThat(afterSecurityRequestFilterHandler.getLeft()).isGreaterThan(securityValidationHandler.getLeft());
Assertions.assertThat(afterSecurityRequestFilterHandler.getLeft()).isLessThan(requestContentDeserializerHandler.getLeft());
// and then
RequestFilterHandler beforeSecurityCachedHandler = extractField(hci, "beforeSecurityRequestFilterHandler");
Assertions.assertThat(beforeSecurityRequestFilterHandler.getRight()).isSameAs(beforeSecurityCachedHandler);
RequestFilterHandler afterSecurityCachedHandler = extractField(hci, "afterSecurityRequestFilterHandler");
Assertions.assertThat(afterSecurityRequestFilterHandler.getRight()).isSameAs(afterSecurityCachedHandler);
}
use of com.nike.riposte.server.http.filter.RequestAndResponseFilter in project riposte by Nike-Inc.
the class HttpChannelInitializerTest method constructor_handles_empty_after_security_request_handlers.
@Test
public void constructor_handles_empty_after_security_request_handlers() {
// given
RequestAndResponseFilter beforeSecurityRequestFilter = mock(RequestAndResponseFilter.class);
doReturn(true).when(beforeSecurityRequestFilter).shouldExecuteBeforeSecurityValidation();
List<RequestAndResponseFilter> reqResFilters = Arrays.asList(beforeSecurityRequestFilter);
// when
HttpChannelInitializer hci = new HttpChannelInitializer(null, 42, Arrays.asList(getMockEndpoint("/some/path")), reqResFilters, null, mock(RiposteErrorHandler.class), mock(RiposteUnhandledErrorHandler.class), null, null, mock(ResponseSender.class), null, 4242L, null, null, null, 121, 42, 321, 100, false, null);
// then
RequestFilterHandler beforeSecReqFH = extractField(hci, "beforeSecurityRequestFilterHandler");
assertThat(extractField(beforeSecReqFH, "filters"), is(Collections.singletonList(beforeSecurityRequestFilter)));
assertThat(extractField(hci, "afterSecurityRequestFilterHandler"), nullValue());
ResponseFilterHandler responseFilterHandler = extractField(hci, "cachedResponseFilterHandler");
assertThat(extractField(responseFilterHandler, "filtersInResponseProcessingOrder"), is(reqResFilters));
}
use of com.nike.riposte.server.http.filter.RequestAndResponseFilter in project riposte by Nike-Inc.
the class RequestFilterHandler method handleFilterLogic.
protected PipelineContinuationBehavior handleFilterLogic(ChannelHandlerContext ctx, Object msg, BiFunction<RequestAndResponseFilter, RequestInfo, RequestInfo> normalFilterCall, BiFunction<RequestAndResponseFilter, RequestInfo, Pair<RequestInfo, Optional<ResponseInfo<?>>>> shortCircuitFilterCall) {
HttpProcessingState state = ChannelAttributes.getHttpProcessingStateForChannel(ctx).get();
RequestInfo<?> currentReqInfo = state.getRequestInfo();
// Run through each filter.
for (RequestAndResponseFilter filter : filters) {
try {
// See if we're supposed to do short circuit call or not
if (filter.isShortCircuitRequestFilter()) {
Pair<RequestInfo, Optional<ResponseInfo<?>>> result = shortCircuitFilterCall.apply(filter, currentReqInfo);
if (result != null) {
currentReqInfo = requestInfoUpdateNoNulls(currentReqInfo, result.getLeft());
// See if we need to short circuit.
ResponseInfo<?> responseInfo = (result.getRight() == null) ? null : result.getRight().orElse(null);
if (responseInfo != null) {
// full, not chunked.
if (responseInfo.isChunkedResponse()) {
throw new IllegalStateException("RequestAndResponseFilter should never return a " + "chunked ResponseInfo when short circuiting.");
}
state.setRequestInfo(currentReqInfo);
state.setResponseInfo(responseInfo);
// Fire the short-circuit event that will get the desired response info sent to the caller.
ctx.fireChannelRead(LastOutboundMessageSendFullResponseInfo.INSTANCE);
// Tell this event to stop where it is.
return PipelineContinuationBehavior.DO_NOT_FIRE_CONTINUE_EVENT;
}
}
} else {
currentReqInfo = requestInfoUpdateNoNulls(currentReqInfo, normalFilterCall.apply(filter, currentReqInfo));
}
} catch (Throwable ex) {
logger.error("An error occurred while processing a request filter. This error will be ignored and the " + "filtering/processing will continue normally, however this error should be fixed (filters should " + "never throw errors). filter_class={}", filter.getClass().getName(), ex);
}
}
// All the filters have been processed, so set the state to whatever the current request info says.
state.setRequestInfo(currentReqInfo);
// No short circuit if we reach here, so continue normally.
return PipelineContinuationBehavior.CONTINUE;
}
use of com.nike.riposte.server.http.filter.RequestAndResponseFilter in project riposte by Nike-Inc.
the class RequestFilterHandlerTest method handleFilterLogic_does_not_short_circuit_if_responseInfo_is_chunked.
@DataProvider(value = { "true | 0 | true", "true | 0 | false", "true | 1 | true", "true | 1 | false", "false | 0 | true", "false | 0 | false", "false | 1 | true", "false | 1 | false" }, splitBy = "\\|")
@Test
public void handleFilterLogic_does_not_short_circuit_if_responseInfo_is_chunked(boolean isFirstChunk, int shortCircuitingFilterIndex, boolean filterReturnsModifiedRequestInfo) {
// given
HandleFilterLogicMethodCallArgs args = new HandleFilterLogicMethodCallArgs(isFirstChunk);
RequestAndResponseFilter shortCircuitingFilter = filtersList.get(shortCircuitingFilterIndex);
doReturn(true).when(shortCircuitingFilter).isShortCircuitRequestFilter();
RequestInfo<?> modifiedRequestInfoMock = mock(RequestInfo.class);
RequestInfo<?> returnedRequestInfo = (filterReturnsModifiedRequestInfo) ? modifiedRequestInfoMock : null;
ResponseInfo<?> chunkedResponseInfoMock = mock(ResponseInfo.class);
doReturn(true).when(chunkedResponseInfoMock).isChunkedResponse();
doReturn(Pair.of(returnedRequestInfo, Optional.of(chunkedResponseInfoMock))).when(shortCircuitingFilter).filterRequestFirstChunkWithOptionalShortCircuitResponse(any(), any());
doReturn(Pair.of(returnedRequestInfo, Optional.of(chunkedResponseInfoMock))).when(shortCircuitingFilter).filterRequestLastChunkWithOptionalShortCircuitResponse(any(), any());
// when
PipelineContinuationBehavior result = handlerSpy.handleFilterLogic(ctxMock, args.msg, args.normalFilterCall, args.shortCircuitFilterCall);
// then
// Pipeline continues - no short circuit.
assertThat(result).isEqualTo(CONTINUE);
// The filter's short-circuit-capable method was called.
if (isFirstChunk)
verify(shortCircuitingFilter).filterRequestFirstChunkWithOptionalShortCircuitResponse(requestInfoMock, ctxMock);
else
verify(shortCircuitingFilter).filterRequestLastChunkWithOptionalShortCircuitResponse(requestInfoMock, ctxMock);
// The state is updated with the correct RequestInfo
if (filterReturnsModifiedRequestInfo)
assertThat(state.getRequestInfo()).isSameAs(modifiedRequestInfoMock);
else
assertThat(state.getRequestInfo()).isSameAs(requestInfoMock);
// The state is NOT updated with the ResponseInfo returned by the filter.
assertThat(state.getResponseInfo()).isNull();
// The short circuiting "we're all done, return the response to the caller" event is NOT fired down the pipeline.
verify(ctxMock, never()).fireChannelRead(LastOutboundMessageSendFullResponseInfo.INSTANCE);
}
Aggregations