From 72348ca5fe2e125c7c538c858b700738a90bef84 Mon Sep 17 00:00:00 2001 From: V Date: Sat, 9 May 2026 16:45:36 +0100 Subject: [PATCH] Switched to TTL cache for client rate limiter tracking --- rate_limiters/middleware_starlette.py | 12 +++++++++--- rate_limiters/pyproject.toml | 1 + rate_limiters/tests/test_middleware_starlette.py | 4 ++-- rate_limiters/uv.lock | 11 +++++++++++ 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/rate_limiters/middleware_starlette.py b/rate_limiters/middleware_starlette.py index beeacc1..dfef511 100644 --- a/rate_limiters/middleware_starlette.py +++ b/rate_limiters/middleware_starlette.py @@ -1,3 +1,6 @@ +from typing import Optional + +from cachetools import TTLCache from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response @@ -6,11 +9,10 @@ from starlette.middleware import Middleware from token_bucket import AsyncTokenBucketLimiter, RateLimitExceeded - class RateLimiterMiddleware: def __init__(self, app): self.app = app - self.clients = {} + self.clients = TTLCache(maxsize=1000, ttl=10) async def __call__(self, scope, receive, send): if scope["type"] != "http": @@ -19,7 +21,11 @@ class RateLimiterMiddleware: client_data = scope["client"] client_ip = client_data[0] if client_data else "UNKNOWN" - limiter = self.clients.setdefault(client_ip, AsyncTokenBucketLimiter(10, 10)) + limiter: Optional[AsyncTokenBucketLimiter] = self.clients.get(client_ip) + + if not limiter: + limiter = self.clients.setdefault(client_ip, AsyncTokenBucketLimiter(10, 10)) + self.clients[client_ip] = limiter try: async with limiter: diff --git a/rate_limiters/pyproject.toml b/rate_limiters/pyproject.toml index 99952cb..35397be 100644 --- a/rate_limiters/pyproject.toml +++ b/rate_limiters/pyproject.toml @@ -5,6 +5,7 @@ description = "Add your description here" readme = "README.md" requires-python = ">=3.11" dependencies = [ + "cachetools>=7.1.1", "httpx>=0.28.1", "pytest>=9.0.3", "pytest-asyncio>=1.3.0", diff --git a/rate_limiters/tests/test_middleware_starlette.py b/rate_limiters/tests/test_middleware_starlette.py index f147698..e52c5a2 100644 --- a/rate_limiters/tests/test_middleware_starlette.py +++ b/rate_limiters/tests/test_middleware_starlette.py @@ -1,8 +1,9 @@ from time import sleep +from starlette.applications import Starlette from starlette.testclient import TestClient -from middleware_starlette import app +from middleware_starlette import app, RateLimiterMiddleware, routes def test_app_runs(): client = TestClient(app) @@ -46,4 +47,3 @@ def test_above_rate_limit(): assert successful != 0 assert rate_limited != 0 - \ No newline at end of file diff --git a/rate_limiters/uv.lock b/rate_limiters/uv.lock index 6e5ea33..402f32d 100644 --- a/rate_limiters/uv.lock +++ b/rate_limiters/uv.lock @@ -15,6 +15,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/da/42/e921fccf5015463e32a3cf6ee7f980a6ed0f395ceeaa45060b61d86486c2/anyio-4.13.0-py3-none-any.whl", hash = "sha256:08b310f9e24a9594186fd75b4f73f4a4152069e3853f1ed8bfbf58369f4ad708", size = 114353, upload-time = "2026-03-24T12:59:08.246Z" }, ] +[[package]] +name = "cachetools" +version = "7.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ff/e2/85f227594656000ff4d8adadae91a21f536d4a84c6c716a86bd6685874be/cachetools-7.1.1.tar.gz", hash = "sha256:27bdf856d68fd3c71c26c01b5edc312124ed427524d1ddb31aa2b7746fe20d4b", size = 40202, upload-time = "2026-05-03T20:00:29.391Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bf/0f/f897abe4ea0a8c408ae65c8c83bffab4936ad65d6032d4fb4cd35bbdc3ee/cachetools-7.1.1-py3-none-any.whl", hash = "sha256:0335cd7a0952d2b22327441fb0628139e234c565559eeb91a8a4ac7551c5353d", size = 16775, upload-time = "2026-05-03T20:00:27.857Z" }, +] + [[package]] name = "certifi" version = "2026.4.22" @@ -161,6 +170,7 @@ name = "rate-limiters" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "cachetools" }, { name = "httpx" }, { name = "pytest" }, { name = "pytest-asyncio" }, @@ -170,6 +180,7 @@ dependencies = [ [package.metadata] requires-dist = [ + { name = "cachetools", specifier = ">=7.1.1" }, { name = "httpx", specifier = ">=0.28.1" }, { name = "pytest", specifier = ">=9.0.3" }, { name = "pytest-asyncio", specifier = ">=1.3.0" },