use of org.springframework.cloud.gateway.filter.ratelimit.KeyResolver in project spring-cloud-gateway by spring-cloud.
the class RequestRateLimiterGatewayFilterFactory method apply.
@SuppressWarnings("unchecked")
@Override
public GatewayFilter apply(Config config) {
KeyResolver resolver = (config.keyResolver == null) ? defaultKeyResolver : config.keyResolver;
RateLimiter<Object> limiter = (config.rateLimiter == null) ? defaultRateLimiter : config.rateLimiter;
return (exchange, chain) -> {
Route route = exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR);
return resolver.resolve(exchange).flatMap(key -> limiter.isAllowed(route.getId(), key).flatMap(response -> {
if (response.isAllowed()) {
return chain.filter(exchange);
}
exchange.getResponse().setStatusCode(HttpStatus.TOO_MANY_REQUESTS);
return exchange.getResponse().setComplete();
}));
};
}
use of org.springframework.cloud.gateway.filter.ratelimit.KeyResolver in project spring-cloud-gateway by spring-cloud.
the class RequestRateLimiterGatewayFilterFactoryTests method assertFilterFactory.
private void assertFilterFactory(KeyResolver keyResolver, String key, boolean allowed, HttpStatus expectedStatus) {
Tuple args = tuple().build();
when(rateLimiter.isAllowed("myroute", key)).thenReturn(Mono.just(new Response(allowed, 1)));
MockServerHttpRequest request = MockServerHttpRequest.get("/").build();
MockServerWebExchange exchange = MockServerWebExchange.from(request);
exchange.getResponse().setStatusCode(HttpStatus.OK);
exchange.getAttributes().put(ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR, Route.builder().id("myroute").predicate(ex -> true).uri("http://localhost").build());
when(this.filterChain.filter(exchange)).thenReturn(Mono.empty());
RequestRateLimiterGatewayFilterFactory factory = this.context.getBean(RequestRateLimiterGatewayFilterFactory.class);
GatewayFilter filter = factory.apply(config -> config.setKeyResolver(keyResolver));
Mono<Void> response = filter.filter(exchange, this.filterChain);
response.subscribe(aVoid -> assertThat(exchange.getResponse().getStatusCode()).isEqualTo(expectedStatus));
}
Aggregations