Added some more tests + formatted code

This commit is contained in:
V 2026-05-09 21:17:23 +01:00
parent 8fccddc815
commit be71e97956
10 changed files with 159 additions and 37 deletions

View File

@ -39,6 +39,7 @@ def make_requests(number: int, limiter: FixedWindowLimiter) -> None:
print("Request limit reached!")
sleep(0.5)
# lim_1 = FixedWindowLimiter(window_duration=10, window_size=10)
# make_requests(100, lim_1)

View File

@ -15,19 +15,16 @@ class LeakyBucketLimiter:
self.metrics = {
"processed_count": 0,
"total_dwell_time": 0.0, # Total time tasks waited in queue
"dropped_tasks": 0
"dropped_tasks": 0,
}
def _calculate_delay(self):
self.delay = 1.0 / self.rate
@property
def queue_depth(self) -> int:
return self._queue.qsize()
async def start(self):
if not self._running:
self._running = True
@ -35,14 +32,12 @@ class LeakyBucketLimiter:
self._monitor_task = asyncio.create_task(self._monitor())
print("Limiter started!")
async def stop(self):
self._running = False
if self._worker_task:
await self._worker_task
print("Limiter stopped!")
async def set_rate(self, new_rate: float):
if new_rate <= 0:
raise ValueError("Rate must be a positive number!")
@ -51,12 +46,14 @@ class LeakyBucketLimiter:
self.rate = new_rate
self._calculate_delay()
async def _monitor(self):
while self._running:
await asyncio.sleep(2)
avg_dwell = (self.metrics["total_dwell_time"] / self.metrics["processed_count"]
if self.metrics["processed_count"] > 0 else 0)
avg_dwell = (
self.metrics["total_dwell_time"] / self.metrics["processed_count"]
if self.metrics["processed_count"] > 0
else 0
)
print("\n--- [METRICS REPORT] ---")
print(f" Queue Depth: {self.queue_depth} tasks")
@ -65,7 +62,6 @@ class LeakyBucketLimiter:
print(f" Drops: {self.metrics['dropped_tasks']}")
print("-------------------------\n")
async def _worker(self):
while self._running or not self._queue.empty():
try:
@ -132,7 +128,6 @@ async def make_requests(number: int, limiter_rate: float, limiter_max_queue_size
end_time = time()
success = 0
failures = 0
@ -149,7 +144,9 @@ async def make_requests(number: int, limiter_rate: float, limiter_max_queue_size
f"Total execution time: {execution_time:.2f}\nEffective rate: {effective_rate:.2f}"
)
print(f"Operation complete with {success} successful requests and {failures} failures.")
print(
f"Operation complete with {success} successful requests and {failures} failures."
)
await limiter.stop()

View File

@ -0,0 +1,4 @@
#!/bin/bash
ruff format .
ruff check --select I --fix .

View File

@ -2,13 +2,15 @@ from typing import Optional
from cachetools import TTLCache
from starlette.applications import Starlette
from starlette.datastructures import Headers
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Route
from starlette.middleware import Middleware
from token_bucket import AsyncTokenBucketLimiter, RateLimitExceeded
class RateLimiterMiddleware:
def __init__(self, app):
self.app = app
@ -18,13 +20,21 @@ class RateLimiterMiddleware:
if scope["type"] != "http":
await self.app(scope, receive, send)
# Check for "X-Forwarded-For" header - as passed by reverse proxies.
# Need it for testing
headers = Headers(raw=scope["headers"])
if (client_ip := headers.get("X-Forwarded-For")) is None:
client_data = scope["client"]
client_ip = client_data[0] if client_data else "UNKNOWN"
limiter: Optional[AsyncTokenBucketLimiter] = self.clients.get(client_ip)
if not limiter:
limiter = self.clients.setdefault(client_ip, AsyncTokenBucketLimiter(10, 10))
limiter = self.clients.setdefault(
client_ip, AsyncTokenBucketLimiter(10, 10)
)
try:
async with limiter:
@ -34,7 +44,7 @@ class RateLimiterMiddleware:
response = Response(
content=f"Rate limit exceeded! Retry after: {retry_after}",
headers={"Retry-After": retry_after},
status_code=429
status_code=429,
)
return await response(scope, receive, send)

View File

@ -12,3 +12,6 @@ dependencies = [
"starlette>=1.0.0",
"uvicorn>=0.46.0",
]
[tool.isort]
profile = "ruff"

View File

@ -2,6 +2,7 @@ import pytest
from fixed_window import FixedWindowLimiter
def test_limit_respected():
limiter = FixedWindowLimiter(10, 10)
try:
@ -11,6 +12,7 @@ def test_limit_respected():
except RuntimeError as e:
assert False, f"Limiter raised an exception -> {e}"
def test_limit_enforced():
limiter = FixedWindowLimiter(10, 10)
with pytest.raises(RuntimeError) as e:

View File

@ -5,12 +5,12 @@ import pytest
from leaky_bucket import LeakyBucketLimiter
@pytest.mark.asyncio
async def test_duration() -> None:
async def test_function():
pass
limiter = LeakyBucketLimiter(10, 100)
await limiter.start()

View File

@ -1,10 +1,15 @@
from time import sleep
from cachetools import TTLCache
from starlette.applications import Starlette
from starlette.testclient import TestClient
from middleware_starlette import app
from middleware_starlette import RateLimiterMiddleware, create_app, routes
def test_app_runs():
app = create_app()
client = TestClient(app)
response = client.get("/")
@ -14,6 +19,8 @@ def test_app_runs():
def test_below_rate_limit():
app = create_app()
client = TestClient(app)
# Limit is 10/s
@ -26,6 +33,8 @@ def test_below_rate_limit():
def test_above_rate_limit():
app = create_app()
client = TestClient(app)
successful = 0
@ -46,3 +55,96 @@ def test_above_rate_limit():
assert successful != 0
assert rate_limited != 0
def find_middleware(stack, middleware_cls):
current = stack
while hasattr(current, "app"):
if isinstance(current, middleware_cls):
return current
current = current.app
return None
def test_rate_limiter_cache():
app = create_app()
client = TestClient(app)
client.get("/")
middleware = find_middleware(
app.middleware_stack,
RateLimiterMiddleware,
)
assert middleware is not None
assert "testclient" in middleware.clients
assert len(middleware.clients) == 1
def test_cache_ttl():
# Setup app with a short TTL
class ShortTTLLimiterMiddleware(RateLimiterMiddleware):
def __init__(self, app):
super().__init__(app)
self.clients = TTLCache(maxsize=10, ttl=2) # Clear after 2 seconds
app = Starlette(routes=routes)
app.add_middleware(ShortTTLLimiterMiddleware)
client = TestClient(app)
client.get("/")
middleware = find_middleware(
app.middleware_stack,
RateLimiterMiddleware,
)
assert middleware is not None
client.get("/", headers={"X-Forwarded-For": "1.1.1.1"})
assert middleware.clients.get("1.1.1.1") is not None
sleep(2)
assert middleware.clients.get("1.1.1.1") is None
def test_cache_eviction():
# Setup app with a tiny cache for testing
class TinyLimiterMiddleware(RateLimiterMiddleware):
def __init__(self, app):
super().__init__(app)
self.clients = TTLCache(maxsize=2, ttl=10) # Only space for 2 IPs
app = Starlette(routes=routes)
app.add_middleware(TinyLimiterMiddleware)
client = TestClient(app)
client.get("/")
middleware = find_middleware(
app.middleware_stack,
RateLimiterMiddleware,
)
assert middleware is not None
client.get("/", headers={"X-Forwarded-For": "1.1.1.1"})
client.get("/", headers={"X-Forwarded-For": "2.2.2.2"})
assert middleware.clients.get("1.1.1.1") is not None
assert middleware.clients.get("2.2.2.2") is not None
assert middleware.clients.get("3.3.3.3") is None
client.get("/", headers={"X-Forwarded-For": "3.3.3.3"})
assert middleware.clients.get("1.1.1.1") is None
assert middleware.clients.get("2.2.2.2") is not None
assert middleware.clients.get("3.3.3.3") is not None

View File

@ -1,6 +1,9 @@
import pytest
from time import sleep
from token_bucket import TokenBucketLimiter, RateLimitExceeded
import pytest
from token_bucket import RateLimitExceeded, TokenBucketLimiter
def test_limit_respected() -> None:
limiter = TokenBucketLimiter(10, 10)
@ -22,4 +25,3 @@ def test_limit_exceeded() -> None:
continue
assert "Rate exceeded" in str(e.value)

View File

@ -1,6 +1,6 @@
import asyncio
import threading
from time import sleep, time
import asyncio
class RateLimiterException(Exception):
@ -42,6 +42,7 @@ class TokenBucketLimiter:
def __exit__(self, exc_type, exc_value, exc_traceback) -> bool:
return False
class AsyncTokenBucketLimiter:
def __init__(self, capacity: int, refill_rate: int):
self.capacity = float(capacity)