46 lines
1.3 KiB
Python
46 lines
1.3 KiB
Python
from starlette.applications import Starlette
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
from starlette.routing import Route
|
|
from starlette.middleware import Middleware
|
|
|
|
from token_bucket import AsyncTokenBucketLimiter, RateLimitExceeded
|
|
|
|
|
|
class RateLimiterMiddleware:
|
|
def __init__(self, app):
|
|
self.app = app
|
|
self.clients = {}
|
|
|
|
async def __call__(self, scope, receive, send):
|
|
if scope["type"] != "http":
|
|
await self.app(scope, receive, send)
|
|
|
|
client_data = scope["client"]
|
|
client_ip = client_data[0] if client_data else "UNKNOWN"
|
|
|
|
limiter = self.clients.setdefault(client_ip, AsyncTokenBucketLimiter(10, 10))
|
|
|
|
try:
|
|
async with limiter:
|
|
await self.app(scope, receive, send)
|
|
except RateLimitExceeded as e:
|
|
retry_after = str(e.retry_after)
|
|
response = Response(
|
|
content=f"Rate limit exceeded! Retry after: {retry_after}",
|
|
headers={"Retry-After": retry_after},
|
|
status_code=429
|
|
)
|
|
return await response(scope, receive, send)
|
|
|
|
|
|
def main(request: Request):
|
|
return Response(content="Hello there!")
|
|
|
|
|
|
routes = [Route("/", main)]
|
|
|
|
middleware = [Middleware(RateLimiterMiddleware)]
|
|
|
|
app = Starlette(routes=routes, middleware=middleware)
|