Some checks failed
- New fusionagi/gpu/ module with TensorBackend protocol abstraction - TensorFlowBackend: GPU-accelerated ops with TensorCore mixed-precision - NumPyBackend: CPU fallback (always available, no extra deps) - Auto-selects best available backend at runtime - GPU-accelerated operations: - Cosine similarity matrix (batched, XLA-compiled) - Multi-head attention for consensus scoring - Batch hypothesis scoring on GPU - Semantic similarity search (pairwise, nearest-neighbor, deduplication) - New TensorFlowAdapter (fusionagi/adapters/): - LLMAdapter for local TF/Keras model inference - TensorCore mixed-precision support - GPU-accelerated embedding synthesis fallback - Reasoning pipeline integration: - gpu_scoring.py: drop-in GPU replacement for multi_path scoring - Super Big Brain: use_gpu config flag, GPU scoring when available - Memory integration: - gpu_search.py: GPU-accelerated semantic search for SemanticGraphMemory - Self-improvement integration: - gpu_training.py: gradient-based heuristic weight optimization - Reflective memory training loop with loss tracking - Dependencies: gpu extra (tensorflow>=2.16, numpy>=1.26) - 64 new tests (276 total), all passing - Architecture spec: docs/gpu_tensorcore_integration.md Co-Authored-By: Nakamoto, S <defi@defi-oracle.io>
163 lines
4.9 KiB
Python
163 lines
4.9 KiB
Python
"""GPU-accelerated attention mechanisms for multi-head consensus.
|
|
|
|
Provides attention-based consensus scoring for the Dvādaśa pipeline:
|
|
- Head output attention: weight head contributions by relevance
|
|
- Claim-level attention: cross-attend between claims for conflict detection
|
|
- Weighted consensus: attention-based aggregation of head outputs
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from fusionagi._logger import logger
|
|
from fusionagi.gpu.backend import TensorBackend, get_backend
|
|
|
|
|
|
def attention_consensus(
|
|
head_embeddings: list[list[str]],
|
|
query_text: str,
|
|
head_weights: list[float] | None = None,
|
|
num_heads: int = 4,
|
|
backend: TensorBackend | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Score head contributions using multi-head attention against the query.
|
|
|
|
Each head's claims are embedded, then cross-attended against the query
|
|
to produce relevance-weighted scores.
|
|
|
|
Args:
|
|
head_embeddings: List of claim-text lists, one per head.
|
|
query_text: The user's original query.
|
|
head_weights: Optional per-head reliability weights.
|
|
num_heads: Number of attention heads.
|
|
backend: TensorBackend to use.
|
|
|
|
Returns:
|
|
Dict with 'head_scores' (list of floats), 'attention_weights' (matrix),
|
|
and 'consensus_score' (float).
|
|
"""
|
|
be = backend or get_backend()
|
|
import numpy as np
|
|
|
|
if not head_embeddings:
|
|
return {"head_scores": [], "attention_weights": [], "consensus_score": 0.0}
|
|
|
|
all_claims: list[str] = []
|
|
head_indices: list[int] = []
|
|
for i, claims in enumerate(head_embeddings):
|
|
for claim in claims:
|
|
all_claims.append(claim)
|
|
head_indices.append(i)
|
|
|
|
if not all_claims:
|
|
return {
|
|
"head_scores": [0.0] * len(head_embeddings),
|
|
"attention_weights": [],
|
|
"consensus_score": 0.0,
|
|
}
|
|
|
|
query_emb = be.embed_texts([query_text])
|
|
claim_emb = be.embed_texts(all_claims)
|
|
|
|
query_np = be.to_numpy(query_emb)
|
|
claims_np = be.to_numpy(claim_emb)
|
|
|
|
query_expanded = np.tile(query_np, (len(all_claims), 1))
|
|
attn_output = be.to_numpy(
|
|
be.multi_head_attention(
|
|
be.from_numpy(query_expanded),
|
|
be.from_numpy(claims_np),
|
|
be.from_numpy(claims_np),
|
|
num_heads=num_heads,
|
|
)
|
|
)
|
|
|
|
relevance = np.sum(attn_output * claims_np, axis=1)
|
|
|
|
num_heads_count = len(head_embeddings)
|
|
head_scores = np.zeros(num_heads_count, dtype=np.float32)
|
|
head_claim_counts = np.zeros(num_heads_count, dtype=np.float32)
|
|
|
|
for idx, head_idx in enumerate(head_indices):
|
|
head_scores[head_idx] += relevance[idx]
|
|
head_claim_counts[head_idx] += 1.0
|
|
|
|
safe_counts: Any = np.maximum(head_claim_counts, 1.0)
|
|
head_scores = head_scores / safe_counts
|
|
|
|
if head_weights is not None:
|
|
w = np.array(head_weights[:num_heads_count], dtype=np.float32)
|
|
head_scores = head_scores * w
|
|
|
|
score_min = head_scores.min() if len(head_scores) > 0 else 0.0
|
|
score_max = head_scores.max() if len(head_scores) > 0 else 1.0
|
|
score_range = score_max - score_min
|
|
if score_range > 0:
|
|
head_scores_norm = (head_scores - score_min) / score_range
|
|
else:
|
|
head_scores_norm = np.ones_like(head_scores) * 0.5
|
|
|
|
consensus_score = float(np.mean(head_scores_norm)) if len(head_scores_norm) > 0 else 0.0
|
|
|
|
logger.debug(
|
|
"Attention consensus computed",
|
|
extra={
|
|
"num_heads": num_heads_count,
|
|
"total_claims": len(all_claims),
|
|
"consensus_score": consensus_score,
|
|
},
|
|
)
|
|
|
|
return {
|
|
"head_scores": head_scores_norm.tolist(),
|
|
"attention_weights": relevance.tolist(),
|
|
"consensus_score": consensus_score,
|
|
}
|
|
|
|
|
|
def cross_claim_attention(
|
|
claims: list[str],
|
|
num_heads: int = 4,
|
|
backend: TensorBackend | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Cross-attend between claims to detect agreement and conflict.
|
|
|
|
Args:
|
|
claims: List of claim texts.
|
|
num_heads: Number of attention heads.
|
|
backend: TensorBackend to use.
|
|
|
|
Returns:
|
|
Dict with 'similarity_matrix' and 'conflict_pairs' (indices).
|
|
"""
|
|
be = backend or get_backend()
|
|
|
|
if len(claims) < 2:
|
|
return {"similarity_matrix": [], "conflict_pairs": []}
|
|
|
|
embeddings = be.embed_texts(claims)
|
|
emb_np = be.to_numpy(embeddings)
|
|
|
|
attn_out = be.to_numpy(
|
|
be.multi_head_attention(
|
|
be.from_numpy(emb_np),
|
|
be.from_numpy(emb_np),
|
|
be.from_numpy(emb_np),
|
|
num_heads=num_heads,
|
|
)
|
|
)
|
|
|
|
sim = be.to_numpy(be.cosine_similarity_matrix(be.from_numpy(attn_out), be.from_numpy(attn_out)))
|
|
|
|
conflict_pairs: list[tuple[int, int]] = []
|
|
for i in range(len(claims)):
|
|
for j in range(i + 1, len(claims)):
|
|
if sim[i, j] < 0.3:
|
|
conflict_pairs.append((i, j))
|
|
|
|
return {
|
|
"similarity_matrix": sim.tolist(),
|
|
"conflict_pairs": conflict_pairs,
|
|
}
|