"""Tests for the causal world model.""" from fusionagi.world_model import CausalWorldModel class TestCausalWorldModel: """Test learned causal state-transition prediction.""" def test_predict_unknown_action(self) -> None: wm = CausalWorldModel() result = wm.predict({"x": 1}, "unknown", {}) assert result.confidence == 0.3 assert result.to_state == {"x": 1} def test_observe_and_predict(self) -> None: wm = CausalWorldModel() wm.observe( from_state={"count": 0}, action="increment", action_args={}, to_state={"count": 1}, success=True, ) result = wm.predict({"count": 5}, "increment", {}) assert result.confidence > 0.3 assert "count" in result.to_state def test_multiple_observations_increase_confidence(self) -> None: wm = CausalWorldModel() for i in range(10): wm.observe({"s": i}, "act", {}, {"s": i + 1}, success=True) result = wm.predict({"s": 100}, "act", {}) assert result.confidence > 0.7 def test_uncertainty_no_observations(self) -> None: wm = CausalWorldModel() info = wm.uncertainty({}, "unknown_action") assert info.risk_level == "high" assert info.confidence == 0.3 def test_uncertainty_with_observations(self) -> None: wm = CausalWorldModel() for i in range(10): wm.observe({}, "safe_action", {}, {}, success=True) info = wm.uncertainty({}, "safe_action") assert info.risk_level in ("low", "medium") assert info.confidence > 0.5 def test_failed_observations_lower_confidence(self) -> None: wm = CausalWorldModel() for i in range(5): wm.observe({}, "risky", {}, {}, success=False) info = wm.uncertainty({}, "risky") assert info.risk_level == "high" def test_known_actions(self) -> None: wm = CausalWorldModel() wm.observe({}, "act_a", {}, {}, success=True) wm.observe({}, "act_b", {}, {}, success=True) assert "act_a" in wm.known_actions assert "act_b" in wm.known_actions def test_get_summary(self) -> None: wm = CausalWorldModel() wm.observe({}, "x", {}, {"result": 1}, success=True) wm.observe({}, "x", {}, {"result": 2}, success=True) summary = wm.get_summary() assert summary["total_observations"] == 2 assert summary["known_patterns"] >= 1