Some checks failed
Choice → Consequence → Learning: - ConsequenceEngine tracks every decision point with alternatives, risk/reward estimates, and actual outcomes - Consequences feed into AdaptiveEthics for experience-based learning - FusionAGILoop now wires ethics + consequences into task lifecycle Causal World Model: - CausalWorldModel learns state-transition patterns from execution history - Predicts outcomes based on observed action→effect patterns - Uncertainty estimates decrease as more evidence accumulates Metacognition: - assess_head_outputs() evaluates reasoning quality from head outputs - Detects knowledge gaps, measures head agreement, identifies uncertainty - Actively recommends whether to seek more information Interpretability: - ReasoningTracer captures full prompt→answer reasoning traces - Each step records stage, component, input/output, timing - explain() generates human-readable reasoning explanations Claim Verification: - ClaimVerifier cross-checks claims for evidence, consistency, grounding - Flags high-confidence claims lacking evidence support - Detects contradictions between claims from different heads 325 tests passing, 0 ruff errors, 0 mypy errors. Co-Authored-By: Nakamoto, S <defi@defi-oracle.io>
301 lines
10 KiB
Python
301 lines
10 KiB
Python
"""Causal world model: learns state-transition patterns from execution history.
|
|
|
|
Unlike ``SimpleWorldModel`` (which returns unchanged state), the causal
|
|
world model builds a library of observed action→effect patterns and uses
|
|
them to predict outcomes of planned actions before execution.
|
|
|
|
The model learns from every executed step:
|
|
- Records (state_before, action, action_args) → state_after transitions
|
|
- Groups patterns by action type for efficient lookup
|
|
- Predicts confidence based on how many similar transitions it has observed
|
|
- Maintains uncertainty estimates that decrease with more evidence
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Protocol
|
|
|
|
from fusionagi._logger import logger
|
|
from fusionagi.schemas.audit import AuditEventType
|
|
from fusionagi.schemas.world_model import StateTransition, UncertaintyInfo
|
|
|
|
|
|
class AuditLogLike(Protocol):
|
|
"""Protocol for audit log."""
|
|
|
|
def append(
|
|
self,
|
|
event_type: AuditEventType,
|
|
actor: str,
|
|
action: str = "",
|
|
task_id: str | None = None,
|
|
payload: dict[str, Any] | None = None,
|
|
outcome: str = "",
|
|
) -> str: ...
|
|
|
|
|
|
class TransitionPattern:
|
|
"""A learned state-transition pattern from execution history.
|
|
|
|
Attributes:
|
|
action: The action type that triggers this pattern.
|
|
preconditions: State keys that must be present for this pattern.
|
|
effects: Observed state changes (key → new_value).
|
|
observation_count: How many times this pattern has been observed.
|
|
success_count: How many times the action succeeded.
|
|
avg_confidence: Running average confidence across observations.
|
|
"""
|
|
|
|
__slots__ = (
|
|
"action",
|
|
"preconditions",
|
|
"effects",
|
|
"observation_count",
|
|
"success_count",
|
|
"avg_confidence",
|
|
)
|
|
|
|
def __init__(self, action: str) -> None:
|
|
self.action = action
|
|
self.preconditions: set[str] = set()
|
|
self.effects: dict[str, Any] = {}
|
|
self.observation_count: int = 0
|
|
self.success_count: int = 0
|
|
self.avg_confidence: float = 0.5
|
|
|
|
def update(
|
|
self,
|
|
from_state: dict[str, Any],
|
|
to_state: dict[str, Any],
|
|
success: bool,
|
|
) -> None:
|
|
"""Update pattern with a new observation."""
|
|
self.observation_count += 1
|
|
if success:
|
|
self.success_count += 1
|
|
|
|
self.preconditions.update(from_state.keys())
|
|
|
|
for key, value in to_state.items():
|
|
if key not in from_state or from_state[key] != value:
|
|
self.effects[key] = value
|
|
|
|
success_rate = self.success_count / self.observation_count
|
|
evidence_boost = min(0.4, self.observation_count * 0.02)
|
|
self.avg_confidence = min(1.0, 0.5 * success_rate + 0.5 + evidence_boost)
|
|
|
|
|
|
class CausalWorldModel:
|
|
"""World model that learns causal state-transition patterns.
|
|
|
|
Records every executed transition and builds a library of
|
|
action→effect patterns. When asked to predict, it finds matching
|
|
patterns and applies learned effects to the current state.
|
|
|
|
Args:
|
|
audit_log: Optional audit log for recording learning events.
|
|
max_patterns_per_action: Max patterns to retain per action type.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
audit_log: AuditLogLike | None = None,
|
|
max_patterns_per_action: int = 100,
|
|
) -> None:
|
|
self._patterns: dict[str, TransitionPattern] = {}
|
|
self._history: list[StateTransition] = []
|
|
self._audit = audit_log
|
|
self._max_per_action = max_patterns_per_action
|
|
|
|
@property
|
|
def total_observations(self) -> int:
|
|
"""Total state transitions observed."""
|
|
return len(self._history)
|
|
|
|
@property
|
|
def known_actions(self) -> list[str]:
|
|
"""Actions the model has observed."""
|
|
return list(self._patterns.keys())
|
|
|
|
def observe(
|
|
self,
|
|
from_state: dict[str, Any],
|
|
action: str,
|
|
action_args: dict[str, Any],
|
|
to_state: dict[str, Any],
|
|
success: bool = True,
|
|
task_id: str | None = None,
|
|
) -> None:
|
|
"""Record an observed state transition.
|
|
|
|
Args:
|
|
from_state: State before the action.
|
|
action: Action name/type.
|
|
action_args: Arguments passed to the action.
|
|
to_state: State after the action.
|
|
success: Whether the action succeeded.
|
|
task_id: Associated task ID.
|
|
"""
|
|
transition = StateTransition(
|
|
from_state=dict(from_state),
|
|
action=action,
|
|
action_args=dict(action_args),
|
|
to_state=dict(to_state),
|
|
confidence=1.0 if success else 0.2,
|
|
)
|
|
self._history.append(transition)
|
|
|
|
pattern_key = self._pattern_key(action, action_args)
|
|
if pattern_key not in self._patterns:
|
|
self._patterns[pattern_key] = TransitionPattern(action)
|
|
self._patterns[pattern_key].update(from_state, to_state, success)
|
|
|
|
logger.debug(
|
|
"CausalWorldModel: transition observed",
|
|
extra={
|
|
"action": action,
|
|
"success": success,
|
|
"observations": self._patterns[pattern_key].observation_count,
|
|
},
|
|
)
|
|
|
|
if self._audit:
|
|
self._audit.append(
|
|
AuditEventType.SELF_IMPROVEMENT,
|
|
actor="world_model",
|
|
action="transition_observed",
|
|
task_id=task_id,
|
|
payload={
|
|
"action_type": action,
|
|
"success": success,
|
|
"pattern_observations": self._patterns[pattern_key].observation_count,
|
|
"state_changes": len(self._patterns[pattern_key].effects),
|
|
},
|
|
outcome="learned",
|
|
)
|
|
|
|
def predict(
|
|
self,
|
|
state: dict[str, Any],
|
|
action: str,
|
|
action_args: dict[str, Any],
|
|
) -> StateTransition:
|
|
"""Predict the result of an action in the current state.
|
|
|
|
Uses learned patterns to predict state changes. When no matching
|
|
pattern exists, returns the state unchanged with low confidence.
|
|
|
|
Args:
|
|
state: Current state.
|
|
action: Action to predict.
|
|
action_args: Arguments for the action.
|
|
|
|
Returns:
|
|
Predicted state transition with confidence.
|
|
"""
|
|
pattern_key = self._pattern_key(action, action_args)
|
|
pattern = self._patterns.get(pattern_key)
|
|
|
|
if pattern is None:
|
|
generic_pattern = self._find_generic_pattern(action)
|
|
if generic_pattern is None:
|
|
return StateTransition(
|
|
from_state=dict(state),
|
|
action=action,
|
|
action_args=dict(action_args),
|
|
to_state=dict(state),
|
|
confidence=0.3,
|
|
)
|
|
pattern = generic_pattern
|
|
|
|
predicted_state = dict(state)
|
|
for key, value in pattern.effects.items():
|
|
predicted_state[key] = value
|
|
|
|
return StateTransition(
|
|
from_state=dict(state),
|
|
action=action,
|
|
action_args=dict(action_args),
|
|
to_state=predicted_state,
|
|
confidence=pattern.avg_confidence,
|
|
)
|
|
|
|
def uncertainty(self, state: dict[str, Any], action: str) -> UncertaintyInfo:
|
|
"""Return uncertainty and risk assessment for an action.
|
|
|
|
Args:
|
|
state: Current state.
|
|
action: Action to assess.
|
|
|
|
Returns:
|
|
Uncertainty info with confidence and risk level.
|
|
"""
|
|
matching = [
|
|
p for key, p in self._patterns.items()
|
|
if p.action == action
|
|
]
|
|
|
|
if not matching:
|
|
return UncertaintyInfo(
|
|
confidence=0.3,
|
|
risk_level="high",
|
|
rationale=f"No prior observations for action '{action}'",
|
|
)
|
|
|
|
total_obs = sum(p.observation_count for p in matching)
|
|
total_success = sum(p.success_count for p in matching)
|
|
success_rate = total_success / total_obs if total_obs > 0 else 0.5
|
|
avg_conf = sum(p.avg_confidence for p in matching) / len(matching)
|
|
|
|
if avg_conf >= 0.8 and success_rate >= 0.8:
|
|
risk = "low"
|
|
elif avg_conf >= 0.5 and success_rate >= 0.5:
|
|
risk = "medium"
|
|
else:
|
|
risk = "high"
|
|
|
|
return UncertaintyInfo(
|
|
confidence=avg_conf,
|
|
risk_level=risk,
|
|
rationale=(
|
|
f"Based on {total_obs} observations of '{action}': "
|
|
f"{success_rate:.0%} success rate, {avg_conf:.2f} avg confidence"
|
|
),
|
|
)
|
|
|
|
def get_summary(self) -> dict[str, Any]:
|
|
"""Return a summary of the world model's learned knowledge."""
|
|
by_action: dict[str, dict[str, Any]] = {}
|
|
for key, pattern in self._patterns.items():
|
|
by_action[key] = {
|
|
"action": pattern.action,
|
|
"observations": pattern.observation_count,
|
|
"success_rate": (
|
|
pattern.success_count / pattern.observation_count
|
|
if pattern.observation_count > 0
|
|
else 0.0
|
|
),
|
|
"confidence": pattern.avg_confidence,
|
|
"known_effects": len(pattern.effects),
|
|
}
|
|
return {
|
|
"total_observations": len(self._history),
|
|
"known_patterns": len(self._patterns),
|
|
"patterns": by_action,
|
|
}
|
|
|
|
def _pattern_key(self, action: str, action_args: dict[str, Any]) -> str:
|
|
"""Generate a pattern key from action and significant args."""
|
|
significant = sorted(action_args.keys())[:3]
|
|
return f"{action}:{','.join(significant)}" if significant else action
|
|
|
|
def _find_generic_pattern(self, action: str) -> TransitionPattern | None:
|
|
"""Find the best matching pattern by action name alone."""
|
|
matching = [
|
|
p for p in self._patterns.values()
|
|
if p.action == action
|
|
]
|
|
if not matching:
|
|
return None
|
|
return max(matching, key=lambda p: p.observation_count)
|