Added some more tests + formatted code
This commit is contained in:
parent
8fccddc815
commit
be71e97956
@ -39,6 +39,7 @@ def make_requests(number: int, limiter: FixedWindowLimiter) -> None:
|
|||||||
print("Request limit reached!")
|
print("Request limit reached!")
|
||||||
sleep(0.5)
|
sleep(0.5)
|
||||||
|
|
||||||
|
|
||||||
# lim_1 = FixedWindowLimiter(window_duration=10, window_size=10)
|
# lim_1 = FixedWindowLimiter(window_duration=10, window_size=10)
|
||||||
|
|
||||||
# make_requests(100, lim_1)
|
# make_requests(100, lim_1)
|
||||||
|
|||||||
@ -14,20 +14,17 @@ class LeakyBucketLimiter:
|
|||||||
|
|
||||||
self.metrics = {
|
self.metrics = {
|
||||||
"processed_count": 0,
|
"processed_count": 0,
|
||||||
"total_dwell_time": 0.0, # Total time tasks waited in queue
|
"total_dwell_time": 0.0, # Total time tasks waited in queue
|
||||||
"dropped_tasks": 0
|
"dropped_tasks": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _calculate_delay(self):
|
def _calculate_delay(self):
|
||||||
self.delay = 1.0 / self.rate
|
self.delay = 1.0 / self.rate
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def queue_depth(self) -> int:
|
def queue_depth(self) -> int:
|
||||||
return self._queue.qsize()
|
return self._queue.qsize()
|
||||||
|
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
if not self._running:
|
if not self._running:
|
||||||
self._running = True
|
self._running = True
|
||||||
@ -35,14 +32,12 @@ class LeakyBucketLimiter:
|
|||||||
self._monitor_task = asyncio.create_task(self._monitor())
|
self._monitor_task = asyncio.create_task(self._monitor())
|
||||||
print("Limiter started!")
|
print("Limiter started!")
|
||||||
|
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
self._running = False
|
self._running = False
|
||||||
if self._worker_task:
|
if self._worker_task:
|
||||||
await self._worker_task
|
await self._worker_task
|
||||||
print("Limiter stopped!")
|
print("Limiter stopped!")
|
||||||
|
|
||||||
|
|
||||||
async def set_rate(self, new_rate: float):
|
async def set_rate(self, new_rate: float):
|
||||||
if new_rate <= 0:
|
if new_rate <= 0:
|
||||||
raise ValueError("Rate must be a positive number!")
|
raise ValueError("Rate must be a positive number!")
|
||||||
@ -51,12 +46,14 @@ class LeakyBucketLimiter:
|
|||||||
self.rate = new_rate
|
self.rate = new_rate
|
||||||
self._calculate_delay()
|
self._calculate_delay()
|
||||||
|
|
||||||
|
|
||||||
async def _monitor(self):
|
async def _monitor(self):
|
||||||
while self._running:
|
while self._running:
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
avg_dwell = (self.metrics["total_dwell_time"] / self.metrics["processed_count"]
|
avg_dwell = (
|
||||||
if self.metrics["processed_count"] > 0 else 0)
|
self.metrics["total_dwell_time"] / self.metrics["processed_count"]
|
||||||
|
if self.metrics["processed_count"] > 0
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
print("\n--- [METRICS REPORT] ---")
|
print("\n--- [METRICS REPORT] ---")
|
||||||
print(f" Queue Depth: {self.queue_depth} tasks")
|
print(f" Queue Depth: {self.queue_depth} tasks")
|
||||||
@ -65,7 +62,6 @@ class LeakyBucketLimiter:
|
|||||||
print(f" Drops: {self.metrics['dropped_tasks']}")
|
print(f" Drops: {self.metrics['dropped_tasks']}")
|
||||||
print("-------------------------\n")
|
print("-------------------------\n")
|
||||||
|
|
||||||
|
|
||||||
async def _worker(self):
|
async def _worker(self):
|
||||||
while self._running or not self._queue.empty():
|
while self._running or not self._queue.empty():
|
||||||
try:
|
try:
|
||||||
@ -132,7 +128,6 @@ async def make_requests(number: int, limiter_rate: float, limiter_max_queue_size
|
|||||||
|
|
||||||
end_time = time()
|
end_time = time()
|
||||||
|
|
||||||
|
|
||||||
success = 0
|
success = 0
|
||||||
failures = 0
|
failures = 0
|
||||||
|
|
||||||
@ -149,7 +144,9 @@ async def make_requests(number: int, limiter_rate: float, limiter_max_queue_size
|
|||||||
f"Total execution time: {execution_time:.2f}\nEffective rate: {effective_rate:.2f}"
|
f"Total execution time: {execution_time:.2f}\nEffective rate: {effective_rate:.2f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Operation complete with {success} successful requests and {failures} failures.")
|
print(
|
||||||
|
f"Operation complete with {success} successful requests and {failures} failures."
|
||||||
|
)
|
||||||
|
|
||||||
await limiter.stop()
|
await limiter.stop()
|
||||||
|
|
||||||
|
|||||||
4
rate_limiters/lint_and_format.sh
Executable file
4
rate_limiters/lint_and_format.sh
Executable file
@ -0,0 +1,4 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
ruff format .
|
||||||
|
ruff check --select I --fix .
|
||||||
@ -2,13 +2,15 @@ from typing import Optional
|
|||||||
|
|
||||||
from cachetools import TTLCache
|
from cachetools import TTLCache
|
||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
|
from starlette.datastructures import Headers
|
||||||
|
from starlette.middleware import Middleware
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
from starlette.routing import Route
|
from starlette.routing import Route
|
||||||
from starlette.middleware import Middleware
|
|
||||||
|
|
||||||
from token_bucket import AsyncTokenBucketLimiter, RateLimitExceeded
|
from token_bucket import AsyncTokenBucketLimiter, RateLimitExceeded
|
||||||
|
|
||||||
|
|
||||||
class RateLimiterMiddleware:
|
class RateLimiterMiddleware:
|
||||||
def __init__(self, app):
|
def __init__(self, app):
|
||||||
self.app = app
|
self.app = app
|
||||||
@ -18,13 +20,21 @@ class RateLimiterMiddleware:
|
|||||||
if scope["type"] != "http":
|
if scope["type"] != "http":
|
||||||
await self.app(scope, receive, send)
|
await self.app(scope, receive, send)
|
||||||
|
|
||||||
client_data = scope["client"]
|
# Check for "X-Forwarded-For" header - as passed by reverse proxies.
|
||||||
client_ip = client_data[0] if client_data else "UNKNOWN"
|
# Need it for testing
|
||||||
|
|
||||||
|
headers = Headers(raw=scope["headers"])
|
||||||
|
|
||||||
|
if (client_ip := headers.get("X-Forwarded-For")) is None:
|
||||||
|
client_data = scope["client"]
|
||||||
|
client_ip = client_data[0] if client_data else "UNKNOWN"
|
||||||
|
|
||||||
limiter: Optional[AsyncTokenBucketLimiter] = self.clients.get(client_ip)
|
limiter: Optional[AsyncTokenBucketLimiter] = self.clients.get(client_ip)
|
||||||
|
|
||||||
if not limiter:
|
if not limiter:
|
||||||
limiter = self.clients.setdefault(client_ip, AsyncTokenBucketLimiter(10, 10))
|
limiter = self.clients.setdefault(
|
||||||
|
client_ip, AsyncTokenBucketLimiter(10, 10)
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with limiter:
|
async with limiter:
|
||||||
@ -34,7 +44,7 @@ class RateLimiterMiddleware:
|
|||||||
response = Response(
|
response = Response(
|
||||||
content=f"Rate limit exceeded! Retry after: {retry_after}",
|
content=f"Rate limit exceeded! Retry after: {retry_after}",
|
||||||
headers={"Retry-After": retry_after},
|
headers={"Retry-After": retry_after},
|
||||||
status_code=429
|
status_code=429,
|
||||||
)
|
)
|
||||||
return await response(scope, receive, send)
|
return await response(scope, receive, send)
|
||||||
|
|
||||||
|
|||||||
@ -12,3 +12,6 @@ dependencies = [
|
|||||||
"starlette>=1.0.0",
|
"starlette>=1.0.0",
|
||||||
"uvicorn>=0.46.0",
|
"uvicorn>=0.46.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
profile = "ruff"
|
||||||
@ -2,6 +2,7 @@ import pytest
|
|||||||
|
|
||||||
from fixed_window import FixedWindowLimiter
|
from fixed_window import FixedWindowLimiter
|
||||||
|
|
||||||
|
|
||||||
def test_limit_respected():
|
def test_limit_respected():
|
||||||
limiter = FixedWindowLimiter(10, 10)
|
limiter = FixedWindowLimiter(10, 10)
|
||||||
try:
|
try:
|
||||||
@ -11,6 +12,7 @@ def test_limit_respected():
|
|||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
assert False, f"Limiter raised an exception -> {e}"
|
assert False, f"Limiter raised an exception -> {e}"
|
||||||
|
|
||||||
|
|
||||||
def test_limit_enforced():
|
def test_limit_enforced():
|
||||||
limiter = FixedWindowLimiter(10, 10)
|
limiter = FixedWindowLimiter(10, 10)
|
||||||
with pytest.raises(RuntimeError) as e:
|
with pytest.raises(RuntimeError) as e:
|
||||||
|
|||||||
@ -5,12 +5,12 @@ import pytest
|
|||||||
|
|
||||||
from leaky_bucket import LeakyBucketLimiter
|
from leaky_bucket import LeakyBucketLimiter
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_duration() -> None:
|
async def test_duration() -> None:
|
||||||
async def test_function():
|
async def test_function():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
limiter = LeakyBucketLimiter(10, 100)
|
limiter = LeakyBucketLimiter(10, 100)
|
||||||
await limiter.start()
|
await limiter.start()
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,15 @@
|
|||||||
from time import sleep
|
from time import sleep
|
||||||
|
|
||||||
|
from cachetools import TTLCache
|
||||||
|
from starlette.applications import Starlette
|
||||||
from starlette.testclient import TestClient
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
from middleware_starlette import app
|
from middleware_starlette import RateLimiterMiddleware, create_app, routes
|
||||||
|
|
||||||
|
|
||||||
def test_app_runs():
|
def test_app_runs():
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
|
|
||||||
response = client.get("/")
|
response = client.get("/")
|
||||||
@ -14,6 +19,8 @@ def test_app_runs():
|
|||||||
|
|
||||||
|
|
||||||
def test_below_rate_limit():
|
def test_below_rate_limit():
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
|
|
||||||
# Limit is 10/s
|
# Limit is 10/s
|
||||||
@ -26,6 +33,8 @@ def test_below_rate_limit():
|
|||||||
|
|
||||||
|
|
||||||
def test_above_rate_limit():
|
def test_above_rate_limit():
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
|
|
||||||
successful = 0
|
successful = 0
|
||||||
@ -46,3 +55,96 @@ def test_above_rate_limit():
|
|||||||
assert successful != 0
|
assert successful != 0
|
||||||
assert rate_limited != 0
|
assert rate_limited != 0
|
||||||
|
|
||||||
|
|
||||||
|
def find_middleware(stack, middleware_cls):
|
||||||
|
current = stack
|
||||||
|
|
||||||
|
while hasattr(current, "app"):
|
||||||
|
if isinstance(current, middleware_cls):
|
||||||
|
return current
|
||||||
|
current = current.app
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limiter_cache():
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
client.get("/")
|
||||||
|
|
||||||
|
middleware = find_middleware(
|
||||||
|
app.middleware_stack,
|
||||||
|
RateLimiterMiddleware,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert middleware is not None
|
||||||
|
|
||||||
|
assert "testclient" in middleware.clients
|
||||||
|
assert len(middleware.clients) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_ttl():
|
||||||
|
# Setup app with a short TTL
|
||||||
|
class ShortTTLLimiterMiddleware(RateLimiterMiddleware):
|
||||||
|
def __init__(self, app):
|
||||||
|
super().__init__(app)
|
||||||
|
self.clients = TTLCache(maxsize=10, ttl=2) # Clear after 2 seconds
|
||||||
|
|
||||||
|
app = Starlette(routes=routes)
|
||||||
|
app.add_middleware(ShortTTLLimiterMiddleware)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
client.get("/")
|
||||||
|
|
||||||
|
middleware = find_middleware(
|
||||||
|
app.middleware_stack,
|
||||||
|
RateLimiterMiddleware,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert middleware is not None
|
||||||
|
|
||||||
|
client.get("/", headers={"X-Forwarded-For": "1.1.1.1"})
|
||||||
|
|
||||||
|
assert middleware.clients.get("1.1.1.1") is not None
|
||||||
|
|
||||||
|
sleep(2)
|
||||||
|
|
||||||
|
assert middleware.clients.get("1.1.1.1") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_eviction():
|
||||||
|
# Setup app with a tiny cache for testing
|
||||||
|
class TinyLimiterMiddleware(RateLimiterMiddleware):
|
||||||
|
def __init__(self, app):
|
||||||
|
super().__init__(app)
|
||||||
|
self.clients = TTLCache(maxsize=2, ttl=10) # Only space for 2 IPs
|
||||||
|
|
||||||
|
app = Starlette(routes=routes)
|
||||||
|
app.add_middleware(TinyLimiterMiddleware)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
client.get("/")
|
||||||
|
|
||||||
|
middleware = find_middleware(
|
||||||
|
app.middleware_stack,
|
||||||
|
RateLimiterMiddleware,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert middleware is not None
|
||||||
|
|
||||||
|
client.get("/", headers={"X-Forwarded-For": "1.1.1.1"})
|
||||||
|
client.get("/", headers={"X-Forwarded-For": "2.2.2.2"})
|
||||||
|
|
||||||
|
assert middleware.clients.get("1.1.1.1") is not None
|
||||||
|
assert middleware.clients.get("2.2.2.2") is not None
|
||||||
|
assert middleware.clients.get("3.3.3.3") is None
|
||||||
|
|
||||||
|
client.get("/", headers={"X-Forwarded-For": "3.3.3.3"})
|
||||||
|
|
||||||
|
assert middleware.clients.get("1.1.1.1") is None
|
||||||
|
assert middleware.clients.get("2.2.2.2") is not None
|
||||||
|
assert middleware.clients.get("3.3.3.3") is not None
|
||||||
|
|||||||
@ -1,6 +1,9 @@
|
|||||||
import pytest
|
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from token_bucket import TokenBucketLimiter, RateLimitExceeded
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from token_bucket import RateLimitExceeded, TokenBucketLimiter
|
||||||
|
|
||||||
|
|
||||||
def test_limit_respected() -> None:
|
def test_limit_respected() -> None:
|
||||||
limiter = TokenBucketLimiter(10, 10)
|
limiter = TokenBucketLimiter(10, 10)
|
||||||
@ -22,4 +25,3 @@ def test_limit_exceeded() -> None:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
assert "Rate exceeded" in str(e.value)
|
assert "Rate exceeded" in str(e.value)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
|
import asyncio
|
||||||
import threading
|
import threading
|
||||||
from time import sleep, time
|
from time import sleep, time
|
||||||
import asyncio
|
|
||||||
|
|
||||||
|
|
||||||
class RateLimiterException(Exception):
|
class RateLimiterException(Exception):
|
||||||
@ -42,6 +42,7 @@ class TokenBucketLimiter:
|
|||||||
def __exit__(self, exc_type, exc_value, exc_traceback) -> bool:
|
def __exit__(self, exc_type, exc_value, exc_traceback) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class AsyncTokenBucketLimiter:
|
class AsyncTokenBucketLimiter:
|
||||||
def __init__(self, capacity: int, refill_rate: int):
|
def __init__(self, capacity: int, refill_rate: int):
|
||||||
self.capacity = float(capacity)
|
self.capacity = float(capacity)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user