python-learning/rate_limiters/middleware_starlette.py

63 lines
1.9 KiB
Python

from typing import Optional
from cachetools import TTLCache
from starlette.applications import Starlette
from starlette.datastructures import Headers
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Route
from token_bucket import AsyncTokenBucketLimiter, RateLimitExceeded
class RateLimiterMiddleware:
def __init__(self, app):
self.app = app
self.clients = TTLCache(maxsize=1000, ttl=10)
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
# Check for "X-Forwarded-For" header - as passed by reverse proxies.
# Need it for testing
headers = Headers(raw=scope["headers"])
if (client_ip := headers.get("X-Forwarded-For")) is None:
client_data = scope["client"]
client_ip = client_data[0] if client_data else "UNKNOWN"
limiter: Optional[AsyncTokenBucketLimiter] = self.clients.get(client_ip)
if not limiter:
limiter = self.clients.setdefault(
client_ip, AsyncTokenBucketLimiter(10, 10)
)
try:
async with limiter:
await self.app(scope, receive, send)
except RateLimitExceeded as e:
retry_after = str(e.retry_after)
response = Response(
content=f"Rate limit exceeded! Retry after: {retry_after}",
headers={"Retry-After": retry_after},
status_code=429,
)
return await response(scope, receive, send)
def main(request: Request):
return Response(content="Hello there!")
routes = [Route("/", main)]
middleware = [Middleware(RateLimiterMiddleware)]
def create_app() -> Starlette:
return Starlette(routes=routes, middleware=middleware)