"""Response cache with TTL for the FusionAGI API. Provides both in-memory and Redis-backed implementations with a common interface. """ from __future__ import annotations import hashlib import json import time from abc import ABC, abstractmethod from typing import Any from fusionagi._logger import logger class CacheBackend(ABC): """Abstract cache backend interface.""" @abstractmethod def get(self, key: str) -> Any | None: """Get value by key, or None if missing/expired.""" ... @abstractmethod def set(self, key: str, value: Any, ttl: float | None = None) -> None: """Set key/value with optional TTL.""" ... @abstractmethod def delete(self, key: str) -> bool: """Delete a key. Returns True if existed.""" ... @abstractmethod def clear(self) -> int: """Clear all entries. Returns count cleared.""" ... @abstractmethod def stats(self) -> dict[str, Any]: """Return backend stats.""" ... class MemoryCacheBackend(CacheBackend): """In-memory LRU cache with TTL.""" def __init__(self, max_size: int = 1000, default_ttl: float = 300.0) -> None: self._cache: dict[str, tuple[float, float, Any]] = {} # key -> (set_time, ttl, value) self._max_size = max_size self._default_ttl = default_ttl def get(self, key: str) -> Any | None: entry = self._cache.get(key) if entry is None: return None set_time, ttl, value = entry if time.time() - set_time > ttl: del self._cache[key] return None return value def set(self, key: str, value: Any, ttl: float | None = None) -> None: if len(self._cache) >= self._max_size: oldest = min(self._cache, key=lambda k: self._cache[k][0]) del self._cache[oldest] self._cache[key] = (time.time(), ttl or self._default_ttl, value) def delete(self, key: str) -> bool: return self._cache.pop(key, None) is not None def clear(self) -> int: count = len(self._cache) self._cache.clear() return count def stats(self) -> dict[str, Any]: now = time.time() active = sum(1 for st, ttl, _ in self._cache.values() if now - st <= ttl) return {"backend": "memory", "total": len(self._cache), "active": active, "max_size": self._max_size} class RedisCacheBackend(CacheBackend): """Redis-backed cache. Requires the ``redis`` package. Falls back to memory cache if Redis is unavailable. """ def __init__(self, redis_url: str = "redis://localhost:6379/0", default_ttl: float = 300.0) -> None: self._default_ttl = default_ttl self._prefix = "fusionagi:cache:" self._redis: Any = None try: import redis self._redis = redis.from_url(redis_url, decode_responses=True) self._redis.ping() logger.info("Redis cache connected", extra={"url": redis_url}) except Exception as e: logger.warning("Redis unavailable, cache operations will be no-ops", extra={"error": str(e)}) self._redis = None @property def available(self) -> bool: """Check if Redis is connected.""" return self._redis is not None def _key(self, key: str) -> str: return f"{self._prefix}{key}" def get(self, key: str) -> Any | None: if not self._redis: return None try: raw = self._redis.get(self._key(key)) if raw is None: return None return json.loads(raw) except Exception: return None def set(self, key: str, value: Any, ttl: float | None = None) -> None: if not self._redis: return try: ttl_seconds = int(ttl or self._default_ttl) self._redis.setex(self._key(key), ttl_seconds, json.dumps(value)) except Exception as e: logger.warning("Redis set failed", extra={"error": str(e)}) def delete(self, key: str) -> bool: if not self._redis: return False try: return bool(self._redis.delete(self._key(key))) except Exception: return False def clear(self) -> int: if not self._redis: return 0 try: keys = self._redis.keys(f"{self._prefix}*") if keys: return self._redis.delete(*keys) return 0 except Exception: return 0 def stats(self) -> dict[str, Any]: if not self._redis: return {"backend": "redis", "available": False} try: info = self._redis.info("keyspace") return {"backend": "redis", "available": True, "info": info} except Exception: return {"backend": "redis", "available": False} class ResponseCache: """High-level response cache with pluggable backend. Uses MemoryCacheBackend by default. Pass a RedisCacheBackend for production multi-worker deployments. """ def __init__( self, backend: CacheBackend | None = None, max_size: int = 1000, ttl_seconds: float = 300.0, ) -> None: self._backend = backend or MemoryCacheBackend(max_size=max_size, default_ttl=ttl_seconds) self._ttl = ttl_seconds @staticmethod def _make_key(prompt: str, session_id: str, tenant_id: str = "default") -> str: """Generate a cache key from prompt + session context.""" raw = json.dumps({"prompt": prompt, "session": session_id, "tenant": tenant_id}, sort_keys=True) return hashlib.sha256(raw.encode()).hexdigest() def get(self, prompt: str, session_id: str, tenant_id: str = "default") -> Any | None: """Get cached response if it exists and hasn't expired.""" key = self._make_key(prompt, session_id, tenant_id) return self._backend.get(key) def set(self, prompt: str, session_id: str, value: Any, tenant_id: str = "default") -> None: """Cache a response.""" key = self._make_key(prompt, session_id, tenant_id) self._backend.set(key, value, self._ttl) def invalidate(self, prompt: str, session_id: str, tenant_id: str = "default") -> bool: """Remove a specific cache entry.""" key = self._make_key(prompt, session_id, tenant_id) return self._backend.delete(key) def clear(self) -> int: """Clear all cache entries.""" return self._backend.clear() def stats(self) -> dict[str, Any]: """Return cache statistics.""" return self._backend.stats()