Switched to TTL cache for client rate limiter tracking
This commit is contained in:
parent
d536dde6c9
commit
72348ca5fe
@ -1,3 +1,6 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from cachetools import TTLCache
|
||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
@ -6,11 +9,10 @@ from starlette.middleware import Middleware
|
|||||||
|
|
||||||
from token_bucket import AsyncTokenBucketLimiter, RateLimitExceeded
|
from token_bucket import AsyncTokenBucketLimiter, RateLimitExceeded
|
||||||
|
|
||||||
|
|
||||||
class RateLimiterMiddleware:
|
class RateLimiterMiddleware:
|
||||||
def __init__(self, app):
|
def __init__(self, app):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.clients = {}
|
self.clients = TTLCache(maxsize=1000, ttl=10)
|
||||||
|
|
||||||
async def __call__(self, scope, receive, send):
|
async def __call__(self, scope, receive, send):
|
||||||
if scope["type"] != "http":
|
if scope["type"] != "http":
|
||||||
@ -19,7 +21,11 @@ class RateLimiterMiddleware:
|
|||||||
client_data = scope["client"]
|
client_data = scope["client"]
|
||||||
client_ip = client_data[0] if client_data else "UNKNOWN"
|
client_ip = client_data[0] if client_data else "UNKNOWN"
|
||||||
|
|
||||||
limiter = self.clients.setdefault(client_ip, AsyncTokenBucketLimiter(10, 10))
|
limiter: Optional[AsyncTokenBucketLimiter] = self.clients.get(client_ip)
|
||||||
|
|
||||||
|
if not limiter:
|
||||||
|
limiter = self.clients.setdefault(client_ip, AsyncTokenBucketLimiter(10, 10))
|
||||||
|
self.clients[client_ip] = limiter
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with limiter:
|
async with limiter:
|
||||||
|
|||||||
@ -5,6 +5,7 @@ description = "Add your description here"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.11"
|
requires-python = ">=3.11"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"cachetools>=7.1.1",
|
||||||
"httpx>=0.28.1",
|
"httpx>=0.28.1",
|
||||||
"pytest>=9.0.3",
|
"pytest>=9.0.3",
|
||||||
"pytest-asyncio>=1.3.0",
|
"pytest-asyncio>=1.3.0",
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
from time import sleep
|
from time import sleep
|
||||||
|
|
||||||
|
from starlette.applications import Starlette
|
||||||
from starlette.testclient import TestClient
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
from middleware_starlette import app
|
from middleware_starlette import app, RateLimiterMiddleware, routes
|
||||||
|
|
||||||
def test_app_runs():
|
def test_app_runs():
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
@ -46,4 +47,3 @@ def test_above_rate_limit():
|
|||||||
assert successful != 0
|
assert successful != 0
|
||||||
assert rate_limited != 0
|
assert rate_limited != 0
|
||||||
|
|
||||||
|
|
||||||
@ -15,6 +15,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/da/42/e921fccf5015463e32a3cf6ee7f980a6ed0f395ceeaa45060b61d86486c2/anyio-4.13.0-py3-none-any.whl", hash = "sha256:08b310f9e24a9594186fd75b4f73f4a4152069e3853f1ed8bfbf58369f4ad708", size = 114353, upload-time = "2026-03-24T12:59:08.246Z" },
|
{ url = "https://files.pythonhosted.org/packages/da/42/e921fccf5015463e32a3cf6ee7f980a6ed0f395ceeaa45060b61d86486c2/anyio-4.13.0-py3-none-any.whl", hash = "sha256:08b310f9e24a9594186fd75b4f73f4a4152069e3853f1ed8bfbf58369f4ad708", size = 114353, upload-time = "2026-03-24T12:59:08.246Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cachetools"
|
||||||
|
version = "7.1.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/ff/e2/85f227594656000ff4d8adadae91a21f536d4a84c6c716a86bd6685874be/cachetools-7.1.1.tar.gz", hash = "sha256:27bdf856d68fd3c71c26c01b5edc312124ed427524d1ddb31aa2b7746fe20d4b", size = 40202, upload-time = "2026-05-03T20:00:29.391Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/bf/0f/f897abe4ea0a8c408ae65c8c83bffab4936ad65d6032d4fb4cd35bbdc3ee/cachetools-7.1.1-py3-none-any.whl", hash = "sha256:0335cd7a0952d2b22327441fb0628139e234c565559eeb91a8a4ac7551c5353d", size = 16775, upload-time = "2026-05-03T20:00:27.857Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "certifi"
|
name = "certifi"
|
||||||
version = "2026.4.22"
|
version = "2026.4.22"
|
||||||
@ -161,6 +170,7 @@ name = "rate-limiters"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
source = { virtual = "." }
|
source = { virtual = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
{ name = "cachetools" },
|
||||||
{ name = "httpx" },
|
{ name = "httpx" },
|
||||||
{ name = "pytest" },
|
{ name = "pytest" },
|
||||||
{ name = "pytest-asyncio" },
|
{ name = "pytest-asyncio" },
|
||||||
@ -170,6 +180,7 @@ dependencies = [
|
|||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
|
{ name = "cachetools", specifier = ">=7.1.1" },
|
||||||
{ name = "httpx", specifier = ">=0.28.1" },
|
{ name = "httpx", specifier = ">=0.28.1" },
|
||||||
{ name = "pytest", specifier = ">=9.0.3" },
|
{ name = "pytest", specifier = ">=9.0.3" },
|
||||||
{ name = "pytest-asyncio", specifier = ">=1.3.0" },
|
{ name = "pytest-asyncio", specifier = ">=1.3.0" },
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user