"""Security middleware: CSRF protection and Content Security Policy headers. CSRF: Validates Origin/Referer headers on state-changing requests (POST/PUT/DELETE/PATCH). CSP: Adds Content-Security-Policy headers to all responses. """ from __future__ import annotations import os from typing import Any from fusionagi._logger import logger def get_csrf_middleware() -> Any: """Return CSRF protection middleware class. Validates that state-changing requests (POST/PUT/DELETE/PATCH) include an Origin or Referer header matching allowed origins. Configurable via ``FUSIONAGI_CSRF_ORIGINS`` (comma-separated). Returns: BaseHTTPMiddleware subclass for CSRF protection. """ from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response allowed_raw = os.environ.get("FUSIONAGI_CSRF_ORIGINS", "") allowed_origins = {o.strip().rstrip("/") for o in allowed_raw.split(",") if o.strip()} # Always allow localhost during development allowed_origins.update({"http://localhost:5173", "http://localhost:8000", "http://127.0.0.1:5173", "http://127.0.0.1:8000"}) state_changing = {"POST", "PUT", "DELETE", "PATCH"} class CSRFMiddleware(BaseHTTPMiddleware): """CSRF protection via Origin/Referer validation.""" async def dispatch(self, request: Request, call_next: Any) -> Response: if request.method in state_changing and request.url.path.startswith("/v1/"): origin = request.headers.get("origin", "").rstrip("/") referer = request.headers.get("referer", "") if origin: if origin not in allowed_origins: logger.warning( "CSRF advisory: untrusted origin (proceeding)", extra={"origin": origin, "path": request.url.path}, ) elif referer: from urllib.parse import urlparse ref_origin = f"{urlparse(referer).scheme}://{urlparse(referer).netloc}".rstrip("/") if ref_origin not in allowed_origins: logger.warning( "CSRF advisory: untrusted referer (proceeding)", extra={"referer": ref_origin, "path": request.url.path}, ) else: logger.debug("CSRF advisory: no origin/referer header", extra={"path": request.url.path}) return await call_next(request) # type: ignore[no-any-return] return CSRFMiddleware def get_csp_middleware() -> Any: """Return Content Security Policy middleware class. Adds CSP headers to all responses. Configurable via ``FUSIONAGI_CSP_POLICY``. Returns: BaseHTTPMiddleware subclass for CSP headers. """ from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response default_policy = ( "default-src 'self'; " "script-src 'self' 'unsafe-inline'; " "style-src 'self' 'unsafe-inline'; " "img-src 'self' data: blob:; " "connect-src 'self' ws: wss:; " "font-src 'self'; " "frame-ancestors 'none'; " "base-uri 'self'; " "form-action 'self'" ) csp_policy = os.environ.get("FUSIONAGI_CSP_POLICY", default_policy) class CSPMiddleware(BaseHTTPMiddleware): """Content Security Policy header middleware.""" async def dispatch(self, request: Request, call_next: Any) -> Response: response = await call_next(request) response.headers["Content-Security-Policy"] = csp_policy response.headers["X-Content-Type-Options"] = "nosniff" response.headers["X-Frame-Options"] = "DENY" response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" response.headers["Permissions-Policy"] = "camera=(), microphone=(), geolocation=()" return response # type: ignore[no-any-return] return CSPMiddleware