python-learning/rate_limiters/middleware_starlette.py

63 lines
1.9 KiB
Python
Raw Permalink Normal View History

from typing import Optional
from cachetools import TTLCache
from starlette.applications import Starlette
2026-05-09 20:17:23 +00:00
from starlette.datastructures import Headers
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Route
from token_bucket import AsyncTokenBucketLimiter, RateLimitExceeded
2026-05-09 20:17:23 +00:00
class RateLimiterMiddleware:
def __init__(self, app):
self.app = app
self.clients = TTLCache(maxsize=1000, ttl=10)
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
2026-05-09 20:17:23 +00:00
# Check for "X-Forwarded-For" header - as passed by reverse proxies.
# Need it for testing
headers = Headers(raw=scope["headers"])
if (client_ip := headers.get("X-Forwarded-For")) is None:
client_data = scope["client"]
client_ip = client_data[0] if client_data else "UNKNOWN"
limiter: Optional[AsyncTokenBucketLimiter] = self.clients.get(client_ip)
if not limiter:
2026-05-09 20:17:23 +00:00
limiter = self.clients.setdefault(
client_ip, AsyncTokenBucketLimiter(10, 10)
)
try:
async with limiter:
await self.app(scope, receive, send)
except RateLimitExceeded as e:
retry_after = str(e.retry_after)
response = Response(
content=f"Rate limit exceeded! Retry after: {retry_after}",
headers={"Retry-After": retry_after},
2026-05-09 20:17:23 +00:00
status_code=429,
)
return await response(scope, receive, send)
def main(request: Request):
return Response(content="Hello there!")
routes = [Route("/", main)]
middleware = [Middleware(RateLimiterMiddleware)]
2026-05-09 15:50:14 +00:00
def create_app() -> Starlette:
return Starlette(routes=routes, middleware=middleware)