from time import sleep from cachetools import TTLCache from starlette.applications import Starlette from starlette.testclient import TestClient from middleware_starlette import RateLimiterMiddleware, create_app, routes def test_app_runs(): app = create_app() client = TestClient(app) response = client.get("/") assert response.status_code == 200 assert "Hello there" in response.text def test_below_rate_limit(): app = create_app() client = TestClient(app) # Limit is 10/s for i in range(20): response = client.get("/") assert response.status_code == 200 assert "Hello there" in response.text sleep(0.1) def test_above_rate_limit(): app = create_app() client = TestClient(app) successful = 0 rate_limited = 0 # Limit is 10/s for i in range(30): response = client.get("/") if response.status_code != 200: rate_limited += 1 assert response.status_code == 429 assert "Rate limit exceeded" in response.text assert response.headers.get("retry-after") is not None else: successful += 1 assert successful != 0 assert rate_limited != 0 def find_middleware(stack, middleware_cls): current = stack while hasattr(current, "app"): if isinstance(current, middleware_cls): return current current = current.app return None def test_rate_limiter_cache(): app = create_app() client = TestClient(app) client.get("/") middleware = find_middleware( app.middleware_stack, RateLimiterMiddleware, ) assert middleware is not None assert "testclient" in middleware.clients assert len(middleware.clients) == 1 def test_cache_ttl(): # Setup app with a short TTL class ShortTTLLimiterMiddleware(RateLimiterMiddleware): def __init__(self, app): super().__init__(app) self.clients = TTLCache(maxsize=10, ttl=2) # Clear after 2 seconds app = Starlette(routes=routes) app.add_middleware(ShortTTLLimiterMiddleware) client = TestClient(app) client.get("/") middleware = find_middleware( app.middleware_stack, RateLimiterMiddleware, ) assert middleware is not None client.get("/", headers={"X-Forwarded-For": "1.1.1.1"}) assert middleware.clients.get("1.1.1.1") is not None sleep(2) assert middleware.clients.get("1.1.1.1") is None def test_cache_eviction(): # Setup app with a tiny cache for testing class TinyLimiterMiddleware(RateLimiterMiddleware): def __init__(self, app): super().__init__(app) self.clients = TTLCache(maxsize=2, ttl=10) # Only space for 2 IPs app = Starlette(routes=routes) app.add_middleware(TinyLimiterMiddleware) client = TestClient(app) client.get("/") middleware = find_middleware( app.middleware_stack, RateLimiterMiddleware, ) assert middleware is not None client.get("/", headers={"X-Forwarded-For": "1.1.1.1"}) client.get("/", headers={"X-Forwarded-For": "2.2.2.2"}) assert middleware.clients.get("1.1.1.1") is not None assert middleware.clients.get("2.2.2.2") is not None assert middleware.clients.get("3.3.3.3") is None client.get("/", headers={"X-Forwarded-For": "3.3.3.3"}) assert middleware.clients.get("1.1.1.1") is None assert middleware.clients.get("2.2.2.2") is not None assert middleware.clients.get("3.3.3.3") is not None