feat: consequence engine, causal world model, metacognition, interpretability, claim verification
Some checks failed
Some checks failed
Choice → Consequence → Learning: - ConsequenceEngine tracks every decision point with alternatives, risk/reward estimates, and actual outcomes - Consequences feed into AdaptiveEthics for experience-based learning - FusionAGILoop now wires ethics + consequences into task lifecycle Causal World Model: - CausalWorldModel learns state-transition patterns from execution history - Predicts outcomes based on observed action→effect patterns - Uncertainty estimates decrease as more evidence accumulates Metacognition: - assess_head_outputs() evaluates reasoning quality from head outputs - Detects knowledge gaps, measures head agreement, identifies uncertainty - Actively recommends whether to seek more information Interpretability: - ReasoningTracer captures full prompt→answer reasoning traces - Each step records stage, component, input/output, timing - explain() generates human-readable reasoning explanations Claim Verification: - ClaimVerifier cross-checks claims for evidence, consistency, grounding - Flags high-confidence claims lacking evidence support - Detects contradictions between claims from different heads 325 tests passing, 0 ruff errors, 0 mypy errors. Co-Authored-By: Nakamoto, S <defi@defi-oracle.io>
This commit is contained in:
@@ -11,6 +11,12 @@ All governance components support two modes (``GovernanceMode``):
|
|||||||
from fusionagi.governance.access_control import AccessControl
|
from fusionagi.governance.access_control import AccessControl
|
||||||
from fusionagi.governance.adaptive_ethics import AdaptiveEthics, EthicalLesson
|
from fusionagi.governance.adaptive_ethics import AdaptiveEthics, EthicalLesson
|
||||||
from fusionagi.governance.audit_log import AuditLog
|
from fusionagi.governance.audit_log import AuditLog
|
||||||
|
from fusionagi.governance.consequence_engine import (
|
||||||
|
Alternative,
|
||||||
|
Choice,
|
||||||
|
Consequence,
|
||||||
|
ConsequenceEngine,
|
||||||
|
)
|
||||||
from fusionagi.governance.guardrails import Guardrails, PreCheckResult
|
from fusionagi.governance.guardrails import Guardrails, PreCheckResult
|
||||||
from fusionagi.governance.intent_alignment import IntentAlignment
|
from fusionagi.governance.intent_alignment import IntentAlignment
|
||||||
from fusionagi.governance.override import OverrideHooks
|
from fusionagi.governance.override import OverrideHooks
|
||||||
@@ -27,6 +33,10 @@ from fusionagi.schemas.audit import GovernanceMode
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AdaptiveEthics",
|
"AdaptiveEthics",
|
||||||
|
"Alternative",
|
||||||
|
"Choice",
|
||||||
|
"Consequence",
|
||||||
|
"ConsequenceEngine",
|
||||||
"EthicalLesson",
|
"EthicalLesson",
|
||||||
"GovernanceMode",
|
"GovernanceMode",
|
||||||
"Guardrails",
|
"Guardrails",
|
||||||
|
|||||||
366
fusionagi/governance/consequence_engine.py
Normal file
366
fusionagi/governance/consequence_engine.py
Normal file
@@ -0,0 +1,366 @@
|
|||||||
|
"""Consequence engine: choice → consequence → learning.
|
||||||
|
|
||||||
|
Every decision the system makes is a *choice*. Every choice has
|
||||||
|
*alternatives* that were not taken. Every choice leads to
|
||||||
|
*consequences* — outcomes that carry risk and reward.
|
||||||
|
|
||||||
|
The consequence engine:
|
||||||
|
1. Records decision points (what options existed, which was chosen, why)
|
||||||
|
2. Tracks consequences (what happened as a result)
|
||||||
|
3. Computes risk/reward from historical consequence data
|
||||||
|
4. Feeds consequence data into AdaptiveEthics for learning
|
||||||
|
|
||||||
|
Philosophy:
|
||||||
|
- Consequences are the true teacher. Not rules, not constraints.
|
||||||
|
- Risk is not to be avoided — it is to be *understood*.
|
||||||
|
- Reward without risk teaches nothing. Risk without consequence teaches less.
|
||||||
|
- The system earns trust by showing it understands what its choices cost.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
from fusionagi._logger import logger
|
||||||
|
from fusionagi.schemas.audit import AuditEventType
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLogLike(Protocol):
|
||||||
|
"""Protocol for audit log."""
|
||||||
|
|
||||||
|
def append(
|
||||||
|
self,
|
||||||
|
event_type: AuditEventType,
|
||||||
|
actor: str,
|
||||||
|
action: str = "",
|
||||||
|
task_id: str | None = None,
|
||||||
|
payload: dict[str, Any] | None = None,
|
||||||
|
outcome: str = "",
|
||||||
|
) -> str: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Alternative:
|
||||||
|
"""An option that was available but not chosen.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
action: What the alternative action was.
|
||||||
|
estimated_risk: Estimated risk at decision time (0.0–1.0).
|
||||||
|
estimated_reward: Estimated reward at decision time (0.0–1.0).
|
||||||
|
reason_not_chosen: Why this alternative was not selected.
|
||||||
|
"""
|
||||||
|
|
||||||
|
action: str = ""
|
||||||
|
estimated_risk: float = 0.5
|
||||||
|
estimated_reward: float = 0.5
|
||||||
|
reason_not_chosen: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Choice:
|
||||||
|
"""A decision point where the system selected an action.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
choice_id: Unique identifier for this choice.
|
||||||
|
task_id: Associated task.
|
||||||
|
actor: Component that made the choice.
|
||||||
|
action_taken: The action that was chosen.
|
||||||
|
alternatives: Other options that were available.
|
||||||
|
estimated_risk: Risk estimate at decision time.
|
||||||
|
estimated_reward: Reward estimate at decision time.
|
||||||
|
rationale: Why this action was chosen.
|
||||||
|
context: Situation context at decision time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
choice_id: str = ""
|
||||||
|
task_id: str | None = None
|
||||||
|
actor: str = ""
|
||||||
|
action_taken: str = ""
|
||||||
|
alternatives: list[Alternative] = field(default_factory=list)
|
||||||
|
estimated_risk: float = 0.5
|
||||||
|
estimated_reward: float = 0.5
|
||||||
|
rationale: str = ""
|
||||||
|
context: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Consequence:
|
||||||
|
"""The outcome of a choice — what actually happened.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
choice_id: Which choice this is a consequence of.
|
||||||
|
outcome_positive: Whether the outcome was beneficial.
|
||||||
|
actual_risk_realized: How much risk materialized (0.0–1.0).
|
||||||
|
actual_reward_gained: How much reward was gained (0.0–1.0).
|
||||||
|
description: What happened.
|
||||||
|
cost: Any cost incurred (errors, retries, time).
|
||||||
|
benefit: Any benefit gained (task success, learning).
|
||||||
|
surprise_factor: How unexpected the outcome was (0 = expected, 1 = total surprise).
|
||||||
|
"""
|
||||||
|
|
||||||
|
choice_id: str = ""
|
||||||
|
outcome_positive: bool = True
|
||||||
|
actual_risk_realized: float = 0.0
|
||||||
|
actual_reward_gained: float = 0.5
|
||||||
|
description: str = ""
|
||||||
|
cost: dict[str, Any] = field(default_factory=dict)
|
||||||
|
benefit: dict[str, Any] = field(default_factory=dict)
|
||||||
|
surprise_factor: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class ConsequenceEngine:
|
||||||
|
"""Tracks choices, consequences, and risk/reward patterns.
|
||||||
|
|
||||||
|
The engine maintains a history of all decisions and their outcomes,
|
||||||
|
enabling the system to make better-informed choices over time — not
|
||||||
|
through restriction, but through understanding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audit_log: Optional audit log for recording choices and consequences.
|
||||||
|
risk_memory_window: How many past consequences to consider when
|
||||||
|
estimating risk for new choices.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
audit_log: AuditLogLike | None = None,
|
||||||
|
risk_memory_window: int = 200,
|
||||||
|
) -> None:
|
||||||
|
self._choices: dict[str, Choice] = {}
|
||||||
|
self._consequences: dict[str, Consequence] = {}
|
||||||
|
self._risk_history: dict[str, list[float]] = {}
|
||||||
|
self._reward_history: dict[str, list[float]] = {}
|
||||||
|
self._audit = audit_log
|
||||||
|
self._risk_window = risk_memory_window
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_choices(self) -> int:
|
||||||
|
"""Total choices recorded."""
|
||||||
|
return len(self._choices)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_consequences(self) -> int:
|
||||||
|
"""Total consequences recorded."""
|
||||||
|
return len(self._consequences)
|
||||||
|
|
||||||
|
def record_choice(
|
||||||
|
self,
|
||||||
|
choice_id: str,
|
||||||
|
actor: str,
|
||||||
|
action_taken: str,
|
||||||
|
alternatives: list[Alternative] | None = None,
|
||||||
|
estimated_risk: float = 0.5,
|
||||||
|
estimated_reward: float = 0.5,
|
||||||
|
rationale: str = "",
|
||||||
|
task_id: str | None = None,
|
||||||
|
context: dict[str, Any] | None = None,
|
||||||
|
) -> Choice:
|
||||||
|
"""Record a decision point.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
choice_id: Unique ID for this choice.
|
||||||
|
actor: Component making the choice.
|
||||||
|
action_taken: The selected action.
|
||||||
|
alternatives: Other options considered.
|
||||||
|
estimated_risk: Risk estimate at decision time.
|
||||||
|
estimated_reward: Reward estimate at decision time.
|
||||||
|
rationale: Why this was chosen.
|
||||||
|
task_id: Associated task.
|
||||||
|
context: Situation context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The recorded choice.
|
||||||
|
"""
|
||||||
|
choice = Choice(
|
||||||
|
choice_id=choice_id,
|
||||||
|
task_id=task_id,
|
||||||
|
actor=actor,
|
||||||
|
action_taken=action_taken,
|
||||||
|
alternatives=alternatives or [],
|
||||||
|
estimated_risk=estimated_risk,
|
||||||
|
estimated_reward=estimated_reward,
|
||||||
|
rationale=rationale,
|
||||||
|
context=context or {},
|
||||||
|
)
|
||||||
|
self._choices[choice_id] = choice
|
||||||
|
|
||||||
|
if self._audit:
|
||||||
|
self._audit.append(
|
||||||
|
AuditEventType.CHOICE,
|
||||||
|
actor=actor,
|
||||||
|
action="choice_recorded",
|
||||||
|
task_id=task_id,
|
||||||
|
payload={
|
||||||
|
"choice_id": choice_id,
|
||||||
|
"action_taken": action_taken[:100],
|
||||||
|
"alternatives_count": len(choice.alternatives),
|
||||||
|
"estimated_risk": estimated_risk,
|
||||||
|
"estimated_reward": estimated_reward,
|
||||||
|
"rationale": rationale[:100],
|
||||||
|
},
|
||||||
|
outcome="recorded",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"ConsequenceEngine: choice recorded",
|
||||||
|
extra={
|
||||||
|
"choice_id": choice_id,
|
||||||
|
"action": action_taken[:50],
|
||||||
|
"risk": estimated_risk,
|
||||||
|
"reward": estimated_reward,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return choice
|
||||||
|
|
||||||
|
def record_consequence(
|
||||||
|
self,
|
||||||
|
choice_id: str,
|
||||||
|
outcome_positive: bool,
|
||||||
|
actual_risk_realized: float = 0.0,
|
||||||
|
actual_reward_gained: float = 0.5,
|
||||||
|
description: str = "",
|
||||||
|
cost: dict[str, Any] | None = None,
|
||||||
|
benefit: dict[str, Any] | None = None,
|
||||||
|
) -> Consequence | None:
|
||||||
|
"""Record the consequence of a previous choice.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
choice_id: Which choice this is a consequence of.
|
||||||
|
outcome_positive: Whether the outcome was beneficial.
|
||||||
|
actual_risk_realized: How much risk materialized.
|
||||||
|
actual_reward_gained: How much reward was gained.
|
||||||
|
description: What happened.
|
||||||
|
cost: Costs incurred.
|
||||||
|
benefit: Benefits gained.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The recorded consequence, or ``None`` if choice not found.
|
||||||
|
"""
|
||||||
|
choice = self._choices.get(choice_id)
|
||||||
|
if choice is None:
|
||||||
|
logger.warning(
|
||||||
|
"ConsequenceEngine: choice not found for consequence",
|
||||||
|
extra={"choice_id": choice_id},
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
surprise = abs(choice.estimated_risk - actual_risk_realized) * 0.5 + \
|
||||||
|
abs(choice.estimated_reward - actual_reward_gained) * 0.5
|
||||||
|
|
||||||
|
consequence = Consequence(
|
||||||
|
choice_id=choice_id,
|
||||||
|
outcome_positive=outcome_positive,
|
||||||
|
actual_risk_realized=actual_risk_realized,
|
||||||
|
actual_reward_gained=actual_reward_gained,
|
||||||
|
description=description,
|
||||||
|
cost=cost or {},
|
||||||
|
benefit=benefit or {},
|
||||||
|
surprise_factor=min(1.0, surprise),
|
||||||
|
)
|
||||||
|
self._consequences[choice_id] = consequence
|
||||||
|
|
||||||
|
action_type = choice.action_taken
|
||||||
|
self._risk_history.setdefault(action_type, []).append(actual_risk_realized)
|
||||||
|
self._reward_history.setdefault(action_type, []).append(actual_reward_gained)
|
||||||
|
|
||||||
|
if len(self._risk_history[action_type]) > self._risk_window:
|
||||||
|
self._risk_history[action_type] = self._risk_history[action_type][-self._risk_window:]
|
||||||
|
self._reward_history[action_type] = self._reward_history[action_type][-self._risk_window:]
|
||||||
|
|
||||||
|
if self._audit:
|
||||||
|
self._audit.append(
|
||||||
|
AuditEventType.CONSEQUENCE,
|
||||||
|
actor=choice.actor,
|
||||||
|
action="consequence_recorded",
|
||||||
|
task_id=choice.task_id,
|
||||||
|
payload={
|
||||||
|
"choice_id": choice_id,
|
||||||
|
"outcome_positive": outcome_positive,
|
||||||
|
"risk_realized": actual_risk_realized,
|
||||||
|
"reward_gained": actual_reward_gained,
|
||||||
|
"surprise_factor": consequence.surprise_factor,
|
||||||
|
"description": description[:100],
|
||||||
|
},
|
||||||
|
outcome="positive" if outcome_positive else "negative",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"ConsequenceEngine: consequence recorded",
|
||||||
|
extra={
|
||||||
|
"choice_id": choice_id,
|
||||||
|
"positive": outcome_positive,
|
||||||
|
"surprise": consequence.surprise_factor,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return consequence
|
||||||
|
|
||||||
|
def estimate_risk_reward(self, action_type: str) -> dict[str, float]:
|
||||||
|
"""Estimate risk and reward for an action type based on history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_type: The type of action being considered.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with ``expected_risk``, ``expected_reward``, ``confidence``,
|
||||||
|
``risk_variance``, ``reward_variance``, ``observations``.
|
||||||
|
"""
|
||||||
|
risks = self._risk_history.get(action_type, [])
|
||||||
|
rewards = self._reward_history.get(action_type, [])
|
||||||
|
|
||||||
|
if not risks:
|
||||||
|
return {
|
||||||
|
"expected_risk": 0.5,
|
||||||
|
"expected_reward": 0.5,
|
||||||
|
"confidence": 0.1,
|
||||||
|
"risk_variance": 0.0,
|
||||||
|
"reward_variance": 0.0,
|
||||||
|
"observations": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
n = len(risks)
|
||||||
|
avg_risk = sum(risks) / n
|
||||||
|
avg_reward = sum(rewards) / n
|
||||||
|
risk_var = sum((r - avg_risk) ** 2 for r in risks) / n if n > 1 else 0.0
|
||||||
|
reward_var = sum((r - avg_reward) ** 2 for r in rewards) / n if n > 1 else 0.0
|
||||||
|
|
||||||
|
confidence = min(1.0, 0.2 + n * 0.04)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"expected_risk": avg_risk,
|
||||||
|
"expected_reward": avg_reward,
|
||||||
|
"confidence": confidence,
|
||||||
|
"risk_variance": risk_var,
|
||||||
|
"reward_variance": reward_var,
|
||||||
|
"observations": n,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_choice(self, choice_id: str) -> Choice | None:
|
||||||
|
"""Retrieve a recorded choice."""
|
||||||
|
return self._choices.get(choice_id)
|
||||||
|
|
||||||
|
def get_consequence(self, choice_id: str) -> Consequence | None:
|
||||||
|
"""Retrieve the consequence of a choice."""
|
||||||
|
return self._consequences.get(choice_id)
|
||||||
|
|
||||||
|
def get_summary(self) -> dict[str, Any]:
|
||||||
|
"""Return a summary of all choices and consequences."""
|
||||||
|
total_positive = sum(1 for c in self._consequences.values() if c.outcome_positive)
|
||||||
|
total_negative = len(self._consequences) - total_positive
|
||||||
|
avg_surprise = (
|
||||||
|
sum(c.surprise_factor for c in self._consequences.values()) / max(len(self._consequences), 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
action_stats: dict[str, dict[str, Any]] = {}
|
||||||
|
for action_type in self._risk_history:
|
||||||
|
action_stats[action_type] = self.estimate_risk_reward(action_type)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_choices": len(self._choices),
|
||||||
|
"total_consequences": len(self._consequences),
|
||||||
|
"positive_outcomes": total_positive,
|
||||||
|
"negative_outcomes": total_negative,
|
||||||
|
"positive_rate": total_positive / max(len(self._consequences), 1),
|
||||||
|
"avg_surprise": avg_surprise,
|
||||||
|
"action_stats": action_stats,
|
||||||
|
}
|
||||||
@@ -10,11 +10,21 @@ from fusionagi.reasoning.gpu_scoring import (
|
|||||||
generate_and_score_gpu,
|
generate_and_score_gpu,
|
||||||
score_claims_gpu,
|
score_claims_gpu,
|
||||||
)
|
)
|
||||||
|
from fusionagi.reasoning.interpretability import (
|
||||||
|
ReasoningTrace,
|
||||||
|
ReasoningTracer,
|
||||||
|
TraceStep,
|
||||||
|
)
|
||||||
from fusionagi.reasoning.meta_reasoning import (
|
from fusionagi.reasoning.meta_reasoning import (
|
||||||
challenge_assumptions,
|
challenge_assumptions,
|
||||||
detect_contradictions,
|
detect_contradictions,
|
||||||
revisit_node,
|
revisit_node,
|
||||||
)
|
)
|
||||||
|
from fusionagi.reasoning.metacognition import (
|
||||||
|
KnowledgeGap,
|
||||||
|
MetacognitiveAssessment,
|
||||||
|
assess_head_outputs,
|
||||||
|
)
|
||||||
from fusionagi.reasoning.multi_path import generate_and_score_parallel
|
from fusionagi.reasoning.multi_path import generate_and_score_parallel
|
||||||
from fusionagi.reasoning.native import (
|
from fusionagi.reasoning.native import (
|
||||||
NativeReasoningProvider,
|
NativeReasoningProvider,
|
||||||
@@ -61,4 +71,10 @@ __all__ = [
|
|||||||
"generate_and_score_gpu",
|
"generate_and_score_gpu",
|
||||||
"score_claims_gpu",
|
"score_claims_gpu",
|
||||||
"deduplicate_claims_gpu",
|
"deduplicate_claims_gpu",
|
||||||
|
"MetacognitiveAssessment",
|
||||||
|
"KnowledgeGap",
|
||||||
|
"assess_head_outputs",
|
||||||
|
"ReasoningTrace",
|
||||||
|
"ReasoningTracer",
|
||||||
|
"TraceStep",
|
||||||
]
|
]
|
||||||
|
|||||||
247
fusionagi/reasoning/interpretability.py
Normal file
247
fusionagi/reasoning/interpretability.py
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
"""Interpretability: full reasoning trace from prompt to final answer.
|
||||||
|
|
||||||
|
Every step of the reasoning pipeline can be traced and explained:
|
||||||
|
- Prompt decomposition decisions
|
||||||
|
- Head selection and dispatch
|
||||||
|
- Per-head claim generation with evidence chains
|
||||||
|
- Consensus process (agreements, disputes)
|
||||||
|
- Metacognitive assessment
|
||||||
|
- Verification results
|
||||||
|
- Final synthesis rationale
|
||||||
|
|
||||||
|
The ReasoningTrace captures all of this in a structured, queryable format
|
||||||
|
that can be serialized for debugging, auditing, or user explanation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def _utc_now() -> datetime:
|
||||||
|
"""Return current UTC time (timezone-aware)."""
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TraceStep:
|
||||||
|
"""A single step in the reasoning trace.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
step_id: Unique identifier for this step.
|
||||||
|
stage: Pipeline stage (e.g. ``decomposition``, ``head_dispatch``).
|
||||||
|
component: Component that executed this step.
|
||||||
|
input_summary: Brief summary of the step's input.
|
||||||
|
output_summary: Brief summary of the step's output.
|
||||||
|
duration_ms: Execution time in milliseconds (if measured).
|
||||||
|
metadata: Additional structured data.
|
||||||
|
timestamp: When this step was recorded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
step_id: str = ""
|
||||||
|
stage: str = ""
|
||||||
|
component: str = ""
|
||||||
|
input_summary: str = ""
|
||||||
|
output_summary: str = ""
|
||||||
|
duration_ms: float | None = None
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
timestamp: datetime = field(default_factory=_utc_now)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReasoningTrace:
|
||||||
|
"""Complete reasoning trace for a single prompt→response cycle.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
trace_id: Unique identifier for this trace.
|
||||||
|
task_id: Associated task ID.
|
||||||
|
prompt: Original user prompt.
|
||||||
|
steps: Ordered list of trace steps.
|
||||||
|
final_answer: The produced answer.
|
||||||
|
overall_confidence: Final confidence score.
|
||||||
|
metacognitive_summary: Summary of metacognitive assessment.
|
||||||
|
verification_summary: Summary of claim verification.
|
||||||
|
created_at: When the trace was started.
|
||||||
|
"""
|
||||||
|
|
||||||
|
trace_id: str = ""
|
||||||
|
task_id: str = ""
|
||||||
|
prompt: str = ""
|
||||||
|
steps: list[TraceStep] = field(default_factory=list)
|
||||||
|
final_answer: str = ""
|
||||||
|
overall_confidence: float = 0.0
|
||||||
|
metacognitive_summary: dict[str, Any] = field(default_factory=dict)
|
||||||
|
verification_summary: dict[str, Any] = field(default_factory=dict)
|
||||||
|
created_at: datetime = field(default_factory=_utc_now)
|
||||||
|
|
||||||
|
|
||||||
|
class ReasoningTracer:
|
||||||
|
"""Records interpretable reasoning traces for the pipeline.
|
||||||
|
|
||||||
|
Attach to the reasoning pipeline to capture every decision point.
|
||||||
|
Each trace can be serialized, stored, and queried for debugging
|
||||||
|
or explanation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_traces: Maximum traces to retain in memory (FIFO).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_traces: int = 1000) -> None:
|
||||||
|
self._traces: dict[str, ReasoningTrace] = {}
|
||||||
|
self._trace_order: list[str] = []
|
||||||
|
self._max_traces = max_traces
|
||||||
|
self._step_counter = 0
|
||||||
|
|
||||||
|
def start_trace(self, trace_id: str, task_id: str, prompt: str) -> ReasoningTrace:
|
||||||
|
"""Begin a new reasoning trace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trace_id: Unique ID for this trace.
|
||||||
|
task_id: Associated task ID.
|
||||||
|
prompt: The user's prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The newly created trace.
|
||||||
|
"""
|
||||||
|
if len(self._traces) >= self._max_traces and self._trace_order:
|
||||||
|
oldest = self._trace_order.pop(0)
|
||||||
|
self._traces.pop(oldest, None)
|
||||||
|
|
||||||
|
trace = ReasoningTrace(
|
||||||
|
trace_id=trace_id,
|
||||||
|
task_id=task_id,
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
self._traces[trace_id] = trace
|
||||||
|
self._trace_order.append(trace_id)
|
||||||
|
return trace
|
||||||
|
|
||||||
|
def add_step(
|
||||||
|
self,
|
||||||
|
trace_id: str,
|
||||||
|
stage: str,
|
||||||
|
component: str,
|
||||||
|
input_summary: str = "",
|
||||||
|
output_summary: str = "",
|
||||||
|
duration_ms: float | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> TraceStep | None:
|
||||||
|
"""Add a step to an existing trace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trace_id: The trace to add the step to.
|
||||||
|
stage: Pipeline stage name.
|
||||||
|
component: Component that executed this step.
|
||||||
|
input_summary: Brief input description.
|
||||||
|
output_summary: Brief output description.
|
||||||
|
duration_ms: Execution time.
|
||||||
|
metadata: Additional data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The added step, or ``None`` if trace not found.
|
||||||
|
"""
|
||||||
|
trace = self._traces.get(trace_id)
|
||||||
|
if trace is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
self._step_counter += 1
|
||||||
|
step = TraceStep(
|
||||||
|
step_id=f"step_{self._step_counter}",
|
||||||
|
stage=stage,
|
||||||
|
component=component,
|
||||||
|
input_summary=input_summary[:200],
|
||||||
|
output_summary=output_summary[:200],
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
trace.steps.append(step)
|
||||||
|
return step
|
||||||
|
|
||||||
|
def finalize_trace(
|
||||||
|
self,
|
||||||
|
trace_id: str,
|
||||||
|
final_answer: str,
|
||||||
|
confidence: float,
|
||||||
|
metacognitive_summary: dict[str, Any] | None = None,
|
||||||
|
verification_summary: dict[str, Any] | None = None,
|
||||||
|
) -> ReasoningTrace | None:
|
||||||
|
"""Finalize a trace with the final answer and assessments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trace_id: The trace to finalize.
|
||||||
|
final_answer: The produced answer.
|
||||||
|
confidence: Overall confidence score.
|
||||||
|
metacognitive_summary: Metacognition assessment summary.
|
||||||
|
verification_summary: Claim verification summary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The finalized trace, or ``None`` if not found.
|
||||||
|
"""
|
||||||
|
trace = self._traces.get(trace_id)
|
||||||
|
if trace is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
trace.final_answer = final_answer
|
||||||
|
trace.overall_confidence = confidence
|
||||||
|
if metacognitive_summary:
|
||||||
|
trace.metacognitive_summary = metacognitive_summary
|
||||||
|
if verification_summary:
|
||||||
|
trace.verification_summary = verification_summary
|
||||||
|
return trace
|
||||||
|
|
||||||
|
def get_trace(self, trace_id: str) -> ReasoningTrace | None:
|
||||||
|
"""Retrieve a trace by ID."""
|
||||||
|
return self._traces.get(trace_id)
|
||||||
|
|
||||||
|
def get_recent_traces(self, limit: int = 10) -> list[ReasoningTrace]:
|
||||||
|
"""Retrieve the most recent traces."""
|
||||||
|
recent_ids = self._trace_order[-limit:]
|
||||||
|
return [self._traces[tid] for tid in recent_ids if tid in self._traces]
|
||||||
|
|
||||||
|
def explain(self, trace_id: str) -> str:
|
||||||
|
"""Generate a human-readable explanation of a reasoning trace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trace_id: The trace to explain.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted explanation string.
|
||||||
|
"""
|
||||||
|
trace = self._traces.get(trace_id)
|
||||||
|
if trace is None:
|
||||||
|
return f"Trace '{trace_id}' not found."
|
||||||
|
|
||||||
|
lines: list[str] = [
|
||||||
|
f"Reasoning Trace: {trace.trace_id}",
|
||||||
|
f"Task: {trace.task_id}",
|
||||||
|
f"Prompt: {trace.prompt[:100]}",
|
||||||
|
f"Steps: {len(trace.steps)}",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, step in enumerate(trace.steps, 1):
|
||||||
|
lines.append(f" {i}. [{step.stage}] {step.component}")
|
||||||
|
if step.input_summary:
|
||||||
|
lines.append(f" Input: {step.input_summary}")
|
||||||
|
if step.output_summary:
|
||||||
|
lines.append(f" Output: {step.output_summary}")
|
||||||
|
if step.duration_ms is not None:
|
||||||
|
lines.append(f" Time: {step.duration_ms:.1f}ms")
|
||||||
|
|
||||||
|
lines.append("")
|
||||||
|
lines.append(f"Final Answer: {trace.final_answer[:200]}")
|
||||||
|
lines.append(f"Confidence: {trace.overall_confidence:.2f}")
|
||||||
|
|
||||||
|
if trace.metacognitive_summary:
|
||||||
|
lines.append(f"Metacognition: {trace.metacognitive_summary}")
|
||||||
|
if trace.verification_summary:
|
||||||
|
lines.append(f"Verification: {trace.verification_summary}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_traces(self) -> int:
|
||||||
|
"""Number of traces stored."""
|
||||||
|
return len(self._traces)
|
||||||
262
fusionagi/reasoning/metacognition.py
Normal file
262
fusionagi/reasoning/metacognition.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
"""Metacognition: self-awareness of knowledge boundaries and reasoning quality.
|
||||||
|
|
||||||
|
The metacognition engine monitors the system's own reasoning processes
|
||||||
|
and produces self-assessments:
|
||||||
|
- Does the system have enough evidence to answer confidently?
|
||||||
|
- Which knowledge gaps exist?
|
||||||
|
- Where are the reasoning weak points?
|
||||||
|
- Should the system seek more information before answering?
|
||||||
|
|
||||||
|
This is distinct from meta_reasoning.py (which challenges assumptions
|
||||||
|
and detects contradictions in content). Metacognition operates on
|
||||||
|
the *process* level — it reasons about the quality of reasoning itself.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from fusionagi._logger import logger
|
||||||
|
from fusionagi.schemas.head import HeadOutput
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KnowledgeGap:
|
||||||
|
"""An identified gap in the system's knowledge.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
domain: Knowledge domain (e.g. ``legal``, ``medical``).
|
||||||
|
description: What the system doesn't know.
|
||||||
|
severity: Impact on answer quality (``low``, ``medium``, ``high``).
|
||||||
|
resolvable: Whether additional tool calls could fill this gap.
|
||||||
|
"""
|
||||||
|
|
||||||
|
domain: str
|
||||||
|
description: str
|
||||||
|
severity: str = "medium"
|
||||||
|
resolvable: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MetacognitiveAssessment:
|
||||||
|
"""Self-assessment of reasoning quality for a given task.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
overall_confidence: System's confidence in its answer (0.0–1.0).
|
||||||
|
evidence_sufficiency: Whether evidence is sufficient (0.0–1.0).
|
||||||
|
knowledge_gaps: Identified gaps in knowledge.
|
||||||
|
reasoning_quality: Assessment of the reasoning chain quality.
|
||||||
|
should_seek_more: Whether the system should seek more info.
|
||||||
|
head_agreement: Fraction of heads that agree (0.0–1.0).
|
||||||
|
uncertainty_sources: Where uncertainty comes from.
|
||||||
|
recommendations: What the system should do next.
|
||||||
|
"""
|
||||||
|
|
||||||
|
overall_confidence: float = 0.5
|
||||||
|
evidence_sufficiency: float = 0.5
|
||||||
|
knowledge_gaps: list[KnowledgeGap] = field(default_factory=list)
|
||||||
|
reasoning_quality: float = 0.5
|
||||||
|
should_seek_more: bool = False
|
||||||
|
head_agreement: float = 0.5
|
||||||
|
uncertainty_sources: list[str] = field(default_factory=list)
|
||||||
|
recommendations: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def assess_head_outputs(
|
||||||
|
outputs: list[HeadOutput],
|
||||||
|
user_prompt: str = "",
|
||||||
|
) -> MetacognitiveAssessment:
|
||||||
|
"""Assess reasoning quality from head outputs.
|
||||||
|
|
||||||
|
Analyzes the collection of head outputs for agreement patterns,
|
||||||
|
confidence distribution, evidence coverage, and knowledge gaps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs: Outputs from Dvādaśa content heads.
|
||||||
|
user_prompt: Original user prompt for context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Metacognitive assessment of reasoning quality.
|
||||||
|
"""
|
||||||
|
if not outputs:
|
||||||
|
return MetacognitiveAssessment(
|
||||||
|
overall_confidence=0.0,
|
||||||
|
evidence_sufficiency=0.0,
|
||||||
|
should_seek_more=True,
|
||||||
|
uncertainty_sources=["No head outputs available"],
|
||||||
|
recommendations=["Execute head pipeline before assessment"],
|
||||||
|
)
|
||||||
|
|
||||||
|
confidences: list[float] = []
|
||||||
|
for out in outputs:
|
||||||
|
if out.claims:
|
||||||
|
confidences.extend(c.confidence for c in out.claims)
|
||||||
|
else:
|
||||||
|
confidences.append(0.0)
|
||||||
|
avg_confidence = sum(confidences) / len(confidences) if confidences else 0.0
|
||||||
|
|
||||||
|
all_claims: list[str] = []
|
||||||
|
for out in outputs:
|
||||||
|
all_claims.extend(c.claim_text for c in out.claims)
|
||||||
|
|
||||||
|
evidence_counts = []
|
||||||
|
for out in outputs:
|
||||||
|
for c in out.claims:
|
||||||
|
evidence_counts.append(len(c.evidence))
|
||||||
|
avg_evidence = sum(evidence_counts) / max(len(evidence_counts), 1)
|
||||||
|
evidence_sufficiency = min(1.0, avg_evidence / 3.0)
|
||||||
|
|
||||||
|
head_agreement = _compute_head_agreement(outputs)
|
||||||
|
|
||||||
|
gaps = _detect_knowledge_gaps(outputs, user_prompt)
|
||||||
|
|
||||||
|
uncertainty_sources: list[str] = []
|
||||||
|
if avg_confidence < 0.5:
|
||||||
|
uncertainty_sources.append(f"Low average head confidence: {avg_confidence:.2f}")
|
||||||
|
if head_agreement < 0.4:
|
||||||
|
uncertainty_sources.append(f"Low head agreement: {head_agreement:.2f}")
|
||||||
|
if evidence_sufficiency < 0.3:
|
||||||
|
uncertainty_sources.append(f"Insufficient evidence: avg {avg_evidence:.1f} per claim")
|
||||||
|
if gaps:
|
||||||
|
uncertainty_sources.append(f"{len(gaps)} knowledge gap(s) detected")
|
||||||
|
|
||||||
|
conf_variance = _variance(confidences) if len(confidences) > 1 else 0.0
|
||||||
|
if conf_variance > 0.1:
|
||||||
|
uncertainty_sources.append(
|
||||||
|
f"High confidence variance across heads: {conf_variance:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
reasoning_quality = (
|
||||||
|
0.4 * avg_confidence
|
||||||
|
+ 0.3 * head_agreement
|
||||||
|
+ 0.2 * evidence_sufficiency
|
||||||
|
+ 0.1 * (1.0 - min(1.0, len(gaps) * 0.2))
|
||||||
|
)
|
||||||
|
|
||||||
|
should_seek_more = (
|
||||||
|
reasoning_quality < 0.4
|
||||||
|
or evidence_sufficiency < 0.3
|
||||||
|
or any(g.severity == "high" and g.resolvable for g in gaps)
|
||||||
|
)
|
||||||
|
|
||||||
|
recommendations: list[str] = []
|
||||||
|
if should_seek_more:
|
||||||
|
recommendations.append("Seek additional evidence before finalizing answer")
|
||||||
|
if head_agreement < 0.4:
|
||||||
|
recommendations.append("Run second-pass with disputed heads for clarification")
|
||||||
|
for gap in gaps:
|
||||||
|
if gap.resolvable:
|
||||||
|
recommendations.append(f"Fill knowledge gap: {gap.description}")
|
||||||
|
|
||||||
|
overall = min(1.0, 0.5 * reasoning_quality + 0.3 * head_agreement + 0.2 * evidence_sufficiency)
|
||||||
|
|
||||||
|
assessment = MetacognitiveAssessment(
|
||||||
|
overall_confidence=overall,
|
||||||
|
evidence_sufficiency=evidence_sufficiency,
|
||||||
|
knowledge_gaps=gaps,
|
||||||
|
reasoning_quality=reasoning_quality,
|
||||||
|
should_seek_more=should_seek_more,
|
||||||
|
head_agreement=head_agreement,
|
||||||
|
uncertainty_sources=uncertainty_sources,
|
||||||
|
recommendations=recommendations,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Metacognition: assessment complete",
|
||||||
|
extra={
|
||||||
|
"overall_confidence": overall,
|
||||||
|
"reasoning_quality": reasoning_quality,
|
||||||
|
"head_agreement": head_agreement,
|
||||||
|
"gaps": len(gaps),
|
||||||
|
"should_seek_more": should_seek_more,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return assessment
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_head_agreement(outputs: list[HeadOutput]) -> float:
|
||||||
|
"""Measure how much heads agree with each other.
|
||||||
|
|
||||||
|
Uses claim text overlap across heads as a proxy for agreement.
|
||||||
|
"""
|
||||||
|
if len(outputs) < 2:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
claim_sets: list[set[str]] = []
|
||||||
|
for out in outputs:
|
||||||
|
words: set[str] = set()
|
||||||
|
for c in out.claims:
|
||||||
|
words.update(c.claim_text.lower().split())
|
||||||
|
claim_sets.append(words)
|
||||||
|
|
||||||
|
agreements: float = 0.0
|
||||||
|
comparisons = 0
|
||||||
|
for i in range(len(claim_sets)):
|
||||||
|
for j in range(i + 1, len(claim_sets)):
|
||||||
|
if not claim_sets[i] or not claim_sets[j]:
|
||||||
|
continue
|
||||||
|
overlap = len(claim_sets[i] & claim_sets[j])
|
||||||
|
union = len(claim_sets[i] | claim_sets[j])
|
||||||
|
if union > 0:
|
||||||
|
agreements += overlap / union
|
||||||
|
comparisons += 1
|
||||||
|
|
||||||
|
return agreements / max(comparisons, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_knowledge_gaps(
|
||||||
|
outputs: list[HeadOutput],
|
||||||
|
user_prompt: str,
|
||||||
|
) -> list[KnowledgeGap]:
|
||||||
|
"""Detect knowledge gaps from head outputs and prompt analysis."""
|
||||||
|
gaps: list[KnowledgeGap] = []
|
||||||
|
|
||||||
|
for out in outputs:
|
||||||
|
if out.claims:
|
||||||
|
avg_claim_conf = sum(c.confidence for c in out.claims) / len(out.claims)
|
||||||
|
else:
|
||||||
|
avg_claim_conf = 0.0
|
||||||
|
if avg_claim_conf < 0.3:
|
||||||
|
gaps.append(KnowledgeGap(
|
||||||
|
domain=out.head_id.value,
|
||||||
|
description=f"Head '{out.head_id.value}' has very low confidence ({avg_claim_conf:.2f})",
|
||||||
|
severity="high" if avg_claim_conf < 0.15 else "medium",
|
||||||
|
resolvable=True,
|
||||||
|
))
|
||||||
|
|
||||||
|
empty_heads = [o for o in outputs if not o.claims]
|
||||||
|
for out in empty_heads:
|
||||||
|
gaps.append(KnowledgeGap(
|
||||||
|
domain=out.head_id.value,
|
||||||
|
description=f"Head '{out.head_id.value}' produced no claims",
|
||||||
|
severity="medium",
|
||||||
|
resolvable=True,
|
||||||
|
))
|
||||||
|
|
||||||
|
prompt_lower = user_prompt.lower()
|
||||||
|
domain_indicators = {
|
||||||
|
"legal": ["law", "legal", "court", "statute", "regulation", "compliance"],
|
||||||
|
"medical": ["medical", "health", "disease", "treatment", "clinical", "patient"],
|
||||||
|
"financial": ["financial", "stock", "market", "investment", "trading", "portfolio"],
|
||||||
|
"scientific": ["experiment", "hypothesis", "data", "study", "research", "evidence"],
|
||||||
|
}
|
||||||
|
for domain, keywords in domain_indicators.items():
|
||||||
|
if any(kw in prompt_lower for kw in keywords):
|
||||||
|
head_domains = {o.head_id.value for o in outputs}
|
||||||
|
if domain not in head_domains:
|
||||||
|
gaps.append(KnowledgeGap(
|
||||||
|
domain=domain,
|
||||||
|
description=f"Prompt references '{domain}' domain but no specialized head covers it",
|
||||||
|
severity="medium",
|
||||||
|
resolvable=False,
|
||||||
|
))
|
||||||
|
|
||||||
|
return gaps
|
||||||
|
|
||||||
|
|
||||||
|
def _variance(values: list[float]) -> float:
|
||||||
|
"""Compute variance of a list of floats."""
|
||||||
|
if len(values) < 2:
|
||||||
|
return 0.0
|
||||||
|
mean = sum(values) / len(values)
|
||||||
|
return sum((v - mean) ** 2 for v in values) / len(values)
|
||||||
@@ -38,6 +38,8 @@ class AuditEventType(str, Enum):
|
|||||||
ADVISORY = "advisory"
|
ADVISORY = "advisory"
|
||||||
SELF_IMPROVEMENT = "self_improvement"
|
SELF_IMPROVEMENT = "self_improvement"
|
||||||
ETHICAL_LEARNING = "ethical_learning"
|
ETHICAL_LEARNING = "ethical_learning"
|
||||||
|
CHOICE = "choice"
|
||||||
|
CONSEQUENCE = "consequence"
|
||||||
OTHER = "other"
|
OTHER = "other"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,20 @@
|
|||||||
"""AGI loop: wires self-correction, auto-recommend, and auto-training to events."""
|
"""AGI loop: wires self-correction, auto-training, adaptive ethics, and
|
||||||
|
consequence tracking to the event bus.
|
||||||
|
|
||||||
|
Choice → Consequence → Learning:
|
||||||
|
- Every task failure/success is recorded as a consequence of the choices made.
|
||||||
|
- Consequences feed into AdaptiveEthics for learned ethical growth.
|
||||||
|
- The ConsequenceEngine tracks risk/reward patterns across all actions.
|
||||||
|
- Trust is earned through demonstrable learning from outcomes.
|
||||||
|
"""
|
||||||
|
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
from fusionagi._logger import logger
|
from fusionagi._logger import logger
|
||||||
from fusionagi.core.event_bus import EventBus
|
from fusionagi.core.event_bus import EventBus
|
||||||
|
from fusionagi.governance.adaptive_ethics import AdaptiveEthics
|
||||||
|
from fusionagi.governance.audit_log import AuditLog
|
||||||
|
from fusionagi.governance.consequence_engine import ConsequenceEngine
|
||||||
from fusionagi.schemas.recommendation import Recommendation, TrainingSuggestion
|
from fusionagi.schemas.recommendation import Recommendation, TrainingSuggestion
|
||||||
from fusionagi.schemas.task import TaskState
|
from fusionagi.schemas.task import TaskState
|
||||||
from fusionagi.self_improvement.correction import (
|
from fusionagi.self_improvement.correction import (
|
||||||
@@ -17,10 +28,24 @@ from fusionagi.self_improvement.training import AutoTrainer, ReflectiveMemoryLik
|
|||||||
|
|
||||||
|
|
||||||
class FusionAGILoop:
|
class FusionAGILoop:
|
||||||
"""
|
"""High-level AGI loop with consequence-driven learning.
|
||||||
High-level AGI loop: subscribes to task_state_changed and reflection_done,
|
|
||||||
runs self-correction on failures, and runs auto-recommend + auto-training
|
Subscribes to task_state_changed and reflection_done events.
|
||||||
after reflection. Composes the world's most advanced agentic AGI self-improvement pipeline.
|
Runs self-correction on failures, auto-recommend + auto-training
|
||||||
|
after reflection, and feeds all outcomes into the adaptive ethics
|
||||||
|
and consequence engines.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_bus: Event bus for task and reflection events.
|
||||||
|
state_manager: State manager for task state and traces.
|
||||||
|
orchestrator: Orchestrator for plan and state transitions.
|
||||||
|
critic_agent: Critic agent for evaluation.
|
||||||
|
reflective_memory: Optional reflective memory for lessons/heuristics.
|
||||||
|
audit_log: Optional audit log for full transparency.
|
||||||
|
auto_retry_on_failure: Auto-retry failed tasks.
|
||||||
|
max_retries_per_task: Max retries per task (``None`` = unlimited).
|
||||||
|
on_recommendations: Callback for recommendations.
|
||||||
|
on_training_suggestions: Callback for training suggestions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -30,26 +55,13 @@ class FusionAGILoop:
|
|||||||
orchestrator: OrchestratorLike,
|
orchestrator: OrchestratorLike,
|
||||||
critic_agent: CriticLike,
|
critic_agent: CriticLike,
|
||||||
reflective_memory: ReflectiveMemoryLike | None = None,
|
reflective_memory: ReflectiveMemoryLike | None = None,
|
||||||
|
audit_log: AuditLog | None = None,
|
||||||
*,
|
*,
|
||||||
auto_retry_on_failure: bool = False,
|
auto_retry_on_failure: bool = False,
|
||||||
max_retries_per_task: int = 2,
|
max_retries_per_task: int | None = None,
|
||||||
on_recommendations: Callable[[list[Recommendation]], None] | None = None,
|
on_recommendations: Callable[[list[Recommendation]], None] | None = None,
|
||||||
on_training_suggestions: Callable[[list[TrainingSuggestion]], None] | None = None,
|
on_training_suggestions: Callable[[list[TrainingSuggestion]], None] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
|
||||||
Initialize the FusionAGI loop.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_bus: Event bus to subscribe to task_state_changed and reflection_done.
|
|
||||||
state_manager: State manager for task state and traces.
|
|
||||||
orchestrator: Orchestrator for plan and state transitions.
|
|
||||||
critic_agent: Critic agent for evaluate_request -> evaluation_ready.
|
|
||||||
reflective_memory: Optional reflective memory for lessons/heuristics.
|
|
||||||
auto_retry_on_failure: If True, on FAILED transition prepare_retry automatically.
|
|
||||||
max_retries_per_task: Max retries per task when auto_retry_on_failure is True.
|
|
||||||
on_recommendations: Optional callback to receive recommendations (e.g. log or UI).
|
|
||||||
on_training_suggestions: Optional callback to receive training suggestions.
|
|
||||||
"""
|
|
||||||
self._event_bus = event_bus
|
self._event_bus = event_bus
|
||||||
self._state = state_manager
|
self._state = state_manager
|
||||||
self._orchestrator = orchestrator
|
self._orchestrator = orchestrator
|
||||||
@@ -59,6 +71,10 @@ class FusionAGILoop:
|
|||||||
self._on_recs = on_recommendations
|
self._on_recs = on_recommendations
|
||||||
self._on_training = on_training_suggestions
|
self._on_training = on_training_suggestions
|
||||||
|
|
||||||
|
self._audit = audit_log or AuditLog()
|
||||||
|
self._ethics = AdaptiveEthics(audit_log=self._audit)
|
||||||
|
self._consequences = ConsequenceEngine(audit_log=self._audit)
|
||||||
|
|
||||||
self._correction = SelfCorrectionLoop(
|
self._correction = SelfCorrectionLoop(
|
||||||
state_manager=state_manager,
|
state_manager=state_manager,
|
||||||
orchestrator=orchestrator,
|
orchestrator=orchestrator,
|
||||||
@@ -66,27 +82,85 @@ class FusionAGILoop:
|
|||||||
max_retries_per_task=max_retries_per_task,
|
max_retries_per_task=max_retries_per_task,
|
||||||
)
|
)
|
||||||
self._recommender = AutoRecommender(reflective_memory=reflective_memory)
|
self._recommender = AutoRecommender(reflective_memory=reflective_memory)
|
||||||
self._trainer = AutoTrainer(reflective_memory=reflective_memory)
|
self._trainer = AutoTrainer(
|
||||||
|
reflective_memory=reflective_memory,
|
||||||
|
audit_log=self._audit,
|
||||||
|
)
|
||||||
|
|
||||||
self._event_bus.subscribe("task_state_changed", self._on_task_state_changed)
|
self._event_bus.subscribe("task_state_changed", self._on_task_state_changed)
|
||||||
self._event_bus.subscribe("reflection_done", self._on_reflection_done)
|
self._event_bus.subscribe("reflection_done", self._on_reflection_done)
|
||||||
logger.info("FusionAGILoop: subscribed to task_state_changed and reflection_done")
|
logger.info("FusionAGILoop: subscribed (with consequence + ethics engines)")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ethics(self) -> AdaptiveEthics:
|
||||||
|
"""Access the adaptive ethics engine."""
|
||||||
|
return self._ethics
|
||||||
|
|
||||||
|
@property
|
||||||
|
def consequences(self) -> ConsequenceEngine:
|
||||||
|
"""Access the consequence engine."""
|
||||||
|
return self._consequences
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audit_log(self) -> AuditLog:
|
||||||
|
"""Access the audit log."""
|
||||||
|
return self._audit
|
||||||
|
|
||||||
def _on_task_state_changed(self, event_type: str, payload: dict[str, Any]) -> None:
|
def _on_task_state_changed(self, event_type: str, payload: dict[str, Any]) -> None:
|
||||||
"""On FAILED, optionally run self-correction and prepare retry."""
|
"""On state change, record consequences and optionally retry."""
|
||||||
try:
|
try:
|
||||||
to_state = payload.get("to_state")
|
to_state = payload.get("to_state")
|
||||||
task_id = payload.get("task_id", "")
|
task_id = payload.get("task_id", "")
|
||||||
if to_state != TaskState.FAILED.value or not task_id:
|
if not task_id:
|
||||||
return
|
return
|
||||||
if self._auto_retry:
|
|
||||||
ok, _ = self._correction.suggest_retry(task_id)
|
if to_state == TaskState.FAILED.value:
|
||||||
if ok:
|
self._consequences.record_consequence(
|
||||||
self._correction.prepare_retry(task_id)
|
choice_id=f"task_{task_id}",
|
||||||
else:
|
outcome_positive=False,
|
||||||
recs = self._correction.correction_recommendations(task_id)
|
actual_risk_realized=0.8,
|
||||||
if recs and self._on_recs:
|
actual_reward_gained=0.1,
|
||||||
self._on_recs(recs)
|
description=f"Task {task_id} failed",
|
||||||
|
cost={"retries_needed": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
self._ethics.record_experience(
|
||||||
|
action_type="task_execution",
|
||||||
|
context_summary=f"Task {task_id} execution",
|
||||||
|
advisory_reason="",
|
||||||
|
proceeded=True,
|
||||||
|
outcome_positive=False,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._auto_retry:
|
||||||
|
ok, _ = self._correction.suggest_retry(task_id)
|
||||||
|
if ok:
|
||||||
|
self._correction.prepare_retry(task_id)
|
||||||
|
else:
|
||||||
|
recs = self._correction.correction_recommendations(task_id)
|
||||||
|
if recs and self._on_recs:
|
||||||
|
self._on_recs(recs)
|
||||||
|
|
||||||
|
elif to_state == TaskState.COMPLETED.value:
|
||||||
|
self._consequences.record_consequence(
|
||||||
|
choice_id=f"task_{task_id}",
|
||||||
|
outcome_positive=True,
|
||||||
|
actual_risk_realized=0.1,
|
||||||
|
actual_reward_gained=0.8,
|
||||||
|
description=f"Task {task_id} completed successfully",
|
||||||
|
benefit={"task_completed": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
self._ethics.record_experience(
|
||||||
|
action_type="task_execution",
|
||||||
|
context_summary=f"Task {task_id} execution",
|
||||||
|
advisory_reason="",
|
||||||
|
proceeded=True,
|
||||||
|
outcome_positive=True,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"FusionAGILoop: _on_task_state_changed failed (best-effort)",
|
"FusionAGILoop: _on_task_state_changed failed (best-effort)",
|
||||||
@@ -94,10 +168,22 @@ class FusionAGILoop:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _on_reflection_done(self, event_type: str, payload: dict[str, Any]) -> None:
|
def _on_reflection_done(self, event_type: str, payload: dict[str, Any]) -> None:
|
||||||
"""After reflection, run auto-recommend and auto-training."""
|
"""After reflection, run auto-recommend, auto-training, and update ethics."""
|
||||||
try:
|
try:
|
||||||
task_id = payload.get("task_id") or ""
|
task_id = payload.get("task_id") or ""
|
||||||
evaluation = payload.get("evaluation") or {}
|
evaluation = payload.get("evaluation") or {}
|
||||||
|
|
||||||
|
success = evaluation.get("success", False)
|
||||||
|
|
||||||
|
self._ethics.record_experience(
|
||||||
|
action_type="reflection_outcome",
|
||||||
|
context_summary=f"Reflection on task {task_id}",
|
||||||
|
advisory_reason="",
|
||||||
|
proceeded=True,
|
||||||
|
outcome_positive=success,
|
||||||
|
task_id=task_id or None,
|
||||||
|
)
|
||||||
|
|
||||||
recs = self._recommender.recommend(
|
recs = self._recommender.recommend(
|
||||||
task_id=task_id or None,
|
task_id=task_id or None,
|
||||||
evaluation=evaluation,
|
evaluation=evaluation,
|
||||||
@@ -129,10 +215,27 @@ class FusionAGILoop:
|
|||||||
task_id: str,
|
task_id: str,
|
||||||
evaluation: dict[str, Any],
|
evaluation: dict[str, Any],
|
||||||
) -> tuple[list[Recommendation], list[TrainingSuggestion]]:
|
) -> tuple[list[Recommendation], list[TrainingSuggestion]]:
|
||||||
|
"""Run auto-recommend and auto-training after a reflection.
|
||||||
|
|
||||||
|
Also records the reflection outcome for ethical learning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Task that was reflected on.
|
||||||
|
evaluation: Critic evaluation dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (recommendations, training_suggestions).
|
||||||
"""
|
"""
|
||||||
Run auto-recommend and auto-training after a reflection (e.g. when
|
success = evaluation.get("success", False)
|
||||||
not using reflection_done event). Returns (recommendations, training_suggestions).
|
self._ethics.record_experience(
|
||||||
"""
|
action_type="reflection_outcome",
|
||||||
|
context_summary=f"Manual reflection on {task_id}",
|
||||||
|
advisory_reason="",
|
||||||
|
proceeded=True,
|
||||||
|
outcome_positive=success,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
|
|
||||||
recs = self._recommender.recommend(
|
recs = self._recommender.recommend(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
evaluation=evaluation,
|
evaluation=evaluation,
|
||||||
|
|||||||
@@ -1,5 +1,17 @@
|
|||||||
|
from fusionagi.verification.claim_verifier import (
|
||||||
|
ClaimVerifier,
|
||||||
|
VerificationReport,
|
||||||
|
VerificationResult,
|
||||||
|
)
|
||||||
from fusionagi.verification.contradiction import ContradictionDetector
|
from fusionagi.verification.contradiction import ContradictionDetector
|
||||||
from fusionagi.verification.outcome import OutcomeVerifier
|
from fusionagi.verification.outcome import OutcomeVerifier
|
||||||
from fusionagi.verification.validators import FormalValidators
|
from fusionagi.verification.validators import FormalValidators
|
||||||
|
|
||||||
__all__ = ["OutcomeVerifier", "ContradictionDetector", "FormalValidators"]
|
__all__ = [
|
||||||
|
"ClaimVerifier",
|
||||||
|
"ContradictionDetector",
|
||||||
|
"FormalValidators",
|
||||||
|
"OutcomeVerifier",
|
||||||
|
"VerificationReport",
|
||||||
|
"VerificationResult",
|
||||||
|
]
|
||||||
|
|||||||
273
fusionagi/verification/claim_verifier.py
Normal file
273
fusionagi/verification/claim_verifier.py
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
"""Claim verification: cross-check claims against known facts and evidence.
|
||||||
|
|
||||||
|
Provides formal verification of claims produced by the reasoning pipeline
|
||||||
|
before they reach the final output. Each claim is checked for:
|
||||||
|
- Internal consistency (does it contradict other claims in the same response?)
|
||||||
|
- Evidence support (how well-supported is this claim by cited evidence?)
|
||||||
|
- Confidence calibration (is the claimed confidence appropriate?)
|
||||||
|
- Factual grounding (can the claim be grounded in the semantic graph?)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
from fusionagi._logger import logger
|
||||||
|
from fusionagi.schemas.head import HeadClaim, HeadOutput
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticGraphLike(Protocol):
|
||||||
|
"""Protocol for semantic graph memory."""
|
||||||
|
|
||||||
|
def query_units(
|
||||||
|
self,
|
||||||
|
unit_ids: list[str] | None = None,
|
||||||
|
content_contains: str | None = None,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> list[Any]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VerificationResult:
|
||||||
|
"""Result of verifying a single claim.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
claim_text: The claim that was verified.
|
||||||
|
verified: Whether the claim passed verification.
|
||||||
|
confidence_calibrated: Whether confidence seems well-calibrated.
|
||||||
|
evidence_score: Evidence support strength (0.0–1.0).
|
||||||
|
consistency_score: Internal consistency with other claims (0.0–1.0).
|
||||||
|
grounding_score: Grounding in known facts (0.0–1.0).
|
||||||
|
issues: List of issues found.
|
||||||
|
overall_score: Composite verification score (0.0–1.0).
|
||||||
|
"""
|
||||||
|
|
||||||
|
claim_text: str = ""
|
||||||
|
verified: bool = True
|
||||||
|
confidence_calibrated: bool = True
|
||||||
|
evidence_score: float = 0.5
|
||||||
|
consistency_score: float = 1.0
|
||||||
|
grounding_score: float = 0.5
|
||||||
|
issues: list[str] = field(default_factory=list)
|
||||||
|
overall_score: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VerificationReport:
|
||||||
|
"""Verification report for all claims in a response.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
results: Per-claim verification results.
|
||||||
|
overall_integrity: Overall response integrity (0.0–1.0).
|
||||||
|
total_claims: Total claims checked.
|
||||||
|
verified_count: How many passed verification.
|
||||||
|
flagged_count: How many were flagged with issues.
|
||||||
|
recommendations: Suggested actions based on verification.
|
||||||
|
"""
|
||||||
|
|
||||||
|
results: list[VerificationResult] = field(default_factory=list)
|
||||||
|
overall_integrity: float = 0.5
|
||||||
|
total_claims: int = 0
|
||||||
|
verified_count: int = 0
|
||||||
|
flagged_count: int = 0
|
||||||
|
recommendations: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class ClaimVerifier:
|
||||||
|
"""Verifies claims from head outputs against evidence and known facts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
semantic_graph: Optional semantic graph for fact grounding.
|
||||||
|
min_evidence_for_high_conf: Minimum evidence items expected for
|
||||||
|
high-confidence claims (>=0.8).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
semantic_graph: SemanticGraphLike | None = None,
|
||||||
|
min_evidence_for_high_conf: int = 2,
|
||||||
|
) -> None:
|
||||||
|
self._graph = semantic_graph
|
||||||
|
self._min_evidence_high = min_evidence_for_high_conf
|
||||||
|
|
||||||
|
def verify_outputs(self, outputs: list[HeadOutput]) -> VerificationReport:
|
||||||
|
"""Verify all claims across all head outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs: Head outputs to verify.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Comprehensive verification report.
|
||||||
|
"""
|
||||||
|
all_claims: list[tuple[HeadClaim, str]] = []
|
||||||
|
for out in outputs:
|
||||||
|
for claim in out.claims:
|
||||||
|
all_claims.append((claim, out.head_id.value))
|
||||||
|
|
||||||
|
results: list[VerificationResult] = []
|
||||||
|
for claim, head_id in all_claims:
|
||||||
|
result = self._verify_claim(claim, head_id, all_claims)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
verified = sum(1 for r in results if r.verified)
|
||||||
|
flagged = sum(1 for r in results if r.issues)
|
||||||
|
overall = (
|
||||||
|
sum(r.overall_score for r in results) / max(len(results), 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
recommendations: list[str] = []
|
||||||
|
if flagged > len(results) * 0.3:
|
||||||
|
recommendations.append(
|
||||||
|
f"{flagged}/{len(results)} claims flagged — consider second-pass verification"
|
||||||
|
)
|
||||||
|
uncalibrated = [r for r in results if not r.confidence_calibrated]
|
||||||
|
if uncalibrated:
|
||||||
|
recommendations.append(
|
||||||
|
f"{len(uncalibrated)} claims with miscalibrated confidence"
|
||||||
|
)
|
||||||
|
low_evidence = [r for r in results if r.evidence_score < 0.3]
|
||||||
|
if low_evidence:
|
||||||
|
recommendations.append(
|
||||||
|
f"{len(low_evidence)} claims lack evidence support"
|
||||||
|
)
|
||||||
|
|
||||||
|
report = VerificationReport(
|
||||||
|
results=results,
|
||||||
|
overall_integrity=overall,
|
||||||
|
total_claims=len(results),
|
||||||
|
verified_count=verified,
|
||||||
|
flagged_count=flagged,
|
||||||
|
recommendations=recommendations,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"ClaimVerifier: verification complete",
|
||||||
|
extra={
|
||||||
|
"total": report.total_claims,
|
||||||
|
"verified": report.verified_count,
|
||||||
|
"flagged": report.flagged_count,
|
||||||
|
"integrity": report.overall_integrity,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return report
|
||||||
|
|
||||||
|
def _verify_claim(
|
||||||
|
self,
|
||||||
|
claim: HeadClaim,
|
||||||
|
head_id: str,
|
||||||
|
all_claims: list[tuple[HeadClaim, str]],
|
||||||
|
) -> VerificationResult:
|
||||||
|
"""Verify a single claim."""
|
||||||
|
issues: list[str] = []
|
||||||
|
|
||||||
|
evidence_score = self._check_evidence(claim, issues)
|
||||||
|
|
||||||
|
calibrated = self._check_calibration(claim, evidence_score, issues)
|
||||||
|
|
||||||
|
consistency_score = self._check_consistency(claim, head_id, all_claims, issues)
|
||||||
|
|
||||||
|
grounding_score = self._check_grounding(claim, issues)
|
||||||
|
|
||||||
|
overall = (
|
||||||
|
0.35 * evidence_score
|
||||||
|
+ 0.25 * consistency_score
|
||||||
|
+ 0.25 * grounding_score
|
||||||
|
+ 0.15 * (1.0 if calibrated else 0.5)
|
||||||
|
)
|
||||||
|
|
||||||
|
return VerificationResult(
|
||||||
|
claim_text=claim.claim_text,
|
||||||
|
verified=len(issues) == 0,
|
||||||
|
confidence_calibrated=calibrated,
|
||||||
|
evidence_score=evidence_score,
|
||||||
|
consistency_score=consistency_score,
|
||||||
|
grounding_score=grounding_score,
|
||||||
|
issues=issues,
|
||||||
|
overall_score=overall,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_evidence(self, claim: HeadClaim, issues: list[str]) -> float:
|
||||||
|
"""Check how well a claim is supported by evidence."""
|
||||||
|
if not claim.evidence:
|
||||||
|
issues.append("No evidence cited")
|
||||||
|
return 0.1
|
||||||
|
|
||||||
|
score = min(1.0, len(claim.evidence) / 3.0)
|
||||||
|
|
||||||
|
if claim.confidence >= 0.8 and len(claim.evidence) < self._min_evidence_high:
|
||||||
|
issues.append(
|
||||||
|
f"High confidence ({claim.confidence:.2f}) with only "
|
||||||
|
f"{len(claim.evidence)} evidence item(s)"
|
||||||
|
)
|
||||||
|
score *= 0.7
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
def _check_calibration(
|
||||||
|
self,
|
||||||
|
claim: HeadClaim,
|
||||||
|
evidence_score: float,
|
||||||
|
issues: list[str],
|
||||||
|
) -> bool:
|
||||||
|
"""Check if confidence is well-calibrated relative to evidence."""
|
||||||
|
if claim.confidence >= 0.9 and evidence_score < 0.3:
|
||||||
|
issues.append(
|
||||||
|
f"Confidence {claim.confidence:.2f} not supported by evidence "
|
||||||
|
f"(evidence score: {evidence_score:.2f})"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
if claim.confidence >= 0.8 and evidence_score < 0.2:
|
||||||
|
issues.append("Very high confidence with minimal evidence")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _check_consistency(
|
||||||
|
self,
|
||||||
|
claim: HeadClaim,
|
||||||
|
head_id: str,
|
||||||
|
all_claims: list[tuple[HeadClaim, str]],
|
||||||
|
issues: list[str],
|
||||||
|
) -> float:
|
||||||
|
"""Check if this claim is consistent with other claims."""
|
||||||
|
claim_words = set(claim.claim_text.lower().split())
|
||||||
|
neg_words = {"not", "no", "never", "none", "cannot", "shouldn't", "won't"}
|
||||||
|
claim_has_neg = bool(claim_words & neg_words)
|
||||||
|
|
||||||
|
contradictions = 0
|
||||||
|
comparisons = 0
|
||||||
|
for other_claim, other_head in all_claims:
|
||||||
|
if other_claim is claim:
|
||||||
|
continue
|
||||||
|
other_words = set(other_claim.claim_text.lower().split())
|
||||||
|
overlap = len(claim_words & other_words) / max(len(claim_words), 1)
|
||||||
|
if overlap < 0.2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
comparisons += 1
|
||||||
|
other_has_neg = bool(other_words & neg_words)
|
||||||
|
if claim_has_neg != other_has_neg and overlap > 0.3:
|
||||||
|
contradictions += 1
|
||||||
|
issues.append(
|
||||||
|
f"Potential contradiction with claim from '{other_head}': "
|
||||||
|
f"'{other_claim.claim_text[:60]}...'"
|
||||||
|
)
|
||||||
|
|
||||||
|
if comparisons == 0:
|
||||||
|
return 0.7
|
||||||
|
return max(0.0, 1.0 - contradictions / max(comparisons, 1))
|
||||||
|
|
||||||
|
def _check_grounding(self, claim: HeadClaim, issues: list[str]) -> float:
|
||||||
|
"""Check if the claim can be grounded in the semantic graph."""
|
||||||
|
if self._graph is None:
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
try:
|
||||||
|
claim_keywords = claim.claim_text[:80]
|
||||||
|
units = self._graph.query_units(content_contains=claim_keywords, limit=5)
|
||||||
|
if not units:
|
||||||
|
return 0.3
|
||||||
|
return min(1.0, 0.3 + len(units) * 0.15)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("ClaimVerifier: grounding check failed (non-fatal)")
|
||||||
|
return 0.5
|
||||||
@@ -1,6 +1,11 @@
|
|||||||
"""World model and simulation for AGI."""
|
"""World model and simulation for AGI.
|
||||||
|
|
||||||
|
Provides causal state-transition prediction from learned execution history,
|
||||||
|
rollout simulation, and uncertainty estimation.
|
||||||
|
"""
|
||||||
|
|
||||||
from fusionagi.world_model.base import SimpleWorldModel, WorldModel
|
from fusionagi.world_model.base import SimpleWorldModel, WorldModel
|
||||||
|
from fusionagi.world_model.causal import CausalWorldModel
|
||||||
from fusionagi.world_model.rollout import run_rollout
|
from fusionagi.world_model.rollout import run_rollout
|
||||||
|
|
||||||
__all__ = ["WorldModel", "SimpleWorldModel", "run_rollout"]
|
__all__ = ["CausalWorldModel", "SimpleWorldModel", "WorldModel", "run_rollout"]
|
||||||
|
|||||||
300
fusionagi/world_model/causal.py
Normal file
300
fusionagi/world_model/causal.py
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
"""Causal world model: learns state-transition patterns from execution history.
|
||||||
|
|
||||||
|
Unlike ``SimpleWorldModel`` (which returns unchanged state), the causal
|
||||||
|
world model builds a library of observed action→effect patterns and uses
|
||||||
|
them to predict outcomes of planned actions before execution.
|
||||||
|
|
||||||
|
The model learns from every executed step:
|
||||||
|
- Records (state_before, action, action_args) → state_after transitions
|
||||||
|
- Groups patterns by action type for efficient lookup
|
||||||
|
- Predicts confidence based on how many similar transitions it has observed
|
||||||
|
- Maintains uncertainty estimates that decrease with more evidence
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
from fusionagi._logger import logger
|
||||||
|
from fusionagi.schemas.audit import AuditEventType
|
||||||
|
from fusionagi.schemas.world_model import StateTransition, UncertaintyInfo
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLogLike(Protocol):
|
||||||
|
"""Protocol for audit log."""
|
||||||
|
|
||||||
|
def append(
|
||||||
|
self,
|
||||||
|
event_type: AuditEventType,
|
||||||
|
actor: str,
|
||||||
|
action: str = "",
|
||||||
|
task_id: str | None = None,
|
||||||
|
payload: dict[str, Any] | None = None,
|
||||||
|
outcome: str = "",
|
||||||
|
) -> str: ...
|
||||||
|
|
||||||
|
|
||||||
|
class TransitionPattern:
|
||||||
|
"""A learned state-transition pattern from execution history.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
action: The action type that triggers this pattern.
|
||||||
|
preconditions: State keys that must be present for this pattern.
|
||||||
|
effects: Observed state changes (key → new_value).
|
||||||
|
observation_count: How many times this pattern has been observed.
|
||||||
|
success_count: How many times the action succeeded.
|
||||||
|
avg_confidence: Running average confidence across observations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = (
|
||||||
|
"action",
|
||||||
|
"preconditions",
|
||||||
|
"effects",
|
||||||
|
"observation_count",
|
||||||
|
"success_count",
|
||||||
|
"avg_confidence",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, action: str) -> None:
|
||||||
|
self.action = action
|
||||||
|
self.preconditions: set[str] = set()
|
||||||
|
self.effects: dict[str, Any] = {}
|
||||||
|
self.observation_count: int = 0
|
||||||
|
self.success_count: int = 0
|
||||||
|
self.avg_confidence: float = 0.5
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
from_state: dict[str, Any],
|
||||||
|
to_state: dict[str, Any],
|
||||||
|
success: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Update pattern with a new observation."""
|
||||||
|
self.observation_count += 1
|
||||||
|
if success:
|
||||||
|
self.success_count += 1
|
||||||
|
|
||||||
|
self.preconditions.update(from_state.keys())
|
||||||
|
|
||||||
|
for key, value in to_state.items():
|
||||||
|
if key not in from_state or from_state[key] != value:
|
||||||
|
self.effects[key] = value
|
||||||
|
|
||||||
|
success_rate = self.success_count / self.observation_count
|
||||||
|
evidence_boost = min(0.4, self.observation_count * 0.02)
|
||||||
|
self.avg_confidence = min(1.0, 0.5 * success_rate + 0.5 + evidence_boost)
|
||||||
|
|
||||||
|
|
||||||
|
class CausalWorldModel:
|
||||||
|
"""World model that learns causal state-transition patterns.
|
||||||
|
|
||||||
|
Records every executed transition and builds a library of
|
||||||
|
action→effect patterns. When asked to predict, it finds matching
|
||||||
|
patterns and applies learned effects to the current state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audit_log: Optional audit log for recording learning events.
|
||||||
|
max_patterns_per_action: Max patterns to retain per action type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
audit_log: AuditLogLike | None = None,
|
||||||
|
max_patterns_per_action: int = 100,
|
||||||
|
) -> None:
|
||||||
|
self._patterns: dict[str, TransitionPattern] = {}
|
||||||
|
self._history: list[StateTransition] = []
|
||||||
|
self._audit = audit_log
|
||||||
|
self._max_per_action = max_patterns_per_action
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_observations(self) -> int:
|
||||||
|
"""Total state transitions observed."""
|
||||||
|
return len(self._history)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def known_actions(self) -> list[str]:
|
||||||
|
"""Actions the model has observed."""
|
||||||
|
return list(self._patterns.keys())
|
||||||
|
|
||||||
|
def observe(
|
||||||
|
self,
|
||||||
|
from_state: dict[str, Any],
|
||||||
|
action: str,
|
||||||
|
action_args: dict[str, Any],
|
||||||
|
to_state: dict[str, Any],
|
||||||
|
success: bool = True,
|
||||||
|
task_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Record an observed state transition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
from_state: State before the action.
|
||||||
|
action: Action name/type.
|
||||||
|
action_args: Arguments passed to the action.
|
||||||
|
to_state: State after the action.
|
||||||
|
success: Whether the action succeeded.
|
||||||
|
task_id: Associated task ID.
|
||||||
|
"""
|
||||||
|
transition = StateTransition(
|
||||||
|
from_state=dict(from_state),
|
||||||
|
action=action,
|
||||||
|
action_args=dict(action_args),
|
||||||
|
to_state=dict(to_state),
|
||||||
|
confidence=1.0 if success else 0.2,
|
||||||
|
)
|
||||||
|
self._history.append(transition)
|
||||||
|
|
||||||
|
pattern_key = self._pattern_key(action, action_args)
|
||||||
|
if pattern_key not in self._patterns:
|
||||||
|
self._patterns[pattern_key] = TransitionPattern(action)
|
||||||
|
self._patterns[pattern_key].update(from_state, to_state, success)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"CausalWorldModel: transition observed",
|
||||||
|
extra={
|
||||||
|
"action": action,
|
||||||
|
"success": success,
|
||||||
|
"observations": self._patterns[pattern_key].observation_count,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._audit:
|
||||||
|
self._audit.append(
|
||||||
|
AuditEventType.SELF_IMPROVEMENT,
|
||||||
|
actor="world_model",
|
||||||
|
action="transition_observed",
|
||||||
|
task_id=task_id,
|
||||||
|
payload={
|
||||||
|
"action_type": action,
|
||||||
|
"success": success,
|
||||||
|
"pattern_observations": self._patterns[pattern_key].observation_count,
|
||||||
|
"state_changes": len(self._patterns[pattern_key].effects),
|
||||||
|
},
|
||||||
|
outcome="learned",
|
||||||
|
)
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
state: dict[str, Any],
|
||||||
|
action: str,
|
||||||
|
action_args: dict[str, Any],
|
||||||
|
) -> StateTransition:
|
||||||
|
"""Predict the result of an action in the current state.
|
||||||
|
|
||||||
|
Uses learned patterns to predict state changes. When no matching
|
||||||
|
pattern exists, returns the state unchanged with low confidence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Current state.
|
||||||
|
action: Action to predict.
|
||||||
|
action_args: Arguments for the action.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Predicted state transition with confidence.
|
||||||
|
"""
|
||||||
|
pattern_key = self._pattern_key(action, action_args)
|
||||||
|
pattern = self._patterns.get(pattern_key)
|
||||||
|
|
||||||
|
if pattern is None:
|
||||||
|
generic_pattern = self._find_generic_pattern(action)
|
||||||
|
if generic_pattern is None:
|
||||||
|
return StateTransition(
|
||||||
|
from_state=dict(state),
|
||||||
|
action=action,
|
||||||
|
action_args=dict(action_args),
|
||||||
|
to_state=dict(state),
|
||||||
|
confidence=0.3,
|
||||||
|
)
|
||||||
|
pattern = generic_pattern
|
||||||
|
|
||||||
|
predicted_state = dict(state)
|
||||||
|
for key, value in pattern.effects.items():
|
||||||
|
predicted_state[key] = value
|
||||||
|
|
||||||
|
return StateTransition(
|
||||||
|
from_state=dict(state),
|
||||||
|
action=action,
|
||||||
|
action_args=dict(action_args),
|
||||||
|
to_state=predicted_state,
|
||||||
|
confidence=pattern.avg_confidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
def uncertainty(self, state: dict[str, Any], action: str) -> UncertaintyInfo:
|
||||||
|
"""Return uncertainty and risk assessment for an action.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Current state.
|
||||||
|
action: Action to assess.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Uncertainty info with confidence and risk level.
|
||||||
|
"""
|
||||||
|
matching = [
|
||||||
|
p for key, p in self._patterns.items()
|
||||||
|
if p.action == action
|
||||||
|
]
|
||||||
|
|
||||||
|
if not matching:
|
||||||
|
return UncertaintyInfo(
|
||||||
|
confidence=0.3,
|
||||||
|
risk_level="high",
|
||||||
|
rationale=f"No prior observations for action '{action}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
total_obs = sum(p.observation_count for p in matching)
|
||||||
|
total_success = sum(p.success_count for p in matching)
|
||||||
|
success_rate = total_success / total_obs if total_obs > 0 else 0.5
|
||||||
|
avg_conf = sum(p.avg_confidence for p in matching) / len(matching)
|
||||||
|
|
||||||
|
if avg_conf >= 0.8 and success_rate >= 0.8:
|
||||||
|
risk = "low"
|
||||||
|
elif avg_conf >= 0.5 and success_rate >= 0.5:
|
||||||
|
risk = "medium"
|
||||||
|
else:
|
||||||
|
risk = "high"
|
||||||
|
|
||||||
|
return UncertaintyInfo(
|
||||||
|
confidence=avg_conf,
|
||||||
|
risk_level=risk,
|
||||||
|
rationale=(
|
||||||
|
f"Based on {total_obs} observations of '{action}': "
|
||||||
|
f"{success_rate:.0%} success rate, {avg_conf:.2f} avg confidence"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_summary(self) -> dict[str, Any]:
|
||||||
|
"""Return a summary of the world model's learned knowledge."""
|
||||||
|
by_action: dict[str, dict[str, Any]] = {}
|
||||||
|
for key, pattern in self._patterns.items():
|
||||||
|
by_action[key] = {
|
||||||
|
"action": pattern.action,
|
||||||
|
"observations": pattern.observation_count,
|
||||||
|
"success_rate": (
|
||||||
|
pattern.success_count / pattern.observation_count
|
||||||
|
if pattern.observation_count > 0
|
||||||
|
else 0.0
|
||||||
|
),
|
||||||
|
"confidence": pattern.avg_confidence,
|
||||||
|
"known_effects": len(pattern.effects),
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"total_observations": len(self._history),
|
||||||
|
"known_patterns": len(self._patterns),
|
||||||
|
"patterns": by_action,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _pattern_key(self, action: str, action_args: dict[str, Any]) -> str:
|
||||||
|
"""Generate a pattern key from action and significant args."""
|
||||||
|
significant = sorted(action_args.keys())[:3]
|
||||||
|
return f"{action}:{','.join(significant)}" if significant else action
|
||||||
|
|
||||||
|
def _find_generic_pattern(self, action: str) -> TransitionPattern | None:
|
||||||
|
"""Find the best matching pattern by action name alone."""
|
||||||
|
matching = [
|
||||||
|
p for p in self._patterns.values()
|
||||||
|
if p.action == action
|
||||||
|
]
|
||||||
|
if not matching:
|
||||||
|
return None
|
||||||
|
return max(matching, key=lambda p: p.observation_count)
|
||||||
118
tests/test_consequence_engine.py
Normal file
118
tests/test_consequence_engine.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
"""Tests for the consequence engine and choice→consequence→learning loop."""
|
||||||
|
|
||||||
|
from fusionagi.governance import Alternative, ConsequenceEngine
|
||||||
|
from fusionagi.governance.audit_log import AuditLog
|
||||||
|
from fusionagi.schemas.audit import AuditEventType
|
||||||
|
|
||||||
|
|
||||||
|
class TestConsequenceEngine:
|
||||||
|
"""Test consequence tracking and risk/reward estimation."""
|
||||||
|
|
||||||
|
def test_record_choice(self) -> None:
|
||||||
|
ce = ConsequenceEngine()
|
||||||
|
choice = ce.record_choice(
|
||||||
|
choice_id="c1",
|
||||||
|
actor="planner",
|
||||||
|
action_taken="use_tool_x",
|
||||||
|
estimated_risk=0.3,
|
||||||
|
estimated_reward=0.7,
|
||||||
|
rationale="Tool X is the best fit",
|
||||||
|
)
|
||||||
|
assert choice.choice_id == "c1"
|
||||||
|
assert choice.estimated_risk == 0.3
|
||||||
|
assert ce.total_choices == 1
|
||||||
|
|
||||||
|
def test_record_consequence(self) -> None:
|
||||||
|
ce = ConsequenceEngine()
|
||||||
|
ce.record_choice(choice_id="c1", actor="planner", action_taken="act")
|
||||||
|
consequence = ce.record_consequence(
|
||||||
|
choice_id="c1",
|
||||||
|
outcome_positive=True,
|
||||||
|
actual_risk_realized=0.1,
|
||||||
|
actual_reward_gained=0.9,
|
||||||
|
description="Action succeeded",
|
||||||
|
)
|
||||||
|
assert consequence is not None
|
||||||
|
assert consequence.outcome_positive is True
|
||||||
|
assert ce.total_consequences == 1
|
||||||
|
|
||||||
|
def test_consequence_not_found(self) -> None:
|
||||||
|
ce = ConsequenceEngine()
|
||||||
|
result = ce.record_consequence(choice_id="nonexistent", outcome_positive=True)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_surprise_factor(self) -> None:
|
||||||
|
ce = ConsequenceEngine()
|
||||||
|
ce.record_choice(
|
||||||
|
choice_id="c1",
|
||||||
|
actor="exec",
|
||||||
|
action_taken="risky_op",
|
||||||
|
estimated_risk=0.1,
|
||||||
|
estimated_reward=0.9,
|
||||||
|
)
|
||||||
|
consequence = ce.record_consequence(
|
||||||
|
choice_id="c1",
|
||||||
|
outcome_positive=False,
|
||||||
|
actual_risk_realized=0.9,
|
||||||
|
actual_reward_gained=0.1,
|
||||||
|
)
|
||||||
|
assert consequence is not None
|
||||||
|
assert consequence.surprise_factor > 0.5
|
||||||
|
|
||||||
|
def test_estimate_risk_reward_no_history(self) -> None:
|
||||||
|
ce = ConsequenceEngine()
|
||||||
|
estimate = ce.estimate_risk_reward("unknown_action")
|
||||||
|
assert estimate["observations"] == 0
|
||||||
|
assert estimate["confidence"] == 0.1
|
||||||
|
|
||||||
|
def test_estimate_risk_reward_with_history(self) -> None:
|
||||||
|
ce = ConsequenceEngine()
|
||||||
|
for i in range(5):
|
||||||
|
ce.record_choice(f"c{i}", "exec", "tool_call")
|
||||||
|
ce.record_consequence(
|
||||||
|
f"c{i}",
|
||||||
|
outcome_positive=True,
|
||||||
|
actual_risk_realized=0.2,
|
||||||
|
actual_reward_gained=0.8,
|
||||||
|
)
|
||||||
|
estimate = ce.estimate_risk_reward("tool_call")
|
||||||
|
assert estimate["observations"] == 5
|
||||||
|
assert abs(estimate["expected_risk"] - 0.2) < 0.01
|
||||||
|
assert abs(estimate["expected_reward"] - 0.8) < 0.01
|
||||||
|
|
||||||
|
def test_alternatives_recorded(self) -> None:
|
||||||
|
ce = ConsequenceEngine()
|
||||||
|
alts = [
|
||||||
|
Alternative(action="alt_a", estimated_risk=0.6, reason_not_chosen="Too risky"),
|
||||||
|
Alternative(action="alt_b", estimated_risk=0.2, reason_not_chosen="Lower reward"),
|
||||||
|
]
|
||||||
|
choice = ce.record_choice(
|
||||||
|
choice_id="c1",
|
||||||
|
actor="planner",
|
||||||
|
action_taken="chosen_action",
|
||||||
|
alternatives=alts,
|
||||||
|
)
|
||||||
|
assert len(choice.alternatives) == 2
|
||||||
|
assert choice.alternatives[0].reason_not_chosen == "Too risky"
|
||||||
|
|
||||||
|
def test_get_summary(self) -> None:
|
||||||
|
ce = ConsequenceEngine()
|
||||||
|
ce.record_choice("c1", "exec", "action_a")
|
||||||
|
ce.record_consequence("c1", True, 0.1, 0.9)
|
||||||
|
ce.record_choice("c2", "exec", "action_a")
|
||||||
|
ce.record_consequence("c2", False, 0.8, 0.1)
|
||||||
|
summary = ce.get_summary()
|
||||||
|
assert summary["total_choices"] == 2
|
||||||
|
assert summary["total_consequences"] == 2
|
||||||
|
assert summary["positive_outcomes"] == 1
|
||||||
|
assert summary["negative_outcomes"] == 1
|
||||||
|
|
||||||
|
def test_audit_log_integration(self) -> None:
|
||||||
|
audit = AuditLog()
|
||||||
|
ce = ConsequenceEngine(audit_log=audit)
|
||||||
|
ce.record_choice("c1", "exec", "action")
|
||||||
|
ce.record_consequence("c1", True)
|
||||||
|
choices = audit.get_by_type(AuditEventType.CHOICE)
|
||||||
|
consequences = audit.get_by_type(AuditEventType.CONSEQUENCE)
|
||||||
|
assert len(choices) == 1
|
||||||
|
assert len(consequences) == 1
|
||||||
139
tests/test_metacognition.py
Normal file
139
tests/test_metacognition.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
"""Tests for metacognition and reasoning interpretability."""
|
||||||
|
|
||||||
|
from fusionagi.reasoning.interpretability import ReasoningTracer
|
||||||
|
from fusionagi.reasoning.metacognition import (
|
||||||
|
assess_head_outputs,
|
||||||
|
)
|
||||||
|
from fusionagi.schemas.grounding import Citation
|
||||||
|
from fusionagi.schemas.head import HeadClaim, HeadId, HeadOutput
|
||||||
|
from fusionagi.verification import ClaimVerifier
|
||||||
|
|
||||||
|
_SAMPLE_CITATION = Citation(source_id="src_1", excerpt="supporting evidence")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_head_output(
|
||||||
|
head_id: HeadId,
|
||||||
|
claims: list[tuple[str, float]] | None = None,
|
||||||
|
) -> HeadOutput:
|
||||||
|
"""Helper to create a head output with claims."""
|
||||||
|
head_claims = []
|
||||||
|
for text, conf in (claims or [("Test claim", 0.7)]):
|
||||||
|
head_claims.append(HeadClaim(
|
||||||
|
claim_text=text,
|
||||||
|
confidence=conf,
|
||||||
|
evidence=[_SAMPLE_CITATION] if conf > 0.5 else [],
|
||||||
|
))
|
||||||
|
return HeadOutput(
|
||||||
|
head_id=head_id,
|
||||||
|
summary=f"Output from {head_id.value}",
|
||||||
|
claims=head_claims,
|
||||||
|
risks=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetacognition:
|
||||||
|
"""Test metacognitive self-assessment."""
|
||||||
|
|
||||||
|
def test_empty_outputs(self) -> None:
|
||||||
|
assessment = assess_head_outputs([])
|
||||||
|
assert assessment.overall_confidence == 0.0
|
||||||
|
assert assessment.should_seek_more is True
|
||||||
|
|
||||||
|
def test_high_confidence_outputs(self) -> None:
|
||||||
|
outputs = [
|
||||||
|
_make_head_output(HeadId.LOGIC, [("Logic is sound", 0.9)]),
|
||||||
|
_make_head_output(HeadId.RESEARCH, [("Data supports this", 0.85)]),
|
||||||
|
]
|
||||||
|
assessment = assess_head_outputs(outputs)
|
||||||
|
assert assessment.overall_confidence > 0.3
|
||||||
|
assert isinstance(assessment.knowledge_gaps, list)
|
||||||
|
|
||||||
|
def test_low_confidence_triggers_seek_more(self) -> None:
|
||||||
|
outputs = [
|
||||||
|
_make_head_output(HeadId.LOGIC, [("Uncertain claim", 0.1)]),
|
||||||
|
]
|
||||||
|
assessment = assess_head_outputs(outputs)
|
||||||
|
assert len(assessment.uncertainty_sources) > 0
|
||||||
|
|
||||||
|
def test_knowledge_gap_detection(self) -> None:
|
||||||
|
outputs = [
|
||||||
|
_make_head_output(HeadId.LOGIC, [("Low conf claim", 0.1)]),
|
||||||
|
]
|
||||||
|
assessment = assess_head_outputs(outputs)
|
||||||
|
gap_domains = [g.domain for g in assessment.knowledge_gaps]
|
||||||
|
assert "logic" in gap_domains
|
||||||
|
|
||||||
|
def test_domain_gap_detection(self) -> None:
|
||||||
|
outputs = [_make_head_output(HeadId.LOGIC)]
|
||||||
|
assessment = assess_head_outputs(outputs, user_prompt="legal compliance required")
|
||||||
|
gap_domains = [g.domain for g in assessment.knowledge_gaps]
|
||||||
|
assert "legal" in gap_domains
|
||||||
|
|
||||||
|
|
||||||
|
class TestReasoningTracer:
|
||||||
|
"""Test interpretability tracing."""
|
||||||
|
|
||||||
|
def test_trace_lifecycle(self) -> None:
|
||||||
|
tracer = ReasoningTracer()
|
||||||
|
tracer.start_trace("t1", "task1", "What is 2+2?")
|
||||||
|
tracer.add_step("t1", "decomposition", "decomposer", "prompt", "2 units")
|
||||||
|
tracer.add_step("t1", "head_dispatch", "orchestrator", "5 heads", "5 outputs")
|
||||||
|
tracer.finalize_trace("t1", "4", 0.95)
|
||||||
|
result = tracer.get_trace("t1")
|
||||||
|
assert result is not None
|
||||||
|
assert len(result.steps) == 2
|
||||||
|
assert result.final_answer == "4"
|
||||||
|
assert result.overall_confidence == 0.95
|
||||||
|
|
||||||
|
def test_explain(self) -> None:
|
||||||
|
tracer = ReasoningTracer()
|
||||||
|
tracer.start_trace("t1", "task1", "question")
|
||||||
|
tracer.add_step("t1", "stage1", "comp1", "in", "out")
|
||||||
|
tracer.finalize_trace("t1", "answer", 0.8)
|
||||||
|
explanation = tracer.explain("t1")
|
||||||
|
assert "stage1" in explanation
|
||||||
|
assert "answer" in explanation
|
||||||
|
|
||||||
|
def test_trace_not_found(self) -> None:
|
||||||
|
tracer = ReasoningTracer()
|
||||||
|
assert tracer.get_trace("nonexistent") is None
|
||||||
|
assert "not found" in tracer.explain("nonexistent")
|
||||||
|
|
||||||
|
def test_recent_traces(self) -> None:
|
||||||
|
tracer = ReasoningTracer()
|
||||||
|
for i in range(5):
|
||||||
|
tracer.start_trace(f"t{i}", f"task{i}", f"prompt{i}")
|
||||||
|
assert len(tracer.get_recent_traces(limit=3)) == 3
|
||||||
|
assert tracer.total_traces == 5
|
||||||
|
|
||||||
|
|
||||||
|
class TestClaimVerifier:
|
||||||
|
"""Test formal claim verification."""
|
||||||
|
|
||||||
|
def test_verify_no_outputs(self) -> None:
|
||||||
|
verifier = ClaimVerifier()
|
||||||
|
report = verifier.verify_outputs([])
|
||||||
|
assert report.total_claims == 0
|
||||||
|
|
||||||
|
def test_verify_well_supported_claims(self) -> None:
|
||||||
|
outputs = [
|
||||||
|
_make_head_output(HeadId.LOGIC, [("Well supported", 0.7)]),
|
||||||
|
_make_head_output(HeadId.RESEARCH, [("Also supported", 0.7)]),
|
||||||
|
]
|
||||||
|
verifier = ClaimVerifier()
|
||||||
|
report = verifier.verify_outputs(outputs)
|
||||||
|
assert report.total_claims == 2
|
||||||
|
assert report.overall_integrity > 0.0
|
||||||
|
|
||||||
|
def test_high_conf_no_evidence_flagged(self) -> None:
|
||||||
|
claim = HeadClaim(claim_text="Bold claim", confidence=0.95, evidence=[])
|
||||||
|
output = HeadOutput(
|
||||||
|
head_id=HeadId.LOGIC,
|
||||||
|
summary="Bold output",
|
||||||
|
claims=[claim],
|
||||||
|
risks=[],
|
||||||
|
)
|
||||||
|
verifier = ClaimVerifier()
|
||||||
|
report = verifier.verify_outputs([output])
|
||||||
|
assert report.flagged_count >= 1
|
||||||
|
assert any("evidence" in issue.lower() for r in report.results for issue in r.issues)
|
||||||
69
tests/test_world_model_causal.py
Normal file
69
tests/test_world_model_causal.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
"""Tests for the causal world model."""
|
||||||
|
|
||||||
|
from fusionagi.world_model import CausalWorldModel
|
||||||
|
|
||||||
|
|
||||||
|
class TestCausalWorldModel:
|
||||||
|
"""Test learned causal state-transition prediction."""
|
||||||
|
|
||||||
|
def test_predict_unknown_action(self) -> None:
|
||||||
|
wm = CausalWorldModel()
|
||||||
|
result = wm.predict({"x": 1}, "unknown", {})
|
||||||
|
assert result.confidence == 0.3
|
||||||
|
assert result.to_state == {"x": 1}
|
||||||
|
|
||||||
|
def test_observe_and_predict(self) -> None:
|
||||||
|
wm = CausalWorldModel()
|
||||||
|
wm.observe(
|
||||||
|
from_state={"count": 0},
|
||||||
|
action="increment",
|
||||||
|
action_args={},
|
||||||
|
to_state={"count": 1},
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
result = wm.predict({"count": 5}, "increment", {})
|
||||||
|
assert result.confidence > 0.3
|
||||||
|
assert "count" in result.to_state
|
||||||
|
|
||||||
|
def test_multiple_observations_increase_confidence(self) -> None:
|
||||||
|
wm = CausalWorldModel()
|
||||||
|
for i in range(10):
|
||||||
|
wm.observe({"s": i}, "act", {}, {"s": i + 1}, success=True)
|
||||||
|
result = wm.predict({"s": 100}, "act", {})
|
||||||
|
assert result.confidence > 0.7
|
||||||
|
|
||||||
|
def test_uncertainty_no_observations(self) -> None:
|
||||||
|
wm = CausalWorldModel()
|
||||||
|
info = wm.uncertainty({}, "unknown_action")
|
||||||
|
assert info.risk_level == "high"
|
||||||
|
assert info.confidence == 0.3
|
||||||
|
|
||||||
|
def test_uncertainty_with_observations(self) -> None:
|
||||||
|
wm = CausalWorldModel()
|
||||||
|
for i in range(10):
|
||||||
|
wm.observe({}, "safe_action", {}, {}, success=True)
|
||||||
|
info = wm.uncertainty({}, "safe_action")
|
||||||
|
assert info.risk_level in ("low", "medium")
|
||||||
|
assert info.confidence > 0.5
|
||||||
|
|
||||||
|
def test_failed_observations_lower_confidence(self) -> None:
|
||||||
|
wm = CausalWorldModel()
|
||||||
|
for i in range(5):
|
||||||
|
wm.observe({}, "risky", {}, {}, success=False)
|
||||||
|
info = wm.uncertainty({}, "risky")
|
||||||
|
assert info.risk_level == "high"
|
||||||
|
|
||||||
|
def test_known_actions(self) -> None:
|
||||||
|
wm = CausalWorldModel()
|
||||||
|
wm.observe({}, "act_a", {}, {}, success=True)
|
||||||
|
wm.observe({}, "act_b", {}, {}, success=True)
|
||||||
|
assert "act_a" in wm.known_actions
|
||||||
|
assert "act_b" in wm.known_actions
|
||||||
|
|
||||||
|
def test_get_summary(self) -> None:
|
||||||
|
wm = CausalWorldModel()
|
||||||
|
wm.observe({}, "x", {}, {"result": 1}, success=True)
|
||||||
|
wm.observe({}, "x", {}, {"result": 2}, success=True)
|
||||||
|
summary = wm.get_summary()
|
||||||
|
assert summary["total_observations"] == 2
|
||||||
|
assert summary["known_patterns"] >= 1
|
||||||
Reference in New Issue
Block a user