Added some more tests + formatted code

This commit is contained in:
V 2026-05-09 21:17:23 +01:00
parent 8fccddc815
commit be71e97956
10 changed files with 159 additions and 37 deletions

View File

@ -39,6 +39,7 @@ def make_requests(number: int, limiter: FixedWindowLimiter) -> None:
print("Request limit reached!") print("Request limit reached!")
sleep(0.5) sleep(0.5)
# lim_1 = FixedWindowLimiter(window_duration=10, window_size=10) # lim_1 = FixedWindowLimiter(window_duration=10, window_size=10)
# make_requests(100, lim_1) # make_requests(100, lim_1)

View File

@ -15,19 +15,16 @@ class LeakyBucketLimiter:
self.metrics = { self.metrics = {
"processed_count": 0, "processed_count": 0,
"total_dwell_time": 0.0, # Total time tasks waited in queue "total_dwell_time": 0.0, # Total time tasks waited in queue
"dropped_tasks": 0 "dropped_tasks": 0,
} }
def _calculate_delay(self): def _calculate_delay(self):
self.delay = 1.0 / self.rate self.delay = 1.0 / self.rate
@property @property
def queue_depth(self) -> int: def queue_depth(self) -> int:
return self._queue.qsize() return self._queue.qsize()
async def start(self): async def start(self):
if not self._running: if not self._running:
self._running = True self._running = True
@ -35,14 +32,12 @@ class LeakyBucketLimiter:
self._monitor_task = asyncio.create_task(self._monitor()) self._monitor_task = asyncio.create_task(self._monitor())
print("Limiter started!") print("Limiter started!")
async def stop(self): async def stop(self):
self._running = False self._running = False
if self._worker_task: if self._worker_task:
await self._worker_task await self._worker_task
print("Limiter stopped!") print("Limiter stopped!")
async def set_rate(self, new_rate: float): async def set_rate(self, new_rate: float):
if new_rate <= 0: if new_rate <= 0:
raise ValueError("Rate must be a positive number!") raise ValueError("Rate must be a positive number!")
@ -51,12 +46,14 @@ class LeakyBucketLimiter:
self.rate = new_rate self.rate = new_rate
self._calculate_delay() self._calculate_delay()
async def _monitor(self): async def _monitor(self):
while self._running: while self._running:
await asyncio.sleep(2) await asyncio.sleep(2)
avg_dwell = (self.metrics["total_dwell_time"] / self.metrics["processed_count"] avg_dwell = (
if self.metrics["processed_count"] > 0 else 0) self.metrics["total_dwell_time"] / self.metrics["processed_count"]
if self.metrics["processed_count"] > 0
else 0
)
print("\n--- [METRICS REPORT] ---") print("\n--- [METRICS REPORT] ---")
print(f" Queue Depth: {self.queue_depth} tasks") print(f" Queue Depth: {self.queue_depth} tasks")
@ -65,7 +62,6 @@ class LeakyBucketLimiter:
print(f" Drops: {self.metrics['dropped_tasks']}") print(f" Drops: {self.metrics['dropped_tasks']}")
print("-------------------------\n") print("-------------------------\n")
async def _worker(self): async def _worker(self):
while self._running or not self._queue.empty(): while self._running or not self._queue.empty():
try: try:
@ -132,7 +128,6 @@ async def make_requests(number: int, limiter_rate: float, limiter_max_queue_size
end_time = time() end_time = time()
success = 0 success = 0
failures = 0 failures = 0
@ -149,7 +144,9 @@ async def make_requests(number: int, limiter_rate: float, limiter_max_queue_size
f"Total execution time: {execution_time:.2f}\nEffective rate: {effective_rate:.2f}" f"Total execution time: {execution_time:.2f}\nEffective rate: {effective_rate:.2f}"
) )
print(f"Operation complete with {success} successful requests and {failures} failures.") print(
f"Operation complete with {success} successful requests and {failures} failures."
)
await limiter.stop() await limiter.stop()

View File

@ -0,0 +1,4 @@
#!/bin/bash
ruff format .
ruff check --select I --fix .

View File

@ -2,13 +2,15 @@ from typing import Optional
from cachetools import TTLCache from cachetools import TTLCache
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.datastructures import Headers
from starlette.middleware import Middleware
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
from starlette.routing import Route from starlette.routing import Route
from starlette.middleware import Middleware
from token_bucket import AsyncTokenBucketLimiter, RateLimitExceeded from token_bucket import AsyncTokenBucketLimiter, RateLimitExceeded
class RateLimiterMiddleware: class RateLimiterMiddleware:
def __init__(self, app): def __init__(self, app):
self.app = app self.app = app
@ -18,13 +20,21 @@ class RateLimiterMiddleware:
if scope["type"] != "http": if scope["type"] != "http":
await self.app(scope, receive, send) await self.app(scope, receive, send)
# Check for "X-Forwarded-For" header - as passed by reverse proxies.
# Need it for testing
headers = Headers(raw=scope["headers"])
if (client_ip := headers.get("X-Forwarded-For")) is None:
client_data = scope["client"] client_data = scope["client"]
client_ip = client_data[0] if client_data else "UNKNOWN" client_ip = client_data[0] if client_data else "UNKNOWN"
limiter: Optional[AsyncTokenBucketLimiter] = self.clients.get(client_ip) limiter: Optional[AsyncTokenBucketLimiter] = self.clients.get(client_ip)
if not limiter: if not limiter:
limiter = self.clients.setdefault(client_ip, AsyncTokenBucketLimiter(10, 10)) limiter = self.clients.setdefault(
client_ip, AsyncTokenBucketLimiter(10, 10)
)
try: try:
async with limiter: async with limiter:
@ -34,7 +44,7 @@ class RateLimiterMiddleware:
response = Response( response = Response(
content=f"Rate limit exceeded! Retry after: {retry_after}", content=f"Rate limit exceeded! Retry after: {retry_after}",
headers={"Retry-After": retry_after}, headers={"Retry-After": retry_after},
status_code=429 status_code=429,
) )
return await response(scope, receive, send) return await response(scope, receive, send)

View File

@ -12,3 +12,6 @@ dependencies = [
"starlette>=1.0.0", "starlette>=1.0.0",
"uvicorn>=0.46.0", "uvicorn>=0.46.0",
] ]
[tool.isort]
profile = "ruff"

View File

@ -2,6 +2,7 @@ import pytest
from fixed_window import FixedWindowLimiter from fixed_window import FixedWindowLimiter
def test_limit_respected(): def test_limit_respected():
limiter = FixedWindowLimiter(10, 10) limiter = FixedWindowLimiter(10, 10)
try: try:
@ -11,6 +12,7 @@ def test_limit_respected():
except RuntimeError as e: except RuntimeError as e:
assert False, f"Limiter raised an exception -> {e}" assert False, f"Limiter raised an exception -> {e}"
def test_limit_enforced(): def test_limit_enforced():
limiter = FixedWindowLimiter(10, 10) limiter = FixedWindowLimiter(10, 10)
with pytest.raises(RuntimeError) as e: with pytest.raises(RuntimeError) as e:

View File

@ -5,12 +5,12 @@ import pytest
from leaky_bucket import LeakyBucketLimiter from leaky_bucket import LeakyBucketLimiter
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_duration() -> None: async def test_duration() -> None:
async def test_function(): async def test_function():
pass pass
limiter = LeakyBucketLimiter(10, 100) limiter = LeakyBucketLimiter(10, 100)
await limiter.start() await limiter.start()

View File

@ -1,10 +1,15 @@
from time import sleep from time import sleep
from cachetools import TTLCache
from starlette.applications import Starlette
from starlette.testclient import TestClient from starlette.testclient import TestClient
from middleware_starlette import app from middleware_starlette import RateLimiterMiddleware, create_app, routes
def test_app_runs(): def test_app_runs():
app = create_app()
client = TestClient(app) client = TestClient(app)
response = client.get("/") response = client.get("/")
@ -14,6 +19,8 @@ def test_app_runs():
def test_below_rate_limit(): def test_below_rate_limit():
app = create_app()
client = TestClient(app) client = TestClient(app)
# Limit is 10/s # Limit is 10/s
@ -26,6 +33,8 @@ def test_below_rate_limit():
def test_above_rate_limit(): def test_above_rate_limit():
app = create_app()
client = TestClient(app) client = TestClient(app)
successful = 0 successful = 0
@ -46,3 +55,96 @@ def test_above_rate_limit():
assert successful != 0 assert successful != 0
assert rate_limited != 0 assert rate_limited != 0
def find_middleware(stack, middleware_cls):
current = stack
while hasattr(current, "app"):
if isinstance(current, middleware_cls):
return current
current = current.app
return None
def test_rate_limiter_cache():
app = create_app()
client = TestClient(app)
client.get("/")
middleware = find_middleware(
app.middleware_stack,
RateLimiterMiddleware,
)
assert middleware is not None
assert "testclient" in middleware.clients
assert len(middleware.clients) == 1
def test_cache_ttl():
# Setup app with a short TTL
class ShortTTLLimiterMiddleware(RateLimiterMiddleware):
def __init__(self, app):
super().__init__(app)
self.clients = TTLCache(maxsize=10, ttl=2) # Clear after 2 seconds
app = Starlette(routes=routes)
app.add_middleware(ShortTTLLimiterMiddleware)
client = TestClient(app)
client.get("/")
middleware = find_middleware(
app.middleware_stack,
RateLimiterMiddleware,
)
assert middleware is not None
client.get("/", headers={"X-Forwarded-For": "1.1.1.1"})
assert middleware.clients.get("1.1.1.1") is not None
sleep(2)
assert middleware.clients.get("1.1.1.1") is None
def test_cache_eviction():
# Setup app with a tiny cache for testing
class TinyLimiterMiddleware(RateLimiterMiddleware):
def __init__(self, app):
super().__init__(app)
self.clients = TTLCache(maxsize=2, ttl=10) # Only space for 2 IPs
app = Starlette(routes=routes)
app.add_middleware(TinyLimiterMiddleware)
client = TestClient(app)
client.get("/")
middleware = find_middleware(
app.middleware_stack,
RateLimiterMiddleware,
)
assert middleware is not None
client.get("/", headers={"X-Forwarded-For": "1.1.1.1"})
client.get("/", headers={"X-Forwarded-For": "2.2.2.2"})
assert middleware.clients.get("1.1.1.1") is not None
assert middleware.clients.get("2.2.2.2") is not None
assert middleware.clients.get("3.3.3.3") is None
client.get("/", headers={"X-Forwarded-For": "3.3.3.3"})
assert middleware.clients.get("1.1.1.1") is None
assert middleware.clients.get("2.2.2.2") is not None
assert middleware.clients.get("3.3.3.3") is not None

View File

@ -1,6 +1,9 @@
import pytest
from time import sleep from time import sleep
from token_bucket import TokenBucketLimiter, RateLimitExceeded
import pytest
from token_bucket import RateLimitExceeded, TokenBucketLimiter
def test_limit_respected() -> None: def test_limit_respected() -> None:
limiter = TokenBucketLimiter(10, 10) limiter = TokenBucketLimiter(10, 10)
@ -22,4 +25,3 @@ def test_limit_exceeded() -> None:
continue continue
assert "Rate exceeded" in str(e.value) assert "Rate exceeded" in str(e.value)

View File

@ -1,6 +1,6 @@
import asyncio
import threading import threading
from time import sleep, time from time import sleep, time
import asyncio
class RateLimiterException(Exception): class RateLimiterException(Exception):
@ -42,6 +42,7 @@ class TokenBucketLimiter:
def __exit__(self, exc_type, exc_value, exc_traceback) -> bool: def __exit__(self, exc_type, exc_value, exc_traceback) -> bool:
return False return False
class AsyncTokenBucketLimiter: class AsyncTokenBucketLimiter:
def __init__(self, capacity: int, refill_rate: int): def __init__(self, capacity: int, refill_rate: int):
self.capacity = float(capacity) self.capacity = float(capacity)