python-learning/rate_limiters/middleware_starlette.py

46 lines
1.3 KiB
Python

from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Route
from starlette.middleware import Middleware
from token_bucket import AsyncTokenBucketLimiter, RateLimitExceeded
class RateLimiterMiddleware:
def __init__(self, app):
self.app = app
self.clients = {}
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
client_data = scope["client"]
client_ip = client_data[0] if client_data else "UNKNOWN"
limiter = self.clients.setdefault(client_ip, AsyncTokenBucketLimiter(10, 10))
try:
async with limiter:
await self.app(scope, receive, send)
except RateLimitExceeded as e:
retry_after = str(e.retry_after)
response = Response(
content=f"Rate limit exceeded! Retry after: {retry_after}",
headers={"Retry-After": retry_after},
status_code=429
)
return await response(scope, receive, send)
def main(request: Request):
return Response(content="Hello there!")
routes = [Route("/", main)]
middleware = [Middleware(RateLimiterMiddleware)]
app = Starlette(routes=routes, middleware=middleware)