38 lines
1.3 KiB
Python
38 lines
1.3 KiB
Python
from __future__ import annotations
|
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse, Response
|
|
|
|
from app.db.redis import get_redis
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
def __init__(self, app, max_requests: int = 100, window_seconds: int = 60): # type: ignore[no-untyped-def]
|
|
super().__init__(app)
|
|
self.max_requests = max_requests
|
|
self.window_seconds = window_seconds
|
|
|
|
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
|
if request.url.path.startswith("/docs") or request.url.path.startswith("/redoc"):
|
|
return await call_next(request)
|
|
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
key = f"rate_limit:{client_ip}"
|
|
|
|
try:
|
|
redis = get_redis()
|
|
current = await redis.incr(key)
|
|
if current == 1:
|
|
await redis.expire(key, self.window_seconds)
|
|
|
|
if current > self.max_requests:
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={"detail": "Too many requests"},
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
return await call_next(request)
|