From be71e97956de20ba99be361b59000e966fedaf17 Mon Sep 17 00:00:00 2001 From: V Date: Sat, 9 May 2026 21:17:23 +0100 Subject: [PATCH] Added some more tests + formatted code --- rate_limiters/fixed_window.py | 1 + rate_limiters/leaky_bucket.py | 41 ++++--- rate_limiters/lint_and_format.sh | 4 + rate_limiters/middleware_starlette.py | 20 +++- rate_limiters/pyproject.toml | 3 + rate_limiters/tests/test_fixed_window.py | 4 +- rate_limiters/tests/test_leaky_bucket.py | 4 +- .../tests/test_middleware_starlette.py | 104 +++++++++++++++++- rate_limiters/tests/test_token_bucket.py | 10 +- rate_limiters/token_bucket.py | 5 +- 10 files changed, 159 insertions(+), 37 deletions(-) create mode 100755 rate_limiters/lint_and_format.sh diff --git a/rate_limiters/fixed_window.py b/rate_limiters/fixed_window.py index 6b363e0..d73d221 100644 --- a/rate_limiters/fixed_window.py +++ b/rate_limiters/fixed_window.py @@ -39,6 +39,7 @@ def make_requests(number: int, limiter: FixedWindowLimiter) -> None: print("Request limit reached!") sleep(0.5) + # lim_1 = FixedWindowLimiter(window_duration=10, window_size=10) # make_requests(100, lim_1) diff --git a/rate_limiters/leaky_bucket.py b/rate_limiters/leaky_bucket.py index ec0dbb8..9813e03 100644 --- a/rate_limiters/leaky_bucket.py +++ b/rate_limiters/leaky_bucket.py @@ -14,27 +14,23 @@ class LeakyBucketLimiter: self.metrics = { "processed_count": 0, - "total_dwell_time": 0.0, # Total time tasks waited in queue - "dropped_tasks": 0 + "total_dwell_time": 0.0, # Total time tasks waited in queue + "dropped_tasks": 0, } - def _calculate_delay(self): self.delay = 1.0 / self.rate - @property def queue_depth(self) -> int: return self._queue.qsize() - - + async def start(self): if not self._running: self._running = True self._worker_task = asyncio.create_task(self._worker()) self._monitor_task = asyncio.create_task(self._monitor()) print("Limiter started!") - async def stop(self): self._running = False @@ -42,22 +38,23 @@ class LeakyBucketLimiter: await self._worker_task print("Limiter stopped!") - async def set_rate(self, new_rate: float): if new_rate <= 0: raise ValueError("Rate must be a positive number!") - + print(f"Changing rate from {self.rate} to {new_rate}...") self.rate = new_rate self._calculate_delay() - async def _monitor(self): while self._running: await asyncio.sleep(2) - avg_dwell = (self.metrics["total_dwell_time"] / self.metrics["processed_count"] - if self.metrics["processed_count"] > 0 else 0) - + avg_dwell = ( + self.metrics["total_dwell_time"] / self.metrics["processed_count"] + if self.metrics["processed_count"] > 0 + else 0 + ) + print("\n--- [METRICS REPORT] ---") print(f" Queue Depth: {self.queue_depth} tasks") print(f" Throughput: {self.metrics['processed_count']} tasks total") @@ -65,7 +62,6 @@ class LeakyBucketLimiter: print(f" Drops: {self.metrics['dropped_tasks']}") print("-------------------------\n") - async def _worker(self): while self._running or not self._queue.empty(): try: @@ -86,7 +82,7 @@ class LeakyBucketLimiter: finally: self.metrics["processed_count"] += 1 self._queue.task_done() - + except asyncio.TimeoutError: continue except Exception as e: @@ -95,10 +91,10 @@ class LeakyBucketLimiter: async def execute(self, func, *args, **kwargs): if not self._running: raise RuntimeError("Limiter is not running! Call .start() first.") - + # if self._queue.full(): # raise RuntimeError(f"Queue is full! Retry in {self.delay:.2f} seconds.") - + loop = asyncio.get_running_loop() future = loop.create_future() @@ -132,7 +128,6 @@ async def make_requests(number: int, limiter_rate: float, limiter_max_queue_size end_time = time() - success = 0 failures = 0 @@ -148,11 +143,13 @@ async def make_requests(number: int, limiter_rate: float, limiter_max_queue_size print( f"Total execution time: {execution_time:.2f}\nEffective rate: {effective_rate:.2f}" ) - - print(f"Operation complete with {success} successful requests and {failures} failures.") - + + print( + f"Operation complete with {success} successful requests and {failures} failures." + ) + await limiter.stop() if __name__ == "__main__": - asyncio.run(make_requests(1000, 20, 60)) \ No newline at end of file + asyncio.run(make_requests(1000, 20, 60)) diff --git a/rate_limiters/lint_and_format.sh b/rate_limiters/lint_and_format.sh new file mode 100755 index 0000000..6a4d582 --- /dev/null +++ b/rate_limiters/lint_and_format.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +ruff format . +ruff check --select I --fix . \ No newline at end of file diff --git a/rate_limiters/middleware_starlette.py b/rate_limiters/middleware_starlette.py index e14ea5e..43ed05c 100644 --- a/rate_limiters/middleware_starlette.py +++ b/rate_limiters/middleware_starlette.py @@ -2,13 +2,15 @@ from typing import Optional from cachetools import TTLCache from starlette.applications import Starlette +from starlette.datastructures import Headers +from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import Response from starlette.routing import Route -from starlette.middleware import Middleware from token_bucket import AsyncTokenBucketLimiter, RateLimitExceeded + class RateLimiterMiddleware: def __init__(self, app): self.app = app @@ -18,13 +20,21 @@ class RateLimiterMiddleware: if scope["type"] != "http": await self.app(scope, receive, send) - client_data = scope["client"] - client_ip = client_data[0] if client_data else "UNKNOWN" + # Check for "X-Forwarded-For" header - as passed by reverse proxies. + # Need it for testing + + headers = Headers(raw=scope["headers"]) + + if (client_ip := headers.get("X-Forwarded-For")) is None: + client_data = scope["client"] + client_ip = client_data[0] if client_data else "UNKNOWN" limiter: Optional[AsyncTokenBucketLimiter] = self.clients.get(client_ip) if not limiter: - limiter = self.clients.setdefault(client_ip, AsyncTokenBucketLimiter(10, 10)) + limiter = self.clients.setdefault( + client_ip, AsyncTokenBucketLimiter(10, 10) + ) try: async with limiter: @@ -34,7 +44,7 @@ class RateLimiterMiddleware: response = Response( content=f"Rate limit exceeded! Retry after: {retry_after}", headers={"Retry-After": retry_after}, - status_code=429 + status_code=429, ) return await response(scope, receive, send) diff --git a/rate_limiters/pyproject.toml b/rate_limiters/pyproject.toml index 35397be..2a969ec 100644 --- a/rate_limiters/pyproject.toml +++ b/rate_limiters/pyproject.toml @@ -12,3 +12,6 @@ dependencies = [ "starlette>=1.0.0", "uvicorn>=0.46.0", ] + +[tool.isort] +profile = "ruff" \ No newline at end of file diff --git a/rate_limiters/tests/test_fixed_window.py b/rate_limiters/tests/test_fixed_window.py index 7f60ace..183cb30 100644 --- a/rate_limiters/tests/test_fixed_window.py +++ b/rate_limiters/tests/test_fixed_window.py @@ -2,6 +2,7 @@ import pytest from fixed_window import FixedWindowLimiter + def test_limit_respected(): limiter = FixedWindowLimiter(10, 10) try: @@ -10,7 +11,8 @@ def test_limit_respected(): pass except RuntimeError as e: assert False, f"Limiter raised an exception -> {e}" - + + def test_limit_enforced(): limiter = FixedWindowLimiter(10, 10) with pytest.raises(RuntimeError) as e: diff --git a/rate_limiters/tests/test_leaky_bucket.py b/rate_limiters/tests/test_leaky_bucket.py index 6912f15..4d91fae 100644 --- a/rate_limiters/tests/test_leaky_bucket.py +++ b/rate_limiters/tests/test_leaky_bucket.py @@ -5,11 +5,11 @@ import pytest from leaky_bucket import LeakyBucketLimiter + @pytest.mark.asyncio async def test_duration() -> None: async def test_function(): pass - limiter = LeakyBucketLimiter(10, 100) await limiter.start() @@ -26,4 +26,4 @@ async def test_duration() -> None: assert duration < 2.1 for item in results: - assert isinstance(item, Exception) is False \ No newline at end of file + assert isinstance(item, Exception) is False diff --git a/rate_limiters/tests/test_middleware_starlette.py b/rate_limiters/tests/test_middleware_starlette.py index 04818de..e65beab 100644 --- a/rate_limiters/tests/test_middleware_starlette.py +++ b/rate_limiters/tests/test_middleware_starlette.py @@ -1,10 +1,15 @@ from time import sleep +from cachetools import TTLCache +from starlette.applications import Starlette from starlette.testclient import TestClient -from middleware_starlette import app +from middleware_starlette import RateLimiterMiddleware, create_app, routes + def test_app_runs(): + app = create_app() + client = TestClient(app) response = client.get("/") @@ -14,6 +19,8 @@ def test_app_runs(): def test_below_rate_limit(): + app = create_app() + client = TestClient(app) # Limit is 10/s @@ -26,6 +33,8 @@ def test_below_rate_limit(): def test_above_rate_limit(): + app = create_app() + client = TestClient(app) successful = 0 @@ -46,3 +55,96 @@ def test_above_rate_limit(): assert successful != 0 assert rate_limited != 0 + +def find_middleware(stack, middleware_cls): + current = stack + + while hasattr(current, "app"): + if isinstance(current, middleware_cls): + return current + current = current.app + + return None + + +def test_rate_limiter_cache(): + app = create_app() + + client = TestClient(app) + + client.get("/") + + middleware = find_middleware( + app.middleware_stack, + RateLimiterMiddleware, + ) + + assert middleware is not None + + assert "testclient" in middleware.clients + assert len(middleware.clients) == 1 + + +def test_cache_ttl(): + # Setup app with a short TTL + class ShortTTLLimiterMiddleware(RateLimiterMiddleware): + def __init__(self, app): + super().__init__(app) + self.clients = TTLCache(maxsize=10, ttl=2) # Clear after 2 seconds + + app = Starlette(routes=routes) + app.add_middleware(ShortTTLLimiterMiddleware) + + client = TestClient(app) + + client.get("/") + + middleware = find_middleware( + app.middleware_stack, + RateLimiterMiddleware, + ) + + assert middleware is not None + + client.get("/", headers={"X-Forwarded-For": "1.1.1.1"}) + + assert middleware.clients.get("1.1.1.1") is not None + + sleep(2) + + assert middleware.clients.get("1.1.1.1") is None + + +def test_cache_eviction(): + # Setup app with a tiny cache for testing + class TinyLimiterMiddleware(RateLimiterMiddleware): + def __init__(self, app): + super().__init__(app) + self.clients = TTLCache(maxsize=2, ttl=10) # Only space for 2 IPs + + app = Starlette(routes=routes) + app.add_middleware(TinyLimiterMiddleware) + + client = TestClient(app) + + client.get("/") + + middleware = find_middleware( + app.middleware_stack, + RateLimiterMiddleware, + ) + + assert middleware is not None + + client.get("/", headers={"X-Forwarded-For": "1.1.1.1"}) + client.get("/", headers={"X-Forwarded-For": "2.2.2.2"}) + + assert middleware.clients.get("1.1.1.1") is not None + assert middleware.clients.get("2.2.2.2") is not None + assert middleware.clients.get("3.3.3.3") is None + + client.get("/", headers={"X-Forwarded-For": "3.3.3.3"}) + + assert middleware.clients.get("1.1.1.1") is None + assert middleware.clients.get("2.2.2.2") is not None + assert middleware.clients.get("3.3.3.3") is not None diff --git a/rate_limiters/tests/test_token_bucket.py b/rate_limiters/tests/test_token_bucket.py index ba2a68f..bf19dc5 100644 --- a/rate_limiters/tests/test_token_bucket.py +++ b/rate_limiters/tests/test_token_bucket.py @@ -1,6 +1,9 @@ -import pytest from time import sleep -from token_bucket import TokenBucketLimiter, RateLimitExceeded + +import pytest + +from token_bucket import RateLimitExceeded, TokenBucketLimiter + def test_limit_respected() -> None: limiter = TokenBucketLimiter(10, 10) @@ -20,6 +23,5 @@ def test_limit_exceeded() -> None: for i in range(20): with limiter: continue - - assert "Rate exceeded" in str(e.value) + assert "Rate exceeded" in str(e.value) diff --git a/rate_limiters/token_bucket.py b/rate_limiters/token_bucket.py index 5ebc8e7..f107995 100644 --- a/rate_limiters/token_bucket.py +++ b/rate_limiters/token_bucket.py @@ -1,6 +1,6 @@ +import asyncio import threading from time import sleep, time -import asyncio class RateLimiterException(Exception): @@ -41,7 +41,8 @@ class TokenBucketLimiter: def __exit__(self, exc_type, exc_value, exc_traceback) -> bool: return False - + + class AsyncTokenBucketLimiter: def __init__(self, capacity: int, refill_rate: int): self.capacity = float(capacity)