"""Causal world model: learns state-transition patterns from execution history. Unlike ``SimpleWorldModel`` (which returns unchanged state), the causal world model builds a library of observed action→effect patterns and uses them to predict outcomes of planned actions before execution. The model learns from every executed step: - Records (state_before, action, action_args) → state_after transitions - Groups patterns by action type for efficient lookup - Predicts confidence based on how many similar transitions it has observed - Maintains uncertainty estimates that decrease with more evidence """ from __future__ import annotations from typing import Any, Protocol from fusionagi._logger import logger from fusionagi.schemas.audit import AuditEventType from fusionagi.schemas.world_model import StateTransition, UncertaintyInfo class AuditLogLike(Protocol): """Protocol for audit log.""" def append( self, event_type: AuditEventType, actor: str, action: str = "", task_id: str | None = None, payload: dict[str, Any] | None = None, outcome: str = "", ) -> str: ... class TransitionPattern: """A learned state-transition pattern from execution history. Attributes: action: The action type that triggers this pattern. preconditions: State keys that must be present for this pattern. effects: Observed state changes (key → new_value). observation_count: How many times this pattern has been observed. success_count: How many times the action succeeded. avg_confidence: Running average confidence across observations. """ __slots__ = ( "action", "preconditions", "effects", "observation_count", "success_count", "avg_confidence", ) def __init__(self, action: str) -> None: self.action = action self.preconditions: set[str] = set() self.effects: dict[str, Any] = {} self.observation_count: int = 0 self.success_count: int = 0 self.avg_confidence: float = 0.5 def update( self, from_state: dict[str, Any], to_state: dict[str, Any], success: bool, ) -> None: """Update pattern with a new observation.""" self.observation_count += 1 if success: self.success_count += 1 self.preconditions.update(from_state.keys()) for key, value in to_state.items(): if key not in from_state or from_state[key] != value: self.effects[key] = value success_rate = self.success_count / self.observation_count evidence_boost = min(0.4, self.observation_count * 0.02) self.avg_confidence = min(1.0, 0.5 * success_rate + 0.5 + evidence_boost) class CausalWorldModel: """World model that learns causal state-transition patterns. Records every executed transition and builds a library of action→effect patterns. When asked to predict, it finds matching patterns and applies learned effects to the current state. Args: audit_log: Optional audit log for recording learning events. max_patterns_per_action: Max patterns to retain per action type. """ def __init__( self, audit_log: AuditLogLike | None = None, max_patterns_per_action: int = 100, ) -> None: self._patterns: dict[str, TransitionPattern] = {} self._history: list[StateTransition] = [] self._audit = audit_log self._max_per_action = max_patterns_per_action @property def total_observations(self) -> int: """Total state transitions observed.""" return len(self._history) @property def known_actions(self) -> list[str]: """Actions the model has observed.""" return list(self._patterns.keys()) def observe( self, from_state: dict[str, Any], action: str, action_args: dict[str, Any], to_state: dict[str, Any], success: bool = True, task_id: str | None = None, ) -> None: """Record an observed state transition. Args: from_state: State before the action. action: Action name/type. action_args: Arguments passed to the action. to_state: State after the action. success: Whether the action succeeded. task_id: Associated task ID. """ transition = StateTransition( from_state=dict(from_state), action=action, action_args=dict(action_args), to_state=dict(to_state), confidence=1.0 if success else 0.2, ) self._history.append(transition) pattern_key = self._pattern_key(action, action_args) if pattern_key not in self._patterns: self._patterns[pattern_key] = TransitionPattern(action) self._patterns[pattern_key].update(from_state, to_state, success) logger.debug( "CausalWorldModel: transition observed", extra={ "action": action, "success": success, "observations": self._patterns[pattern_key].observation_count, }, ) if self._audit: self._audit.append( AuditEventType.SELF_IMPROVEMENT, actor="world_model", action="transition_observed", task_id=task_id, payload={ "action_type": action, "success": success, "pattern_observations": self._patterns[pattern_key].observation_count, "state_changes": len(self._patterns[pattern_key].effects), }, outcome="learned", ) def predict( self, state: dict[str, Any], action: str, action_args: dict[str, Any], ) -> StateTransition: """Predict the result of an action in the current state. Uses learned patterns to predict state changes. When no matching pattern exists, returns the state unchanged with low confidence. Args: state: Current state. action: Action to predict. action_args: Arguments for the action. Returns: Predicted state transition with confidence. """ pattern_key = self._pattern_key(action, action_args) pattern = self._patterns.get(pattern_key) if pattern is None: generic_pattern = self._find_generic_pattern(action) if generic_pattern is None: return StateTransition( from_state=dict(state), action=action, action_args=dict(action_args), to_state=dict(state), confidence=0.3, ) pattern = generic_pattern predicted_state = dict(state) for key, value in pattern.effects.items(): predicted_state[key] = value return StateTransition( from_state=dict(state), action=action, action_args=dict(action_args), to_state=predicted_state, confidence=pattern.avg_confidence, ) def uncertainty(self, state: dict[str, Any], action: str) -> UncertaintyInfo: """Return uncertainty and risk assessment for an action. Args: state: Current state. action: Action to assess. Returns: Uncertainty info with confidence and risk level. """ matching = [ p for key, p in self._patterns.items() if p.action == action ] if not matching: return UncertaintyInfo( confidence=0.3, risk_level="high", rationale=f"No prior observations for action '{action}'", ) total_obs = sum(p.observation_count for p in matching) total_success = sum(p.success_count for p in matching) success_rate = total_success / total_obs if total_obs > 0 else 0.5 avg_conf = sum(p.avg_confidence for p in matching) / len(matching) if avg_conf >= 0.8 and success_rate >= 0.8: risk = "low" elif avg_conf >= 0.5 and success_rate >= 0.5: risk = "medium" else: risk = "high" return UncertaintyInfo( confidence=avg_conf, risk_level=risk, rationale=( f"Based on {total_obs} observations of '{action}': " f"{success_rate:.0%} success rate, {avg_conf:.2f} avg confidence" ), ) def predict_self_modification( self, action: str, action_args: dict[str, Any], ) -> dict[str, Any]: """Predict how a self-improvement action changes the system's own capabilities. Tracks capability evolution over time by observing how internal actions (training, parameter updates, strategy changes) affect subsequent performance. Args: action: The self-modification action type. action_args: Parameters for the action. Returns: Dict with predicted capability changes and confidence. """ self_mod_actions = [ h for h in self._history if h.action == action and any( k in h.action_args for k in ("capability", "domain", "heuristic") ) ] if not self_mod_actions: return { "predicted_change": "unknown", "confidence": 0.2, "prior_self_modifications": 0, "rationale": f"No prior self-modification observations for '{action}'", } improvements = sum( 1 for t in self_mod_actions if t.confidence > 0.6 ) total = len(self_mod_actions) improvement_rate = improvements / total if total > 0 else 0.0 return { "predicted_change": "improvement" if improvement_rate > 0.5 else "uncertain", "confidence": min(0.9, 0.3 + total * 0.05), "improvement_rate": improvement_rate, "prior_self_modifications": total, "rationale": ( f"Based on {total} prior self-modifications: " f"{improvement_rate:.0%} led to improvements" ), } def get_summary(self) -> dict[str, Any]: """Return a summary of the world model's learned knowledge.""" by_action: dict[str, dict[str, Any]] = {} for key, pattern in self._patterns.items(): by_action[key] = { "action": pattern.action, "observations": pattern.observation_count, "success_rate": ( pattern.success_count / pattern.observation_count if pattern.observation_count > 0 else 0.0 ), "confidence": pattern.avg_confidence, "known_effects": len(pattern.effects), } return { "total_observations": len(self._history), "known_patterns": len(self._patterns), "patterns": by_action, } def _pattern_key(self, action: str, action_args: dict[str, Any]) -> str: """Generate a pattern key from action and significant args.""" significant = sorted(action_args.keys())[:3] return f"{action}:{','.join(significant)}" if significant else action def _find_generic_pattern(self, action: str) -> TransitionPattern | None: """Find the best matching pattern by action name alone.""" matching = [ p for p in self._patterns.values() if p.action == action ] if not matching: return None return max(matching, key=lambda p: p.observation_count)