from __future__ import annotations from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.responses import JSONResponse, Response from app.db.redis import get_redis class RateLimitMiddleware(BaseHTTPMiddleware): def __init__(self, app, max_requests: int = 100, window_seconds: int = 60): # type: ignore[no-untyped-def] super().__init__(app) self.max_requests = max_requests self.window_seconds = window_seconds async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: if request.url.path.startswith("/docs") or request.url.path.startswith("/redoc"): return await call_next(request) client_ip = request.client.host if request.client else "unknown" key = f"rate_limit:{client_ip}" try: redis = get_redis() current = await redis.incr(key) if current == 1: await redis.expire(key, self.window_seconds) if current > self.max_requests: return JSONResponse( status_code=429, content={"detail": "Too many requests"}, ) except Exception: pass return await call_next(request)