Search in sources :

Example 1 with KeyResolver

use of org.springframework.cloud.gateway.filter.ratelimit.KeyResolver in project spring-cloud-gateway by spring-cloud.

the class RequestRateLimiterGatewayFilterFactory method apply.

@SuppressWarnings("unchecked")
@Override
public GatewayFilter apply(Config config) {
    KeyResolver resolver = (config.keyResolver == null) ? defaultKeyResolver : config.keyResolver;
    RateLimiter<Object> limiter = (config.rateLimiter == null) ? defaultRateLimiter : config.rateLimiter;
    return (exchange, chain) -> {
        Route route = exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR);
        return resolver.resolve(exchange).flatMap(key -> limiter.isAllowed(route.getId(), key).flatMap(response -> {
            if (response.isAllowed()) {
                return chain.filter(exchange);
            }
            exchange.getResponse().setStatusCode(HttpStatus.TOO_MANY_REQUESTS);
            return exchange.getResponse().setComplete();
        }));
    };
}
Also used : HttpStatus(org.springframework.http.HttpStatus) ServerWebExchangeUtils(org.springframework.cloud.gateway.support.ServerWebExchangeUtils) RateLimiter(org.springframework.cloud.gateway.filter.ratelimit.RateLimiter) GatewayFilter(org.springframework.cloud.gateway.filter.GatewayFilter) KeyResolver(org.springframework.cloud.gateway.filter.ratelimit.KeyResolver) Route(org.springframework.cloud.gateway.route.Route) KeyResolver(org.springframework.cloud.gateway.filter.ratelimit.KeyResolver) Route(org.springframework.cloud.gateway.route.Route)

Example 2 with KeyResolver

use of org.springframework.cloud.gateway.filter.ratelimit.KeyResolver in project spring-cloud-gateway by spring-cloud.

the class RequestRateLimiterGatewayFilterFactoryTests method assertFilterFactory.

private void assertFilterFactory(KeyResolver keyResolver, String key, boolean allowed, HttpStatus expectedStatus) {
    Tuple args = tuple().build();
    when(rateLimiter.isAllowed("myroute", key)).thenReturn(Mono.just(new Response(allowed, 1)));
    MockServerHttpRequest request = MockServerHttpRequest.get("/").build();
    MockServerWebExchange exchange = MockServerWebExchange.from(request);
    exchange.getResponse().setStatusCode(HttpStatus.OK);
    exchange.getAttributes().put(ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR, Route.builder().id("myroute").predicate(ex -> true).uri("http://localhost").build());
    when(this.filterChain.filter(exchange)).thenReturn(Mono.empty());
    RequestRateLimiterGatewayFilterFactory factory = this.context.getBean(RequestRateLimiterGatewayFilterFactory.class);
    GatewayFilter filter = factory.apply(config -> config.setKeyResolver(keyResolver));
    Mono<Void> response = filter.filter(exchange, this.filterChain);
    response.subscribe(aVoid -> assertThat(exchange.getResponse().getStatusCode()).isEqualTo(expectedStatus));
}
Also used : Response(org.springframework.cloud.gateway.filter.ratelimit.RateLimiter.Response) DirtiesContext(org.springframework.test.annotation.DirtiesContext) TupleBuilder.tuple(org.springframework.tuple.TupleBuilder.tuple) GatewayFilterChain(org.springframework.cloud.gateway.filter.GatewayFilterChain) ServerWebExchangeUtils(org.springframework.cloud.gateway.support.ServerWebExchangeUtils) Response(org.springframework.cloud.gateway.filter.ratelimit.RateLimiter.Response) RateLimiter(org.springframework.cloud.gateway.filter.ratelimit.RateLimiter) Assertions.assertThat(org.assertj.core.api.Assertions.assertThat) RunWith(org.junit.runner.RunWith) BaseWebClientTests(org.springframework.cloud.gateway.test.BaseWebClientTests) Autowired(org.springframework.beans.factory.annotation.Autowired) Qualifier(org.springframework.beans.factory.annotation.Qualifier) KeyResolver(org.springframework.cloud.gateway.filter.ratelimit.KeyResolver) SpringRunner(org.springframework.test.context.junit4.SpringRunner) RANDOM_PORT(org.springframework.boot.test.context.SpringBootTest.WebEnvironment.RANDOM_PORT) MockBean(org.springframework.boot.test.mock.mockito.MockBean) EnableAutoConfiguration(org.springframework.boot.autoconfigure.EnableAutoConfiguration) MockServerHttpRequest(org.springframework.mock.http.server.reactive.MockServerHttpRequest) Import(org.springframework.context.annotation.Import) Test(org.junit.Test) Mono(reactor.core.publisher.Mono) Mockito.when(org.mockito.Mockito.when) ApplicationContext(org.springframework.context.ApplicationContext) HttpStatus(org.springframework.http.HttpStatus) SpringBootTest(org.springframework.boot.test.context.SpringBootTest) Tuple(org.springframework.tuple.Tuple) SpringBootConfiguration(org.springframework.boot.SpringBootConfiguration) GatewayFilter(org.springframework.cloud.gateway.filter.GatewayFilter) MockServerWebExchange(org.springframework.mock.web.server.MockServerWebExchange) Bean(org.springframework.context.annotation.Bean) Route(org.springframework.cloud.gateway.route.Route) MockServerHttpRequest(org.springframework.mock.http.server.reactive.MockServerHttpRequest) MockServerWebExchange(org.springframework.mock.web.server.MockServerWebExchange) Tuple(org.springframework.tuple.Tuple) GatewayFilter(org.springframework.cloud.gateway.filter.GatewayFilter)

Aggregations

GatewayFilter (org.springframework.cloud.gateway.filter.GatewayFilter)2 KeyResolver (org.springframework.cloud.gateway.filter.ratelimit.KeyResolver)2 RateLimiter (org.springframework.cloud.gateway.filter.ratelimit.RateLimiter)2 Route (org.springframework.cloud.gateway.route.Route)2 ServerWebExchangeUtils (org.springframework.cloud.gateway.support.ServerWebExchangeUtils)2 HttpStatus (org.springframework.http.HttpStatus)2 Assertions.assertThat (org.assertj.core.api.Assertions.assertThat)1 Test (org.junit.Test)1 RunWith (org.junit.runner.RunWith)1 Mockito.when (org.mockito.Mockito.when)1 Autowired (org.springframework.beans.factory.annotation.Autowired)1 Qualifier (org.springframework.beans.factory.annotation.Qualifier)1 SpringBootConfiguration (org.springframework.boot.SpringBootConfiguration)1 EnableAutoConfiguration (org.springframework.boot.autoconfigure.EnableAutoConfiguration)1 SpringBootTest (org.springframework.boot.test.context.SpringBootTest)1 RANDOM_PORT (org.springframework.boot.test.context.SpringBootTest.WebEnvironment.RANDOM_PORT)1 MockBean (org.springframework.boot.test.mock.mockito.MockBean)1 GatewayFilterChain (org.springframework.cloud.gateway.filter.GatewayFilterChain)1 Response (org.springframework.cloud.gateway.filter.ratelimit.RateLimiter.Response)1 BaseWebClientTests (org.springframework.cloud.gateway.test.BaseWebClientTests)1