Added some more tests + formatted code
This commit is contained in:
parent
8fccddc815
commit
be71e97956
@ -39,6 +39,7 @@ def make_requests(number: int, limiter: FixedWindowLimiter) -> None:
|
||||
print("Request limit reached!")
|
||||
sleep(0.5)
|
||||
|
||||
|
||||
# lim_1 = FixedWindowLimiter(window_duration=10, window_size=10)
|
||||
|
||||
# make_requests(100, lim_1)
|
||||
|
||||
@ -14,27 +14,23 @@ class LeakyBucketLimiter:
|
||||
|
||||
self.metrics = {
|
||||
"processed_count": 0,
|
||||
"total_dwell_time": 0.0, # Total time tasks waited in queue
|
||||
"dropped_tasks": 0
|
||||
"total_dwell_time": 0.0, # Total time tasks waited in queue
|
||||
"dropped_tasks": 0,
|
||||
}
|
||||
|
||||
|
||||
def _calculate_delay(self):
|
||||
self.delay = 1.0 / self.rate
|
||||
|
||||
|
||||
@property
|
||||
def queue_depth(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
|
||||
|
||||
async def start(self):
|
||||
if not self._running:
|
||||
self._running = True
|
||||
self._worker_task = asyncio.create_task(self._worker())
|
||||
self._monitor_task = asyncio.create_task(self._monitor())
|
||||
print("Limiter started!")
|
||||
|
||||
|
||||
async def stop(self):
|
||||
self._running = False
|
||||
@ -42,22 +38,23 @@ class LeakyBucketLimiter:
|
||||
await self._worker_task
|
||||
print("Limiter stopped!")
|
||||
|
||||
|
||||
async def set_rate(self, new_rate: float):
|
||||
if new_rate <= 0:
|
||||
raise ValueError("Rate must be a positive number!")
|
||||
|
||||
|
||||
print(f"Changing rate from {self.rate} to {new_rate}...")
|
||||
self.rate = new_rate
|
||||
self._calculate_delay()
|
||||
|
||||
|
||||
async def _monitor(self):
|
||||
while self._running:
|
||||
await asyncio.sleep(2)
|
||||
avg_dwell = (self.metrics["total_dwell_time"] / self.metrics["processed_count"]
|
||||
if self.metrics["processed_count"] > 0 else 0)
|
||||
|
||||
avg_dwell = (
|
||||
self.metrics["total_dwell_time"] / self.metrics["processed_count"]
|
||||
if self.metrics["processed_count"] > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
print("\n--- [METRICS REPORT] ---")
|
||||
print(f" Queue Depth: {self.queue_depth} tasks")
|
||||
print(f" Throughput: {self.metrics['processed_count']} tasks total")
|
||||
@ -65,7 +62,6 @@ class LeakyBucketLimiter:
|
||||
print(f" Drops: {self.metrics['dropped_tasks']}")
|
||||
print("-------------------------\n")
|
||||
|
||||
|
||||
async def _worker(self):
|
||||
while self._running or not self._queue.empty():
|
||||
try:
|
||||
@ -86,7 +82,7 @@ class LeakyBucketLimiter:
|
||||
finally:
|
||||
self.metrics["processed_count"] += 1
|
||||
self._queue.task_done()
|
||||
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception as e:
|
||||
@ -95,10 +91,10 @@ class LeakyBucketLimiter:
|
||||
async def execute(self, func, *args, **kwargs):
|
||||
if not self._running:
|
||||
raise RuntimeError("Limiter is not running! Call .start() first.")
|
||||
|
||||
|
||||
# if self._queue.full():
|
||||
# raise RuntimeError(f"Queue is full! Retry in {self.delay:.2f} seconds.")
|
||||
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
|
||||
@ -132,7 +128,6 @@ async def make_requests(number: int, limiter_rate: float, limiter_max_queue_size
|
||||
|
||||
end_time = time()
|
||||
|
||||
|
||||
success = 0
|
||||
failures = 0
|
||||
|
||||
@ -148,11 +143,13 @@ async def make_requests(number: int, limiter_rate: float, limiter_max_queue_size
|
||||
print(
|
||||
f"Total execution time: {execution_time:.2f}\nEffective rate: {effective_rate:.2f}"
|
||||
)
|
||||
|
||||
print(f"Operation complete with {success} successful requests and {failures} failures.")
|
||||
|
||||
|
||||
print(
|
||||
f"Operation complete with {success} successful requests and {failures} failures."
|
||||
)
|
||||
|
||||
await limiter.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(make_requests(1000, 20, 60))
|
||||
asyncio.run(make_requests(1000, 20, 60))
|
||||
|
||||
4
rate_limiters/lint_and_format.sh
Executable file
4
rate_limiters/lint_and_format.sh
Executable file
@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
|
||||
ruff format .
|
||||
ruff check --select I --fix .
|
||||
@ -2,13 +2,15 @@ from typing import Optional
|
||||
|
||||
from cachetools import TTLCache
|
||||
from starlette.applications import Starlette
|
||||
from starlette.datastructures import Headers
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Route
|
||||
from starlette.middleware import Middleware
|
||||
|
||||
from token_bucket import AsyncTokenBucketLimiter, RateLimitExceeded
|
||||
|
||||
|
||||
class RateLimiterMiddleware:
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
@ -18,13 +20,21 @@ class RateLimiterMiddleware:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
client_data = scope["client"]
|
||||
client_ip = client_data[0] if client_data else "UNKNOWN"
|
||||
# Check for "X-Forwarded-For" header - as passed by reverse proxies.
|
||||
# Need it for testing
|
||||
|
||||
headers = Headers(raw=scope["headers"])
|
||||
|
||||
if (client_ip := headers.get("X-Forwarded-For")) is None:
|
||||
client_data = scope["client"]
|
||||
client_ip = client_data[0] if client_data else "UNKNOWN"
|
||||
|
||||
limiter: Optional[AsyncTokenBucketLimiter] = self.clients.get(client_ip)
|
||||
|
||||
if not limiter:
|
||||
limiter = self.clients.setdefault(client_ip, AsyncTokenBucketLimiter(10, 10))
|
||||
limiter = self.clients.setdefault(
|
||||
client_ip, AsyncTokenBucketLimiter(10, 10)
|
||||
)
|
||||
|
||||
try:
|
||||
async with limiter:
|
||||
@ -34,7 +44,7 @@ class RateLimiterMiddleware:
|
||||
response = Response(
|
||||
content=f"Rate limit exceeded! Retry after: {retry_after}",
|
||||
headers={"Retry-After": retry_after},
|
||||
status_code=429
|
||||
status_code=429,
|
||||
)
|
||||
return await response(scope, receive, send)
|
||||
|
||||
|
||||
@ -12,3 +12,6 @@ dependencies = [
|
||||
"starlette>=1.0.0",
|
||||
"uvicorn>=0.46.0",
|
||||
]
|
||||
|
||||
[tool.isort]
|
||||
profile = "ruff"
|
||||
@ -2,6 +2,7 @@ import pytest
|
||||
|
||||
from fixed_window import FixedWindowLimiter
|
||||
|
||||
|
||||
def test_limit_respected():
|
||||
limiter = FixedWindowLimiter(10, 10)
|
||||
try:
|
||||
@ -10,7 +11,8 @@ def test_limit_respected():
|
||||
pass
|
||||
except RuntimeError as e:
|
||||
assert False, f"Limiter raised an exception -> {e}"
|
||||
|
||||
|
||||
|
||||
def test_limit_enforced():
|
||||
limiter = FixedWindowLimiter(10, 10)
|
||||
with pytest.raises(RuntimeError) as e:
|
||||
|
||||
@ -5,11 +5,11 @@ import pytest
|
||||
|
||||
from leaky_bucket import LeakyBucketLimiter
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duration() -> None:
|
||||
async def test_function():
|
||||
pass
|
||||
|
||||
|
||||
limiter = LeakyBucketLimiter(10, 100)
|
||||
await limiter.start()
|
||||
@ -26,4 +26,4 @@ async def test_duration() -> None:
|
||||
assert duration < 2.1
|
||||
|
||||
for item in results:
|
||||
assert isinstance(item, Exception) is False
|
||||
assert isinstance(item, Exception) is False
|
||||
|
||||
@ -1,10 +1,15 @@
|
||||
from time import sleep
|
||||
|
||||
from cachetools import TTLCache
|
||||
from starlette.applications import Starlette
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from middleware_starlette import app
|
||||
from middleware_starlette import RateLimiterMiddleware, create_app, routes
|
||||
|
||||
|
||||
def test_app_runs():
|
||||
app = create_app()
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/")
|
||||
@ -14,6 +19,8 @@ def test_app_runs():
|
||||
|
||||
|
||||
def test_below_rate_limit():
|
||||
app = create_app()
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Limit is 10/s
|
||||
@ -26,6 +33,8 @@ def test_below_rate_limit():
|
||||
|
||||
|
||||
def test_above_rate_limit():
|
||||
app = create_app()
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
successful = 0
|
||||
@ -46,3 +55,96 @@ def test_above_rate_limit():
|
||||
assert successful != 0
|
||||
assert rate_limited != 0
|
||||
|
||||
|
||||
def find_middleware(stack, middleware_cls):
|
||||
current = stack
|
||||
|
||||
while hasattr(current, "app"):
|
||||
if isinstance(current, middleware_cls):
|
||||
return current
|
||||
current = current.app
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def test_rate_limiter_cache():
|
||||
app = create_app()
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
client.get("/")
|
||||
|
||||
middleware = find_middleware(
|
||||
app.middleware_stack,
|
||||
RateLimiterMiddleware,
|
||||
)
|
||||
|
||||
assert middleware is not None
|
||||
|
||||
assert "testclient" in middleware.clients
|
||||
assert len(middleware.clients) == 1
|
||||
|
||||
|
||||
def test_cache_ttl():
|
||||
# Setup app with a short TTL
|
||||
class ShortTTLLimiterMiddleware(RateLimiterMiddleware):
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
self.clients = TTLCache(maxsize=10, ttl=2) # Clear after 2 seconds
|
||||
|
||||
app = Starlette(routes=routes)
|
||||
app.add_middleware(ShortTTLLimiterMiddleware)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
client.get("/")
|
||||
|
||||
middleware = find_middleware(
|
||||
app.middleware_stack,
|
||||
RateLimiterMiddleware,
|
||||
)
|
||||
|
||||
assert middleware is not None
|
||||
|
||||
client.get("/", headers={"X-Forwarded-For": "1.1.1.1"})
|
||||
|
||||
assert middleware.clients.get("1.1.1.1") is not None
|
||||
|
||||
sleep(2)
|
||||
|
||||
assert middleware.clients.get("1.1.1.1") is None
|
||||
|
||||
|
||||
def test_cache_eviction():
|
||||
# Setup app with a tiny cache for testing
|
||||
class TinyLimiterMiddleware(RateLimiterMiddleware):
|
||||
def __init__(self, app):
|
||||
super().__init__(app)
|
||||
self.clients = TTLCache(maxsize=2, ttl=10) # Only space for 2 IPs
|
||||
|
||||
app = Starlette(routes=routes)
|
||||
app.add_middleware(TinyLimiterMiddleware)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
client.get("/")
|
||||
|
||||
middleware = find_middleware(
|
||||
app.middleware_stack,
|
||||
RateLimiterMiddleware,
|
||||
)
|
||||
|
||||
assert middleware is not None
|
||||
|
||||
client.get("/", headers={"X-Forwarded-For": "1.1.1.1"})
|
||||
client.get("/", headers={"X-Forwarded-For": "2.2.2.2"})
|
||||
|
||||
assert middleware.clients.get("1.1.1.1") is not None
|
||||
assert middleware.clients.get("2.2.2.2") is not None
|
||||
assert middleware.clients.get("3.3.3.3") is None
|
||||
|
||||
client.get("/", headers={"X-Forwarded-For": "3.3.3.3"})
|
||||
|
||||
assert middleware.clients.get("1.1.1.1") is None
|
||||
assert middleware.clients.get("2.2.2.2") is not None
|
||||
assert middleware.clients.get("3.3.3.3") is not None
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
import pytest
|
||||
from time import sleep
|
||||
from token_bucket import TokenBucketLimiter, RateLimitExceeded
|
||||
|
||||
import pytest
|
||||
|
||||
from token_bucket import RateLimitExceeded, TokenBucketLimiter
|
||||
|
||||
|
||||
def test_limit_respected() -> None:
|
||||
limiter = TokenBucketLimiter(10, 10)
|
||||
@ -20,6 +23,5 @@ def test_limit_exceeded() -> None:
|
||||
for i in range(20):
|
||||
with limiter:
|
||||
continue
|
||||
|
||||
assert "Rate exceeded" in str(e.value)
|
||||
|
||||
assert "Rate exceeded" in str(e.value)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from time import sleep, time
|
||||
import asyncio
|
||||
|
||||
|
||||
class RateLimiterException(Exception):
|
||||
@ -41,7 +41,8 @@ class TokenBucketLimiter:
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
class AsyncTokenBucketLimiter:
|
||||
def __init__(self, capacity: int, refill_rate: int):
|
||||
self.capacity = float(capacity)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user