python-learning/rate_limiters/middleware_starlette.py

53 lines
1.6 KiB
Python
Raw Normal View History

from typing import Optional
from cachetools import TTLCache
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Route
from starlette.middleware import Middleware
from token_bucket import AsyncTokenBucketLimiter, RateLimitExceeded
class RateLimiterMiddleware:
def __init__(self, app):
self.app = app
self.clients = TTLCache(maxsize=1000, ttl=10)
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
client_data = scope["client"]
client_ip = client_data[0] if client_data else "UNKNOWN"
limiter: Optional[AsyncTokenBucketLimiter] = self.clients.get(client_ip)
if not limiter:
limiter = self.clients.setdefault(client_ip, AsyncTokenBucketLimiter(10, 10))
try:
async with limiter:
await self.app(scope, receive, send)
except RateLimitExceeded as e:
retry_after = str(e.retry_after)
response = Response(
content=f"Rate limit exceeded! Retry after: {retry_after}",
headers={"Retry-After": retry_after},
status_code=429
)
return await response(scope, receive, send)
def main(request: Request):
return Response(content="Hello there!")
routes = [Route("/", main)]
middleware = [Middleware(RateLimiterMiddleware)]
2026-05-09 15:50:14 +00:00
def create_app() -> Starlette:
return Starlette(routes=routes, middleware=middleware)