2026-05-09 15:35:03 +00:00
|
|
|
from time import sleep
|
|
|
|
|
|
2026-05-09 20:17:23 +00:00
|
|
|
from cachetools import TTLCache
|
|
|
|
|
from starlette.applications import Starlette
|
2026-05-09 15:35:03 +00:00
|
|
|
from starlette.testclient import TestClient
|
|
|
|
|
|
2026-05-09 20:17:23 +00:00
|
|
|
from middleware_starlette import RateLimiterMiddleware, create_app, routes
|
|
|
|
|
|
2026-05-09 15:35:03 +00:00
|
|
|
|
|
|
|
|
def test_app_runs():
|
2026-05-09 20:17:23 +00:00
|
|
|
app = create_app()
|
|
|
|
|
|
2026-05-09 15:35:03 +00:00
|
|
|
client = TestClient(app)
|
|
|
|
|
|
|
|
|
|
response = client.get("/")
|
|
|
|
|
|
|
|
|
|
assert response.status_code == 200
|
|
|
|
|
assert "Hello there" in response.text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_below_rate_limit():
|
2026-05-09 20:17:23 +00:00
|
|
|
app = create_app()
|
|
|
|
|
|
2026-05-09 15:35:03 +00:00
|
|
|
client = TestClient(app)
|
|
|
|
|
|
|
|
|
|
# Limit is 10/s
|
|
|
|
|
for i in range(20):
|
|
|
|
|
response = client.get("/")
|
|
|
|
|
|
|
|
|
|
assert response.status_code == 200
|
|
|
|
|
assert "Hello there" in response.text
|
|
|
|
|
sleep(0.1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_above_rate_limit():
|
2026-05-09 20:17:23 +00:00
|
|
|
app = create_app()
|
|
|
|
|
|
2026-05-09 15:35:03 +00:00
|
|
|
client = TestClient(app)
|
|
|
|
|
|
|
|
|
|
successful = 0
|
|
|
|
|
rate_limited = 0
|
|
|
|
|
|
|
|
|
|
# Limit is 10/s
|
|
|
|
|
for i in range(30):
|
|
|
|
|
response = client.get("/")
|
|
|
|
|
|
|
|
|
|
if response.status_code != 200:
|
|
|
|
|
rate_limited += 1
|
|
|
|
|
assert response.status_code == 429
|
|
|
|
|
assert "Rate limit exceeded" in response.text
|
|
|
|
|
assert response.headers.get("retry-after") is not None
|
|
|
|
|
else:
|
|
|
|
|
successful += 1
|
|
|
|
|
|
|
|
|
|
assert successful != 0
|
|
|
|
|
assert rate_limited != 0
|
|
|
|
|
|
2026-05-09 20:17:23 +00:00
|
|
|
|
|
|
|
|
def find_middleware(stack, middleware_cls):
|
|
|
|
|
current = stack
|
|
|
|
|
|
|
|
|
|
while hasattr(current, "app"):
|
|
|
|
|
if isinstance(current, middleware_cls):
|
|
|
|
|
return current
|
|
|
|
|
current = current.app
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_rate_limiter_cache():
|
|
|
|
|
app = create_app()
|
|
|
|
|
|
|
|
|
|
client = TestClient(app)
|
|
|
|
|
|
|
|
|
|
client.get("/")
|
|
|
|
|
|
|
|
|
|
middleware = find_middleware(
|
|
|
|
|
app.middleware_stack,
|
|
|
|
|
RateLimiterMiddleware,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert middleware is not None
|
|
|
|
|
|
|
|
|
|
assert "testclient" in middleware.clients
|
|
|
|
|
assert len(middleware.clients) == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_cache_ttl():
|
|
|
|
|
# Setup app with a short TTL
|
|
|
|
|
class ShortTTLLimiterMiddleware(RateLimiterMiddleware):
|
|
|
|
|
def __init__(self, app):
|
|
|
|
|
super().__init__(app)
|
|
|
|
|
self.clients = TTLCache(maxsize=10, ttl=2) # Clear after 2 seconds
|
|
|
|
|
|
|
|
|
|
app = Starlette(routes=routes)
|
|
|
|
|
app.add_middleware(ShortTTLLimiterMiddleware)
|
|
|
|
|
|
|
|
|
|
client = TestClient(app)
|
|
|
|
|
|
|
|
|
|
client.get("/")
|
|
|
|
|
|
|
|
|
|
middleware = find_middleware(
|
|
|
|
|
app.middleware_stack,
|
|
|
|
|
RateLimiterMiddleware,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert middleware is not None
|
|
|
|
|
|
|
|
|
|
client.get("/", headers={"X-Forwarded-For": "1.1.1.1"})
|
|
|
|
|
|
|
|
|
|
assert middleware.clients.get("1.1.1.1") is not None
|
|
|
|
|
|
|
|
|
|
sleep(2)
|
|
|
|
|
|
|
|
|
|
assert middleware.clients.get("1.1.1.1") is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_cache_eviction():
|
|
|
|
|
# Setup app with a tiny cache for testing
|
|
|
|
|
class TinyLimiterMiddleware(RateLimiterMiddleware):
|
|
|
|
|
def __init__(self, app):
|
|
|
|
|
super().__init__(app)
|
|
|
|
|
self.clients = TTLCache(maxsize=2, ttl=10) # Only space for 2 IPs
|
|
|
|
|
|
|
|
|
|
app = Starlette(routes=routes)
|
|
|
|
|
app.add_middleware(TinyLimiterMiddleware)
|
|
|
|
|
|
|
|
|
|
client = TestClient(app)
|
|
|
|
|
|
|
|
|
|
client.get("/")
|
|
|
|
|
|
|
|
|
|
middleware = find_middleware(
|
|
|
|
|
app.middleware_stack,
|
|
|
|
|
RateLimiterMiddleware,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert middleware is not None
|
|
|
|
|
|
|
|
|
|
client.get("/", headers={"X-Forwarded-For": "1.1.1.1"})
|
|
|
|
|
client.get("/", headers={"X-Forwarded-For": "2.2.2.2"})
|
|
|
|
|
|
|
|
|
|
assert middleware.clients.get("1.1.1.1") is not None
|
|
|
|
|
assert middleware.clients.get("2.2.2.2") is not None
|
|
|
|
|
assert middleware.clients.get("3.3.3.3") is None
|
|
|
|
|
|
|
|
|
|
client.get("/", headers={"X-Forwarded-For": "3.3.3.3"})
|
|
|
|
|
|
|
|
|
|
assert middleware.clients.get("1.1.1.1") is None
|
|
|
|
|
assert middleware.clients.get("2.2.2.2") is not None
|
|
|
|
|
assert middleware.clients.get("3.3.3.3") is not None
|