"""Request tracing middleware for structured logging with correlation IDs.""" from __future__ import annotations import contextvars import uuid from typing import Any trace_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("trace_id", default="") def get_trace_id() -> str: """Get current trace ID from context.""" return trace_id_var.get() or "" def set_trace_id(trace_id: str) -> None: """Set trace ID in current context.""" trace_id_var.set(trace_id) def generate_trace_id() -> str: """Generate a new trace ID.""" return str(uuid.uuid4())[:8] class TracingMiddleware: """ASGI middleware that sets/propagates request trace IDs. Extracts trace ID from X-Request-ID header or generates a new one. Injects trace ID into response headers and logging context. """ def __init__(self, app: Any, header_name: str = "X-Request-ID") -> None: self.app = app self.header_name = header_name.lower() async def __call__(self, scope: dict[str, Any], receive: Any, send: Any) -> None: """ASGI entrypoint.""" if scope["type"] not in ("http", "websocket"): await self.app(scope, receive, send) return headers = dict(scope.get("headers", [])) trace_id = "" for k, v in headers.items(): if isinstance(k, bytes) and k.decode("latin-1").lower() == self.header_name: trace_id = v.decode("latin-1") if isinstance(v, bytes) else str(v) break if not trace_id: trace_id = generate_trace_id() set_trace_id(trace_id) async def send_with_trace(message: dict[str, Any]) -> None: if message["type"] == "http.response.start": headers_list = list(message.get("headers", [])) headers_list.append((b"x-request-id", trace_id.encode())) headers_list.append((b"x-trace-id", trace_id.encode())) message["headers"] = headers_list await send(message) await self.app(scope, receive, send_with_trace)