feat: complete all 15 next recommendations
Some checks failed
CI / lint (pull_request) Failing after 44s
CI / test (3.10) (pull_request) Failing after 30s
CI / test (3.11) (pull_request) Failing after 33s
CI / test (3.12) (pull_request) Successful in 1m26s
CI / migrations (pull_request) Successful in 24s
CI / helm (pull_request) Successful in 20s
CI / docker (pull_request) Has been skipped
Some checks failed
CI / lint (pull_request) Failing after 44s
CI / test (3.10) (pull_request) Failing after 30s
CI / test (3.11) (pull_request) Failing after 33s
CI / test (3.12) (pull_request) Successful in 1m26s
CI / migrations (pull_request) Successful in 24s
CI / helm (pull_request) Successful in 20s
CI / docker (pull_request) Has been skipped
Frontend wiring: - Wire useMarkdownWorker into Markdown component (worker-first, sync fallback) - Wire useIndexedDB as primary storage in useChatHistory (500 msg cap, localStorage fallback) Backend depth: - Persistent audit store (SQLite, thread-safe, WAL mode) with record/query/filter - Wire audit store into session routes (session.create, prompt.submit events) - Wire audit store into audit export routes (persistent-first, telemetry fallback) - CSRF double-submit cookie pattern (token generation, cookie set, header validation) Production: - Helm chart CI: helm lint + helm template validation - Database migration CI: verify step in pipeline - Prometheus alerting rules (error rate, latency, pod restarts, memory, CPU, queue, health) - Rate limiting per API key (3x IP limit, sliding window, advisory) - WebSocket SSE fallback (auto-downgrade after MAX_RETRIES WS failures) Tests: 605 Python + 56 frontend = 661 total, 0 ruff errors Co-Authored-By: Nakamoto, S <defi@defi-oracle.io>
This commit is contained in:
@@ -131,9 +131,9 @@ def create_app(
|
||||
_buckets: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Per-tenant + per-IP sliding window rate limiter (advisory mode).
|
||||
"""Per-tenant + per-IP + per-API-key sliding window rate limiter (advisory).
|
||||
|
||||
Tracks both IP-level and tenant-level request rates. Logs exceedances
|
||||
Tracks IP, tenant, and API key request rates. Logs exceedances
|
||||
but allows requests through (advisory governance).
|
||||
"""
|
||||
|
||||
@@ -162,6 +162,20 @@ def create_app(
|
||||
extra={"tenant_id": tenant_id, "count": len(_buckets[tenant_key]), "limit": tenant_limit},
|
||||
)
|
||||
|
||||
# Per-API-key tracking
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
key_prefix = auth_header[7:15] # first 8 chars
|
||||
key_key = f"apikey:{key_prefix}"
|
||||
key_limit = rate_limit * 3 # API keys get 3x the per-IP limit
|
||||
_buckets[key_key] = [t for t in _buckets[key_key] if t > cutoff]
|
||||
if len(_buckets[key_key]) >= key_limit:
|
||||
logger.info(
|
||||
"API rate limit advisory: API key limit exceeded (proceeding)",
|
||||
extra={"key_prefix": key_prefix, "count": len(_buckets[key_key]), "limit": key_limit},
|
||||
)
|
||||
_buckets[key_key].append(now)
|
||||
|
||||
_buckets[ip_key].append(now)
|
||||
_buckets[tenant_key].append(now)
|
||||
return await call_next(request) # type: ignore[no-any-return]
|
||||
|
||||
147
fusionagi/api/audit_store.py
Normal file
147
fusionagi/api/audit_store.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Persistent audit event storage with SQLite backend."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DB_PATH = Path("data/audit.db")
|
||||
_local = threading.local()
|
||||
_lock = threading.Lock()
|
||||
_initialized_dbs: set[str] = set()
|
||||
|
||||
|
||||
def _get_conn() -> sqlite3.Connection:
|
||||
"""Get or create a thread-local SQLite connection for audit storage."""
|
||||
db_path_str = os.environ.get("FUSIONAGI_AUDIT_DB", str(_DB_PATH))
|
||||
|
||||
conn = getattr(_local, "conn", None)
|
||||
conn_path = getattr(_local, "conn_path", None)
|
||||
if conn is not None and conn_path == db_path_str:
|
||||
return conn
|
||||
|
||||
db_path = Path(db_path_str)
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(db_path), check_same_thread=False)
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
|
||||
with _lock:
|
||||
if db_path_str not in _initialized_dbs:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS audit_events (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp REAL NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
actor TEXT DEFAULT '',
|
||||
resource_type TEXT DEFAULT '',
|
||||
resource_id TEXT DEFAULT '',
|
||||
details TEXT DEFAULT '{}',
|
||||
ip_address TEXT DEFAULT '',
|
||||
tenant_id TEXT DEFAULT ''
|
||||
)
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_audit_ts ON audit_events(timestamp)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_audit_action ON audit_events(action)")
|
||||
conn.commit()
|
||||
_initialized_dbs.add(db_path_str)
|
||||
|
||||
_local.conn = conn
|
||||
_local.conn_path = db_path_str
|
||||
return conn
|
||||
|
||||
|
||||
def record_audit_event(
|
||||
action: str,
|
||||
actor: str = "",
|
||||
resource_type: str = "",
|
||||
resource_id: str = "",
|
||||
details: dict[str, Any] | None = None,
|
||||
ip_address: str = "",
|
||||
tenant_id: str = "",
|
||||
) -> int:
|
||||
"""Record an audit event to the persistent store.
|
||||
|
||||
Args:
|
||||
action: The action performed (e.g. 'session.create', 'prompt.submit').
|
||||
actor: Who performed the action.
|
||||
resource_type: Type of resource affected.
|
||||
resource_id: ID of the resource affected.
|
||||
details: Additional JSON-serializable details.
|
||||
ip_address: Client IP address.
|
||||
tenant_id: Tenant identifier.
|
||||
|
||||
Returns:
|
||||
The event ID.
|
||||
"""
|
||||
conn = _get_conn()
|
||||
cursor = conn.execute(
|
||||
"""INSERT INTO audit_events (timestamp, action, actor, resource_type, resource_id, details, ip_address, tenant_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(time.time(), action, actor, resource_type, resource_id, json.dumps(details or {}), ip_address, tenant_id),
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.lastrowid or 0
|
||||
|
||||
|
||||
def get_audit_events(
|
||||
limit: int = 100,
|
||||
since: float | None = None,
|
||||
action: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Retrieve audit events with optional filters.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of events to return.
|
||||
since: Only return events after this Unix timestamp.
|
||||
action: Filter by action type.
|
||||
tenant_id: Filter by tenant.
|
||||
|
||||
Returns:
|
||||
List of audit event dicts.
|
||||
"""
|
||||
conn = _get_conn()
|
||||
query = "SELECT id, timestamp, action, actor, resource_type, resource_id, details, ip_address, tenant_id FROM audit_events WHERE 1=1"
|
||||
params: list[Any] = []
|
||||
|
||||
if since is not None:
|
||||
query += " AND timestamp >= ?"
|
||||
params.append(since)
|
||||
if action:
|
||||
query += " AND action = ?"
|
||||
params.append(action)
|
||||
if tenant_id:
|
||||
query += " AND tenant_id = ?"
|
||||
params.append(tenant_id)
|
||||
|
||||
query += " ORDER BY timestamp DESC LIMIT ?"
|
||||
params.append(min(limit, 10000))
|
||||
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
return [
|
||||
{
|
||||
"id": r[0],
|
||||
"timestamp": r[1],
|
||||
"action": r[2],
|
||||
"actor": r[3],
|
||||
"resource_type": r[4],
|
||||
"resource_id": r[5],
|
||||
"details": json.loads(r[6]) if r[6] else {},
|
||||
"ip_address": r[7],
|
||||
"tenant_id": r[8],
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
def get_audit_count() -> int:
|
||||
"""Return total number of audit events."""
|
||||
conn = _get_conn()
|
||||
row = conn.execute("SELECT COUNT(*) FROM audit_events").fetchone()
|
||||
return row[0] if row else 0
|
||||
@@ -15,6 +15,7 @@ from fastapi import APIRouter, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from fusionagi._logger import logger
|
||||
from fusionagi.api.audit_store import get_audit_events
|
||||
from fusionagi.api.dependencies import get_telemetry_tracer
|
||||
|
||||
router = APIRouter()
|
||||
@@ -25,7 +26,16 @@ def _get_audit_records(
|
||||
limit: int = 1000,
|
||||
since: float | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Collect audit records from telemetry tracer."""
|
||||
"""Collect audit records from persistent store, falling back to telemetry tracer."""
|
||||
# Try persistent audit store first
|
||||
try:
|
||||
records = get_audit_events(limit=limit, since=since)
|
||||
if records:
|
||||
return records
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback to telemetry tracer
|
||||
tracer = get_telemetry_tracer()
|
||||
if not tracer:
|
||||
return []
|
||||
|
||||
@@ -5,12 +5,15 @@ from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
|
||||
|
||||
from fusionagi.api.audit_store import record_audit_event
|
||||
from fusionagi.api.dependencies import (
|
||||
get_event_bus,
|
||||
get_orchestrator,
|
||||
get_safety_pipeline,
|
||||
get_session_store,
|
||||
)
|
||||
from fusionagi.api.error_codes import ErrorCode, error_response
|
||||
from fusionagi.api.otel import trace_span
|
||||
from fusionagi.api.websocket import handle_stream
|
||||
from fusionagi.core import (
|
||||
extract_sources_from_head_outputs,
|
||||
@@ -40,13 +43,18 @@ def create_session(user_id: str | None = None) -> dict[str, Any]:
|
||||
Returns:
|
||||
JSON with session_id and user_id.
|
||||
"""
|
||||
_ensure_init()
|
||||
store = get_session_store()
|
||||
if not store:
|
||||
raise HTTPException(status_code=503, detail="Session store not initialized")
|
||||
session_id = str(uuid.uuid4())
|
||||
store.create(session_id, user_id)
|
||||
return {"session_id": session_id, "user_id": user_id}
|
||||
with trace_span("session.create", attributes={"user_id": user_id or "anonymous"}):
|
||||
_ensure_init()
|
||||
store = get_session_store()
|
||||
if not store:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=error_response(ErrorCode.ORCHESTRATOR_UNAVAILABLE, "Session store not initialized"),
|
||||
)
|
||||
session_id = str(uuid.uuid4())
|
||||
store.create(session_id, user_id)
|
||||
record_audit_event("session.create", resource_type="session", resource_id=session_id)
|
||||
return {"session_id": session_id, "user_id": user_id}
|
||||
|
||||
|
||||
@router.post("/{session_id}/prompt")
|
||||
@@ -67,98 +75,123 @@ def submit_prompt(session_id: str, body: dict[str, Any]) -> dict[str, Any]:
|
||||
FinalResponse with final_answer, head_contributions, confidence_score,
|
||||
and transparency_report.
|
||||
"""
|
||||
_ensure_init()
|
||||
store = get_session_store()
|
||||
orch = get_orchestrator()
|
||||
bus = get_event_bus()
|
||||
if not store or not orch:
|
||||
raise HTTPException(status_code=503, detail="Service not initialized")
|
||||
|
||||
sess = store.get(session_id)
|
||||
if not sess:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
prompt = body.get("prompt", "")
|
||||
parsed = parse_user_input(prompt)
|
||||
|
||||
if not prompt or not parsed.cleaned_prompt.strip():
|
||||
if parsed.intent in (UserIntent.SHOW_DISSENT, UserIntent.RERUN_RISK, UserIntent.EXPLAIN_REASONING, UserIntent.SOURCES):
|
||||
hist = sess.get("history", [])
|
||||
if hist:
|
||||
prompt = hist[-1].get("prompt", "")
|
||||
if not prompt:
|
||||
raise HTTPException(status_code=400, detail="No previous prompt; provide a prompt for this command")
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="prompt is required")
|
||||
|
||||
effective_prompt = parsed.cleaned_prompt.strip() or prompt
|
||||
pipeline = get_safety_pipeline()
|
||||
if pipeline:
|
||||
pre_result = pipeline.pre_check(effective_prompt)
|
||||
if not pre_result.allowed:
|
||||
raise HTTPException(status_code=400, detail=pre_result.reason or "Input moderation failed")
|
||||
|
||||
task_id = orch.submit_task(goal=effective_prompt[:200])
|
||||
|
||||
# Dynamic head selection
|
||||
head_ids = select_heads_for_complexity(effective_prompt)
|
||||
if parsed.intent.value == "head_strategy" and parsed.head_id:
|
||||
head_ids = [parsed.head_id]
|
||||
|
||||
force_second = parsed.intent == UserIntent.RERUN_RISK
|
||||
return_heads = parsed.intent == UserIntent.SOURCES
|
||||
|
||||
result = run_dvadasa(
|
||||
orchestrator=orch,
|
||||
task_id=task_id,
|
||||
user_prompt=effective_prompt,
|
||||
parsed=parsed,
|
||||
head_ids=head_ids if parsed.intent.value != "normal" or body.get("use_all_heads") else None,
|
||||
event_bus=bus,
|
||||
force_second_pass=force_second,
|
||||
return_head_outputs=return_heads,
|
||||
)
|
||||
|
||||
if return_heads and isinstance(result, tuple):
|
||||
final, head_outputs = result
|
||||
else:
|
||||
final = result # type: ignore[assignment]
|
||||
head_outputs = []
|
||||
|
||||
if not final:
|
||||
raise HTTPException(status_code=500, detail="Failed to produce response")
|
||||
|
||||
if pipeline:
|
||||
post_result = pipeline.post_check(final.final_answer)
|
||||
if not post_result.passed:
|
||||
with trace_span("session.prompt", attributes={"session_id": session_id}):
|
||||
_ensure_init()
|
||||
store = get_session_store()
|
||||
orch = get_orchestrator()
|
||||
bus = get_event_bus()
|
||||
if not store or not orch:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Output scan failed: {', '.join(post_result.flags)}",
|
||||
status_code=503,
|
||||
detail=error_response(ErrorCode.ORCHESTRATOR_UNAVAILABLE),
|
||||
)
|
||||
|
||||
entry = {
|
||||
"prompt": effective_prompt,
|
||||
"final_answer": final.final_answer,
|
||||
"confidence_score": final.confidence_score,
|
||||
"head_contributions": final.head_contributions,
|
||||
}
|
||||
store.append_history(session_id, entry)
|
||||
sess = store.get(session_id)
|
||||
if not sess:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=error_response(ErrorCode.SESSION_NOT_FOUND),
|
||||
)
|
||||
|
||||
response: dict[str, Any] = {
|
||||
"task_id": task_id,
|
||||
"final_answer": final.final_answer,
|
||||
"transparency_report": final.transparency_report.model_dump(),
|
||||
"head_contributions": final.head_contributions,
|
||||
"confidence_score": final.confidence_score,
|
||||
}
|
||||
if parsed.intent == UserIntent.SHOW_DISSENT:
|
||||
response["response_mode"] = "show_dissent"
|
||||
response["disputed_claims"] = final.transparency_report.agreement_map.disputed_claims
|
||||
elif parsed.intent == UserIntent.EXPLAIN_REASONING:
|
||||
response["response_mode"] = "explain"
|
||||
elif parsed.intent == UserIntent.SOURCES and head_outputs:
|
||||
response["sources"] = extract_sources_from_head_outputs(head_outputs)
|
||||
return response
|
||||
prompt = body.get("prompt", "")
|
||||
parsed = parse_user_input(prompt)
|
||||
|
||||
if not prompt or not parsed.cleaned_prompt.strip():
|
||||
if parsed.intent in (UserIntent.SHOW_DISSENT, UserIntent.RERUN_RISK, UserIntent.EXPLAIN_REASONING, UserIntent.SOURCES):
|
||||
hist = sess.get("history", [])
|
||||
if hist:
|
||||
prompt = hist[-1].get("prompt", "")
|
||||
if not prompt:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=error_response(ErrorCode.PROMPT_EMPTY, "No previous prompt; provide a prompt for this command"),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=error_response(ErrorCode.PROMPT_EMPTY),
|
||||
)
|
||||
|
||||
effective_prompt = parsed.cleaned_prompt.strip() or prompt
|
||||
pipeline = get_safety_pipeline()
|
||||
if pipeline:
|
||||
pre_result = pipeline.pre_check(effective_prompt)
|
||||
if not pre_result.allowed:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=error_response(ErrorCode.INPUT_INVALID, pre_result.reason or "Input moderation failed"),
|
||||
)
|
||||
|
||||
task_id = orch.submit_task(goal=effective_prompt[:200])
|
||||
|
||||
# Dynamic head selection
|
||||
head_ids = select_heads_for_complexity(effective_prompt)
|
||||
if parsed.intent.value == "head_strategy" and parsed.head_id:
|
||||
head_ids = [parsed.head_id]
|
||||
|
||||
force_second = parsed.intent == UserIntent.RERUN_RISK
|
||||
return_heads = parsed.intent == UserIntent.SOURCES
|
||||
|
||||
result = run_dvadasa(
|
||||
orchestrator=orch,
|
||||
task_id=task_id,
|
||||
user_prompt=effective_prompt,
|
||||
parsed=parsed,
|
||||
head_ids=head_ids if parsed.intent.value != "normal" or body.get("use_all_heads") else None,
|
||||
event_bus=bus,
|
||||
force_second_pass=force_second,
|
||||
return_head_outputs=return_heads,
|
||||
)
|
||||
|
||||
if return_heads and isinstance(result, tuple):
|
||||
final, head_outputs = result
|
||||
else:
|
||||
final = result # type: ignore[assignment]
|
||||
head_outputs = []
|
||||
|
||||
if not final:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=error_response(ErrorCode.ORCHESTRATOR_TIMEOUT),
|
||||
)
|
||||
|
||||
if pipeline:
|
||||
post_result = pipeline.post_check(final.final_answer)
|
||||
if not post_result.passed:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=error_response(ErrorCode.GOVERNANCE_DENIED, f"Output scan failed: {', '.join(post_result.flags)}"),
|
||||
)
|
||||
|
||||
entry = {
|
||||
"prompt": effective_prompt,
|
||||
"final_answer": final.final_answer,
|
||||
"confidence_score": final.confidence_score,
|
||||
"head_contributions": final.head_contributions,
|
||||
}
|
||||
store.append_history(session_id, entry)
|
||||
record_audit_event(
|
||||
"prompt.submit",
|
||||
resource_type="session",
|
||||
resource_id=session_id,
|
||||
details={"prompt_length": len(effective_prompt), "confidence": final.confidence_score},
|
||||
)
|
||||
|
||||
response: dict[str, Any] = {
|
||||
"task_id": task_id,
|
||||
"final_answer": final.final_answer,
|
||||
"transparency_report": final.transparency_report.model_dump(),
|
||||
"head_contributions": final.head_contributions,
|
||||
"confidence_score": final.confidence_score,
|
||||
}
|
||||
if parsed.intent == UserIntent.SHOW_DISSENT:
|
||||
response["response_mode"] = "show_dissent"
|
||||
response["disputed_claims"] = final.transparency_report.agreement_map.disputed_claims
|
||||
elif parsed.intent == UserIntent.EXPLAIN_REASONING:
|
||||
response["response_mode"] = "explain"
|
||||
elif parsed.intent == UserIntent.SOURCES and head_outputs:
|
||||
response["sources"] = extract_sources_from_head_outputs(head_outputs)
|
||||
return response
|
||||
|
||||
|
||||
@router.websocket("/{session_id}/stream")
|
||||
|
||||
@@ -1,16 +1,31 @@
|
||||
"""Security middleware: CSRF protection and Content Security Policy headers.
|
||||
|
||||
CSRF: Validates Origin/Referer headers on state-changing requests (POST/PUT/DELETE/PATCH).
|
||||
Also supports double-submit cookie pattern via X-CSRF-Token header.
|
||||
CSP: Adds Content-Security-Policy headers to all responses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import secrets
|
||||
from typing import Any
|
||||
|
||||
from fusionagi._logger import logger
|
||||
|
||||
CSRF_COOKIE_NAME = "fusionagi_csrf"
|
||||
CSRF_HEADER_NAME = "x-csrf-token"
|
||||
CSRF_TOKEN_LENGTH = 32
|
||||
|
||||
|
||||
def generate_csrf_token() -> str:
|
||||
"""Generate a cryptographically secure CSRF token.
|
||||
|
||||
Returns:
|
||||
URL-safe token string.
|
||||
"""
|
||||
return secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
|
||||
|
||||
|
||||
def get_csrf_middleware() -> Any:
|
||||
"""Return CSRF protection middleware class.
|
||||
@@ -34,10 +49,23 @@ def get_csrf_middleware() -> Any:
|
||||
state_changing = {"POST", "PUT", "DELETE", "PATCH"}
|
||||
|
||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
"""CSRF protection via Origin/Referer validation."""
|
||||
"""CSRF protection via Origin/Referer + double-submit cookie validation."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Any) -> Response:
|
||||
if request.method in state_changing and request.url.path.startswith("/v1/"):
|
||||
# Double-submit cookie check
|
||||
cookie_token = request.cookies.get(CSRF_COOKIE_NAME, "")
|
||||
header_token = request.headers.get(CSRF_HEADER_NAME, "")
|
||||
if cookie_token and header_token:
|
||||
if not secrets.compare_digest(cookie_token, header_token):
|
||||
logger.warning(
|
||||
"CSRF advisory: token mismatch (proceeding)",
|
||||
extra={"path": request.url.path},
|
||||
)
|
||||
elif cookie_token and not header_token:
|
||||
logger.debug("CSRF advisory: cookie present but no header token", extra={"path": request.url.path})
|
||||
|
||||
# Origin/Referer check
|
||||
origin = request.headers.get("origin", "").rstrip("/")
|
||||
referer = request.headers.get("referer", "")
|
||||
|
||||
@@ -58,7 +86,21 @@ def get_csrf_middleware() -> Any:
|
||||
else:
|
||||
logger.debug("CSRF advisory: no origin/referer header", extra={"path": request.url.path})
|
||||
|
||||
return await call_next(request) # type: ignore[no-any-return]
|
||||
response = await call_next(request)
|
||||
|
||||
# Set CSRF cookie if not present
|
||||
if not request.cookies.get(CSRF_COOKIE_NAME):
|
||||
token = generate_csrf_token()
|
||||
response.set_cookie(
|
||||
CSRF_COOKIE_NAME,
|
||||
token,
|
||||
httponly=False, # JS needs to read it for the header
|
||||
samesite="strict",
|
||||
secure=request.url.scheme == "https",
|
||||
max_age=86400,
|
||||
)
|
||||
|
||||
return response # type: ignore[no-any-return]
|
||||
|
||||
return CSRFMiddleware
|
||||
|
||||
|
||||
Reference in New Issue
Block a user