"""Tests for fusionagi.gpu.tensor_similarity.""" import pytest from fusionagi.gpu.backend import get_backend, reset_backend from fusionagi.gpu.tensor_similarity import ( deduplicate_claims, nearest_neighbors, pairwise_text_similarity, ) @pytest.fixture(autouse=True) def _use_numpy(): reset_backend() get_backend(force="numpy") yield reset_backend() class TestPairwiseTextSimilarity: def test_basic(self): sim = pairwise_text_similarity(["hello world"], ["hello world"]) assert sim.shape == (1, 1) assert sim[0, 0] > 0.9 def test_different_texts(self): sim = pairwise_text_similarity(["hello world"], ["completely different text"]) assert sim.shape == (1, 1) assert sim[0, 0] < 1.0 def test_multi(self): sim = pairwise_text_similarity( ["cat", "dog"], ["car", "bike", "train"], ) assert sim.shape == (2, 3) class TestDeduplicateClaims: def test_empty(self): assert deduplicate_claims([]) == [] def test_single(self): groups = deduplicate_claims(["one claim"]) assert groups == [[0]] def test_identical(self): groups = deduplicate_claims( ["the sky is blue", "the sky is blue"], threshold=0.9, ) assert len(groups) == 1 assert sorted(groups[0]) == [0, 1] def test_different(self): groups = deduplicate_claims( ["the sky is blue", "python is a programming language"], threshold=0.99, ) assert len(groups) == 2 def test_all_indices_covered(self): claims = ["a", "b", "c", "d"] groups = deduplicate_claims(claims, threshold=0.99) all_indices = sorted(idx for group in groups for idx in group) assert all_indices == [0, 1, 2, 3] class TestNearestNeighbors: def test_empty_query(self): result = nearest_neighbors([], ["corpus text"]) assert result == [] def test_empty_corpus(self): result = nearest_neighbors(["query"], []) assert result == [[]] def test_basic(self): result = nearest_neighbors( ["hello world"], ["hello world", "goodbye moon", "hello planet"], top_k=2, ) assert len(result) == 1 assert len(result[0]) == 2 # Each result is (index, score) assert isinstance(result[0][0], tuple) assert isinstance(result[0][0][0], int) assert isinstance(result[0][0][1], float) def test_top_k_limit(self): corpus = [f"text {i}" for i in range(20)] result = nearest_neighbors(["text 5"], corpus, top_k=3) assert len(result[0]) == 3