"""Semantic memory graph: nodes = AtomicSemanticUnit, edges = SemanticRelation.""" from __future__ import annotations from collections import defaultdict from fusionagi._logger import logger from fusionagi.schemas.atomic import ( AtomicSemanticUnit, AtomicUnitType, SemanticRelation, ) class SemanticGraphMemory: """ Graph-backed semantic memory: nodes = atomic units, edges = relations. Supports add_unit, add_relation, query_units, query_neighbors, query_by_type. In-memory implementation with dict + adjacency list. """ def __init__(self, max_units: int = 50000) -> None: self._units: dict[str, AtomicSemanticUnit] = {} self._by_type: dict[AtomicUnitType, list[str]] = defaultdict(list) self._outgoing: dict[str, list[SemanticRelation]] = defaultdict(list) self._incoming: dict[str, list[SemanticRelation]] = defaultdict(list) self._max_units = max_units def add_unit(self, unit: AtomicSemanticUnit) -> None: """Add an atomic semantic unit.""" if len(self._units) >= self._max_units and unit.unit_id not in self._units: self._evict_one() self._units[unit.unit_id] = unit self._by_type[unit.type].append(unit.unit_id) logger.debug("Semantic graph: unit added", extra={"unit_id": unit.unit_id, "type": unit.type.value}) def add_relation(self, relation: SemanticRelation) -> None: """Add a relation between units.""" if relation.from_id in self._units and relation.to_id in self._units: self._outgoing[relation.from_id].append(relation) self._incoming[relation.to_id].append(relation) def get_unit(self, unit_id: str) -> AtomicSemanticUnit | None: """Get unit by ID.""" return self._units.get(unit_id) def query_units( self, unit_ids: list[str] | None = None, unit_type: AtomicUnitType | None = None, limit: int = 100, ) -> list[AtomicSemanticUnit]: """Query units by IDs or type.""" if unit_ids: return [self._units[uid] for uid in unit_ids if uid in self._units][:limit] if unit_type: ids = self._by_type.get(unit_type, [])[-limit:] return [self._units[uid] for uid in ids if uid in self._units] return list(self._units.values())[-limit:] def query_neighbors( self, unit_id: str, direction: str = "outgoing", relation_type: str | None = None, ) -> list[tuple[AtomicSemanticUnit, SemanticRelation]]: """Get neighboring units and relations.""" edges = self._outgoing[unit_id] if direction == "outgoing" else self._incoming[unit_id] results: list[tuple[AtomicSemanticUnit, SemanticRelation]] = [] for rel in edges: if relation_type and rel.relation_type.value != relation_type: continue other_id = rel.to_id if direction == "outgoing" else rel.from_id other = self._units.get(other_id) if other: results.append((other, rel)) return results def query_by_type(self, unit_type: AtomicUnitType, limit: int = 100) -> list[AtomicSemanticUnit]: """Query units by type.""" return self.query_units(unit_type=unit_type, limit=limit) def ingest_decomposition( self, units: list[AtomicSemanticUnit], relations: list[SemanticRelation], ) -> None: """Ingest a DecompositionResult into the graph.""" for u in units: self.add_unit(u) for r in relations: self.add_relation(r) def semantic_search( self, query: str, top_k: int = 10, ) -> list[tuple[AtomicSemanticUnit, float]]: """Search stored units by semantic similarity using GPU when available. Args: query: Query text to search for. top_k: Number of top results to return. Returns: List of (unit, similarity_score) tuples sorted by score descending. """ try: from fusionagi.memory.gpu_search import semantic_search all_units = list(self._units.values()) return semantic_search(query, all_units, top_k=top_k) except ImportError: return self._cpu_search(query, top_k) def _cpu_search( self, query: str, top_k: int, ) -> list[tuple[AtomicSemanticUnit, float]]: """CPU fallback: word-overlap similarity.""" query_words = set(query.lower().split()) scored: list[tuple[AtomicSemanticUnit, float]] = [] for unit in self._units.values(): unit_words = set(unit.content.lower().split()) if not unit_words: continue overlap = len(query_words & unit_words) score = overlap / max(len(query_words | unit_words), 1) scored.append((unit, score)) scored.sort(key=lambda x: x[1], reverse=True) return scored[:top_k] def _evict_one(self) -> None: """Evict oldest unit (simple FIFO on first key).""" if not self._units: return uid = next(iter(self._units)) unit = self._units.pop(uid, None) if unit: self._by_type[unit.type] = [x for x in self._by_type[unit.type] if x != uid] self._outgoing.pop(uid, None) self._incoming.pop(uid, None) logger.debug("Semantic graph: evicted unit", extra={"unit_id": uid})