"""GPU-accelerated attention mechanisms for multi-head consensus. Provides attention-based consensus scoring for the Dvādaśa pipeline: - Head output attention: weight head contributions by relevance - Claim-level attention: cross-attend between claims for conflict detection - Weighted consensus: attention-based aggregation of head outputs """ from __future__ import annotations from typing import Any from fusionagi._logger import logger from fusionagi.gpu.backend import TensorBackend, get_backend def attention_consensus( head_embeddings: list[list[str]], query_text: str, head_weights: list[float] | None = None, num_heads: int = 4, backend: TensorBackend | None = None, ) -> dict[str, Any]: """Score head contributions using multi-head attention against the query. Each head's claims are embedded, then cross-attended against the query to produce relevance-weighted scores. Args: head_embeddings: List of claim-text lists, one per head. query_text: The user's original query. head_weights: Optional per-head reliability weights. num_heads: Number of attention heads. backend: TensorBackend to use. Returns: Dict with 'head_scores' (list of floats), 'attention_weights' (matrix), and 'consensus_score' (float). """ be = backend or get_backend() import numpy as np if not head_embeddings: return {"head_scores": [], "attention_weights": [], "consensus_score": 0.0} all_claims: list[str] = [] head_indices: list[int] = [] for i, claims in enumerate(head_embeddings): for claim in claims: all_claims.append(claim) head_indices.append(i) if not all_claims: return { "head_scores": [0.0] * len(head_embeddings), "attention_weights": [], "consensus_score": 0.0, } query_emb = be.embed_texts([query_text]) claim_emb = be.embed_texts(all_claims) query_np = be.to_numpy(query_emb) claims_np = be.to_numpy(claim_emb) query_expanded = np.tile(query_np, (len(all_claims), 1)) attn_output = be.to_numpy( be.multi_head_attention( be.from_numpy(query_expanded), be.from_numpy(claims_np), be.from_numpy(claims_np), num_heads=num_heads, ) ) relevance = np.sum(attn_output * claims_np, axis=1) num_heads_count = len(head_embeddings) head_scores = np.zeros(num_heads_count, dtype=np.float32) head_claim_counts = np.zeros(num_heads_count, dtype=np.float32) for idx, head_idx in enumerate(head_indices): head_scores[head_idx] += relevance[idx] head_claim_counts[head_idx] += 1.0 safe_counts: Any = np.maximum(head_claim_counts, 1.0) head_scores = head_scores / safe_counts if head_weights is not None: w = np.array(head_weights[:num_heads_count], dtype=np.float32) head_scores = head_scores * w score_min = head_scores.min() if len(head_scores) > 0 else 0.0 score_max = head_scores.max() if len(head_scores) > 0 else 1.0 score_range = score_max - score_min if score_range > 0: head_scores_norm = (head_scores - score_min) / score_range else: head_scores_norm = np.ones_like(head_scores) * 0.5 consensus_score = float(np.mean(head_scores_norm)) if len(head_scores_norm) > 0 else 0.0 logger.debug( "Attention consensus computed", extra={ "num_heads": num_heads_count, "total_claims": len(all_claims), "consensus_score": consensus_score, }, ) return { "head_scores": head_scores_norm.tolist(), "attention_weights": relevance.tolist(), "consensus_score": consensus_score, } def cross_claim_attention( claims: list[str], num_heads: int = 4, backend: TensorBackend | None = None, ) -> dict[str, Any]: """Cross-attend between claims to detect agreement and conflict. Args: claims: List of claim texts. num_heads: Number of attention heads. backend: TensorBackend to use. Returns: Dict with 'similarity_matrix' and 'conflict_pairs' (indices). """ be = backend or get_backend() if len(claims) < 2: return {"similarity_matrix": [], "conflict_pairs": []} embeddings = be.embed_texts(claims) emb_np = be.to_numpy(embeddings) attn_out = be.to_numpy( be.multi_head_attention( be.from_numpy(emb_np), be.from_numpy(emb_np), be.from_numpy(emb_np), num_heads=num_heads, ) ) sim = be.to_numpy(be.cosine_similarity_matrix(be.from_numpy(attn_out), be.from_numpy(attn_out))) conflict_pairs: list[tuple[int, int]] = [] for i in range(len(claims)): for j in range(i + 1, len(claims)): if sim[i, j] < 0.3: conflict_pairs.append((i, j)) return { "similarity_matrix": sim.tolist(), "conflict_pairs": conflict_pairs, }