Switched to TTL cache for client rate limiter tracking

This commit is contained in:
V 2026-05-09 16:45:36 +01:00
parent d536dde6c9
commit 72348ca5fe
4 changed files with 23 additions and 5 deletions

View File

@ -1,3 +1,6 @@
from typing import Optional
from cachetools import TTLCache
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
@ -6,11 +9,10 @@ from starlette.middleware import Middleware
from token_bucket import AsyncTokenBucketLimiter, RateLimitExceeded from token_bucket import AsyncTokenBucketLimiter, RateLimitExceeded
class RateLimiterMiddleware: class RateLimiterMiddleware:
def __init__(self, app): def __init__(self, app):
self.app = app self.app = app
self.clients = {} self.clients = TTLCache(maxsize=1000, ttl=10)
async def __call__(self, scope, receive, send): async def __call__(self, scope, receive, send):
if scope["type"] != "http": if scope["type"] != "http":
@ -19,7 +21,11 @@ class RateLimiterMiddleware:
client_data = scope["client"] client_data = scope["client"]
client_ip = client_data[0] if client_data else "UNKNOWN" client_ip = client_data[0] if client_data else "UNKNOWN"
limiter: Optional[AsyncTokenBucketLimiter] = self.clients.get(client_ip)
if not limiter:
limiter = self.clients.setdefault(client_ip, AsyncTokenBucketLimiter(10, 10)) limiter = self.clients.setdefault(client_ip, AsyncTokenBucketLimiter(10, 10))
self.clients[client_ip] = limiter
try: try:
async with limiter: async with limiter:

View File

@ -5,6 +5,7 @@ description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.11" requires-python = ">=3.11"
dependencies = [ dependencies = [
"cachetools>=7.1.1",
"httpx>=0.28.1", "httpx>=0.28.1",
"pytest>=9.0.3", "pytest>=9.0.3",
"pytest-asyncio>=1.3.0", "pytest-asyncio>=1.3.0",

View File

@ -1,8 +1,9 @@
from time import sleep from time import sleep
from starlette.applications import Starlette
from starlette.testclient import TestClient from starlette.testclient import TestClient
from middleware_starlette import app from middleware_starlette import app, RateLimiterMiddleware, routes
def test_app_runs(): def test_app_runs():
client = TestClient(app) client = TestClient(app)
@ -46,4 +47,3 @@ def test_above_rate_limit():
assert successful != 0 assert successful != 0
assert rate_limited != 0 assert rate_limited != 0

View File

@ -15,6 +15,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/da/42/e921fccf5015463e32a3cf6ee7f980a6ed0f395ceeaa45060b61d86486c2/anyio-4.13.0-py3-none-any.whl", hash = "sha256:08b310f9e24a9594186fd75b4f73f4a4152069e3853f1ed8bfbf58369f4ad708", size = 114353, upload-time = "2026-03-24T12:59:08.246Z" }, { url = "https://files.pythonhosted.org/packages/da/42/e921fccf5015463e32a3cf6ee7f980a6ed0f395ceeaa45060b61d86486c2/anyio-4.13.0-py3-none-any.whl", hash = "sha256:08b310f9e24a9594186fd75b4f73f4a4152069e3853f1ed8bfbf58369f4ad708", size = 114353, upload-time = "2026-03-24T12:59:08.246Z" },
] ]
[[package]]
name = "cachetools"
version = "7.1.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/ff/e2/85f227594656000ff4d8adadae91a21f536d4a84c6c716a86bd6685874be/cachetools-7.1.1.tar.gz", hash = "sha256:27bdf856d68fd3c71c26c01b5edc312124ed427524d1ddb31aa2b7746fe20d4b", size = 40202, upload-time = "2026-05-03T20:00:29.391Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/bf/0f/f897abe4ea0a8c408ae65c8c83bffab4936ad65d6032d4fb4cd35bbdc3ee/cachetools-7.1.1-py3-none-any.whl", hash = "sha256:0335cd7a0952d2b22327441fb0628139e234c565559eeb91a8a4ac7551c5353d", size = 16775, upload-time = "2026-05-03T20:00:27.857Z" },
]
[[package]] [[package]]
name = "certifi" name = "certifi"
version = "2026.4.22" version = "2026.4.22"
@ -161,6 +170,7 @@ name = "rate-limiters"
version = "0.1.0" version = "0.1.0"
source = { virtual = "." } source = { virtual = "." }
dependencies = [ dependencies = [
{ name = "cachetools" },
{ name = "httpx" }, { name = "httpx" },
{ name = "pytest" }, { name = "pytest" },
{ name = "pytest-asyncio" }, { name = "pytest-asyncio" },
@ -170,6 +180,7 @@ dependencies = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "cachetools", specifier = ">=7.1.1" },
{ name = "httpx", specifier = ">=0.28.1" }, { name = "httpx", specifier = ">=0.28.1" },
{ name = "pytest", specifier = ">=9.0.3" }, { name = "pytest", specifier = ">=9.0.3" },
{ name = "pytest-asyncio", specifier = ">=1.3.0" }, { name = "pytest-asyncio", specifier = ">=1.3.0" },