from typing import Optional from cachetools import TTLCache from starlette.applications import Starlette from starlette.datastructures import Headers from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import Response from starlette.routing import Route from token_bucket import AsyncTokenBucketLimiter, RateLimitExceeded class RateLimiterMiddleware: def __init__(self, app): self.app = app self.clients = TTLCache(maxsize=1000, ttl=10) async def __call__(self, scope, receive, send): if scope["type"] != "http": await self.app(scope, receive, send) # Check for "X-Forwarded-For" header - as passed by reverse proxies. # Need it for testing headers = Headers(raw=scope["headers"]) if (client_ip := headers.get("X-Forwarded-For")) is None: client_data = scope["client"] client_ip = client_data[0] if client_data else "UNKNOWN" limiter: Optional[AsyncTokenBucketLimiter] = self.clients.get(client_ip) if not limiter: limiter = self.clients.setdefault( client_ip, AsyncTokenBucketLimiter(10, 10) ) try: async with limiter: await self.app(scope, receive, send) except RateLimitExceeded as e: retry_after = str(e.retry_after) response = Response( content=f"Rate limit exceeded! Retry after: {retry_after}", headers={"Retry-After": retry_after}, status_code=429, ) return await response(scope, receive, send) def main(request: Request): return Response(content="Hello there!") routes = [Route("/", main)] middleware = [Middleware(RateLimiterMiddleware)] def create_app() -> Starlette: return Starlette(routes=routes, middleware=middleware)