"""Tests for fusionagi.gpu.tensor_attention.""" import pytest from fusionagi.gpu.backend import get_backend, reset_backend from fusionagi.gpu.tensor_attention import ( attention_consensus, cross_claim_attention, ) @pytest.fixture(autouse=True) def _use_numpy(): reset_backend() get_backend(force="numpy") yield reset_backend() class TestAttentionConsensus: def test_empty(self): result = attention_consensus([], "query") assert result["head_scores"] == [] assert result["consensus_score"] == 0.0 def test_single_head(self): result = attention_consensus( [["the sky is blue"]], "what color is the sky", ) assert len(result["head_scores"]) == 1 assert isinstance(result["consensus_score"], float) def test_multiple_heads(self): result = attention_consensus( [ ["the sky is blue", "water is wet"], ["security is important"], ["cost should be minimized"], ], "what should we do about the project", ) assert len(result["head_scores"]) == 3 assert 0.0 <= result["consensus_score"] <= 1.0 def test_with_weights(self): result = attention_consensus( [["claim a"], ["claim b"]], "query", head_weights=[2.0, 0.5], ) assert len(result["head_scores"]) == 2 def test_empty_claims(self): result = attention_consensus( [[], []], "query", ) assert len(result["head_scores"]) == 2 assert result["head_scores"] == [0.0, 0.0] class TestCrossClaimAttention: def test_empty(self): result = cross_claim_attention([]) assert result["similarity_matrix"] == [] assert result["conflict_pairs"] == [] def test_single(self): result = cross_claim_attention(["only one claim"]) assert result["similarity_matrix"] == [] def test_two_claims(self): result = cross_claim_attention(["claim one", "claim two"]) assert len(result["similarity_matrix"]) == 2 assert len(result["similarity_matrix"][0]) == 2 def test_self_similarity_high(self): result = cross_claim_attention(["same text", "same text"]) sim = result["similarity_matrix"] assert sim[0][0] > 0.9 assert sim[1][1] > 0.9 def test_conflict_detection(self): result = cross_claim_attention([ "the project is very safe and reliable", "completely unrelated topic about food and cooking", ]) assert isinstance(result["conflict_pairs"], list)