"""Interpretability: full reasoning trace from prompt to final answer. Every step of the reasoning pipeline can be traced and explained: - Prompt decomposition decisions - Head selection and dispatch - Per-head claim generation with evidence chains - Consensus process (agreements, disputes) - Metacognitive assessment - Verification results - Final synthesis rationale The ReasoningTrace captures all of this in a structured, queryable format that can be serialized for debugging, auditing, or user explanation. """ from __future__ import annotations from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Any def _utc_now() -> datetime: """Return current UTC time (timezone-aware).""" return datetime.now(timezone.utc) @dataclass class TraceStep: """A single step in the reasoning trace. Attributes: step_id: Unique identifier for this step. stage: Pipeline stage (e.g. ``decomposition``, ``head_dispatch``). component: Component that executed this step. input_summary: Brief summary of the step's input. output_summary: Brief summary of the step's output. duration_ms: Execution time in milliseconds (if measured). metadata: Additional structured data. timestamp: When this step was recorded. """ step_id: str = "" stage: str = "" component: str = "" input_summary: str = "" output_summary: str = "" duration_ms: float | None = None metadata: dict[str, Any] = field(default_factory=dict) timestamp: datetime = field(default_factory=_utc_now) @dataclass class ReasoningTrace: """Complete reasoning trace for a single prompt→response cycle. Attributes: trace_id: Unique identifier for this trace. task_id: Associated task ID. prompt: Original user prompt. steps: Ordered list of trace steps. final_answer: The produced answer. overall_confidence: Final confidence score. metacognitive_summary: Summary of metacognitive assessment. verification_summary: Summary of claim verification. created_at: When the trace was started. """ trace_id: str = "" task_id: str = "" prompt: str = "" steps: list[TraceStep] = field(default_factory=list) final_answer: str = "" overall_confidence: float = 0.0 metacognitive_summary: dict[str, Any] = field(default_factory=dict) verification_summary: dict[str, Any] = field(default_factory=dict) created_at: datetime = field(default_factory=_utc_now) class ReasoningTracer: """Records interpretable reasoning traces for the pipeline. Attach to the reasoning pipeline to capture every decision point. Each trace can be serialized, stored, and queried for debugging or explanation. Args: max_traces: Maximum traces to retain in memory (FIFO). """ def __init__(self, max_traces: int = 1000) -> None: self._traces: dict[str, ReasoningTrace] = {} self._trace_order: list[str] = [] self._max_traces = max_traces self._step_counter = 0 def start_trace(self, trace_id: str, task_id: str, prompt: str) -> ReasoningTrace: """Begin a new reasoning trace. Args: trace_id: Unique ID for this trace. task_id: Associated task ID. prompt: The user's prompt. Returns: The newly created trace. """ if len(self._traces) >= self._max_traces and self._trace_order: oldest = self._trace_order.pop(0) self._traces.pop(oldest, None) trace = ReasoningTrace( trace_id=trace_id, task_id=task_id, prompt=prompt, ) self._traces[trace_id] = trace self._trace_order.append(trace_id) return trace def add_step( self, trace_id: str, stage: str, component: str, input_summary: str = "", output_summary: str = "", duration_ms: float | None = None, metadata: dict[str, Any] | None = None, ) -> TraceStep | None: """Add a step to an existing trace. Args: trace_id: The trace to add the step to. stage: Pipeline stage name. component: Component that executed this step. input_summary: Brief input description. output_summary: Brief output description. duration_ms: Execution time. metadata: Additional data. Returns: The added step, or ``None`` if trace not found. """ trace = self._traces.get(trace_id) if trace is None: return None self._step_counter += 1 step = TraceStep( step_id=f"step_{self._step_counter}", stage=stage, component=component, input_summary=input_summary[:200], output_summary=output_summary[:200], duration_ms=duration_ms, metadata=metadata or {}, ) trace.steps.append(step) return step def finalize_trace( self, trace_id: str, final_answer: str, confidence: float, metacognitive_summary: dict[str, Any] | None = None, verification_summary: dict[str, Any] | None = None, ) -> ReasoningTrace | None: """Finalize a trace with the final answer and assessments. Args: trace_id: The trace to finalize. final_answer: The produced answer. confidence: Overall confidence score. metacognitive_summary: Metacognition assessment summary. verification_summary: Claim verification summary. Returns: The finalized trace, or ``None`` if not found. """ trace = self._traces.get(trace_id) if trace is None: return None trace.final_answer = final_answer trace.overall_confidence = confidence if metacognitive_summary: trace.metacognitive_summary = metacognitive_summary if verification_summary: trace.verification_summary = verification_summary return trace def get_trace(self, trace_id: str) -> ReasoningTrace | None: """Retrieve a trace by ID.""" return self._traces.get(trace_id) def get_recent_traces(self, limit: int = 10) -> list[ReasoningTrace]: """Retrieve the most recent traces.""" recent_ids = self._trace_order[-limit:] return [self._traces[tid] for tid in recent_ids if tid in self._traces] def explain(self, trace_id: str) -> str: """Generate a human-readable explanation of a reasoning trace. Args: trace_id: The trace to explain. Returns: Formatted explanation string. """ trace = self._traces.get(trace_id) if trace is None: return f"Trace '{trace_id}' not found." lines: list[str] = [ f"Reasoning Trace: {trace.trace_id}", f"Task: {trace.task_id}", f"Prompt: {trace.prompt[:100]}", f"Steps: {len(trace.steps)}", "", ] for i, step in enumerate(trace.steps, 1): lines.append(f" {i}. [{step.stage}] {step.component}") if step.input_summary: lines.append(f" Input: {step.input_summary}") if step.output_summary: lines.append(f" Output: {step.output_summary}") if step.duration_ms is not None: lines.append(f" Time: {step.duration_ms:.1f}ms") lines.append("") lines.append(f"Final Answer: {trace.final_answer[:200]}") lines.append(f"Confidence: {trace.overall_confidence:.2f}") if trace.metacognitive_summary: lines.append(f"Metacognition: {trace.metacognitive_summary}") if trace.verification_summary: lines.append(f"Verification: {trace.verification_summary}") return "\n".join(lines) @property def total_traces(self) -> int: """Number of traces stored.""" return len(self._traces)