"""Tests for metacognition and reasoning interpretability.""" from fusionagi.reasoning.interpretability import ReasoningTracer from fusionagi.reasoning.metacognition import ( assess_head_outputs, ) from fusionagi.schemas.grounding import Citation from fusionagi.schemas.head import HeadClaim, HeadId, HeadOutput from fusionagi.verification import ClaimVerifier _SAMPLE_CITATION = Citation(source_id="src_1", excerpt="supporting evidence") def _make_head_output( head_id: HeadId, claims: list[tuple[str, float]] | None = None, ) -> HeadOutput: """Helper to create a head output with claims.""" head_claims = [] for text, conf in (claims or [("Test claim", 0.7)]): head_claims.append(HeadClaim( claim_text=text, confidence=conf, evidence=[_SAMPLE_CITATION] if conf > 0.5 else [], )) return HeadOutput( head_id=head_id, summary=f"Output from {head_id.value}", claims=head_claims, risks=[], ) class TestMetacognition: """Test metacognitive self-assessment.""" def test_empty_outputs(self) -> None: assessment = assess_head_outputs([]) assert assessment.overall_confidence == 0.0 assert assessment.should_seek_more is True def test_high_confidence_outputs(self) -> None: outputs = [ _make_head_output(HeadId.LOGIC, [("Logic is sound", 0.9)]), _make_head_output(HeadId.RESEARCH, [("Data supports this", 0.85)]), ] assessment = assess_head_outputs(outputs) assert assessment.overall_confidence > 0.3 assert isinstance(assessment.knowledge_gaps, list) def test_low_confidence_triggers_seek_more(self) -> None: outputs = [ _make_head_output(HeadId.LOGIC, [("Uncertain claim", 0.1)]), ] assessment = assess_head_outputs(outputs) assert len(assessment.uncertainty_sources) > 0 def test_knowledge_gap_detection(self) -> None: outputs = [ _make_head_output(HeadId.LOGIC, [("Low conf claim", 0.1)]), ] assessment = assess_head_outputs(outputs) gap_domains = [g.domain for g in assessment.knowledge_gaps] assert "logic" in gap_domains def test_domain_gap_detection(self) -> None: outputs = [_make_head_output(HeadId.LOGIC)] assessment = assess_head_outputs(outputs, user_prompt="legal compliance required") gap_domains = [g.domain for g in assessment.knowledge_gaps] assert "legal" in gap_domains class TestReasoningTracer: """Test interpretability tracing.""" def test_trace_lifecycle(self) -> None: tracer = ReasoningTracer() tracer.start_trace("t1", "task1", "What is 2+2?") tracer.add_step("t1", "decomposition", "decomposer", "prompt", "2 units") tracer.add_step("t1", "head_dispatch", "orchestrator", "5 heads", "5 outputs") tracer.finalize_trace("t1", "4", 0.95) result = tracer.get_trace("t1") assert result is not None assert len(result.steps) == 2 assert result.final_answer == "4" assert result.overall_confidence == 0.95 def test_explain(self) -> None: tracer = ReasoningTracer() tracer.start_trace("t1", "task1", "question") tracer.add_step("t1", "stage1", "comp1", "in", "out") tracer.finalize_trace("t1", "answer", 0.8) explanation = tracer.explain("t1") assert "stage1" in explanation assert "answer" in explanation def test_trace_not_found(self) -> None: tracer = ReasoningTracer() assert tracer.get_trace("nonexistent") is None assert "not found" in tracer.explain("nonexistent") def test_recent_traces(self) -> None: tracer = ReasoningTracer() for i in range(5): tracer.start_trace(f"t{i}", f"task{i}", f"prompt{i}") assert len(tracer.get_recent_traces(limit=3)) == 3 assert tracer.total_traces == 5 class TestClaimVerifier: """Test formal claim verification.""" def test_verify_no_outputs(self) -> None: verifier = ClaimVerifier() report = verifier.verify_outputs([]) assert report.total_claims == 0 def test_verify_well_supported_claims(self) -> None: outputs = [ _make_head_output(HeadId.LOGIC, [("Well supported", 0.7)]), _make_head_output(HeadId.RESEARCH, [("Also supported", 0.7)]), ] verifier = ClaimVerifier() report = verifier.verify_outputs(outputs) assert report.total_claims == 2 assert report.overall_integrity > 0.0 def test_high_conf_no_evidence_flagged(self) -> None: claim = HeadClaim(claim_text="Bold claim", confidence=0.95, evidence=[]) output = HeadOutput( head_id=HeadId.LOGIC, summary="Bold output", claims=[claim], risks=[], ) verifier = ClaimVerifier() report = verifier.verify_outputs([output]) assert report.flagged_count >= 1 assert any("evidence" in issue.lower() for r in report.results for issue in r.issues)