"""API key rotation mechanism for FusionAGI.""" from __future__ import annotations import hashlib import secrets import time from typing import Any from pydantic import BaseModel, Field class APIKeyRecord(BaseModel): """Record for a rotatable API key.""" key_hash: str created_at: float = Field(default_factory=time.time) expires_at: float | None = None label: str = "default" active: bool = True class SecretRotator: """Manages API key lifecycle: generation, rotation, and expiry. Keys are stored as SHA-256 hashes for security. Supports multiple active keys for zero-downtime rotation. """ def __init__(self, max_active_keys: int = 3) -> None: self._keys: list[APIKeyRecord] = [] self._max_active = max_active_keys @staticmethod def _hash_key(key: str) -> str: """Hash a key using SHA-256.""" return hashlib.sha256(key.encode()).hexdigest() def generate_key(self, label: str = "default", ttl_seconds: float | None = None) -> str: """Generate a new API key and register it. Returns the plaintext key.""" key = secrets.token_urlsafe(32) record = APIKeyRecord( key_hash=self._hash_key(key), label=label, expires_at=time.time() + ttl_seconds if ttl_seconds else None, ) self._keys.append(record) self._enforce_max_active() return key def validate_key(self, key: str) -> bool: """Check if a key is valid (active and not expired).""" key_hash = self._hash_key(key) now = time.time() for record in self._keys: if record.key_hash == key_hash and record.active: if record.expires_at and now > record.expires_at: record.active = False return False return True return False def rotate(self, label: str = "default", ttl_seconds: float | None = None) -> str: """Rotate keys: generate new, keep previous active for overlap period.""" return self.generate_key(label=label, ttl_seconds=ttl_seconds) def revoke(self, key: str) -> bool: """Revoke a specific key.""" key_hash = self._hash_key(key) for record in self._keys: if record.key_hash == key_hash: record.active = False return True return False def revoke_expired(self) -> int: """Deactivate all expired keys.""" now = time.time() count = 0 for record in self._keys: if record.active and record.expires_at and now > record.expires_at: record.active = False count += 1 return count def _enforce_max_active(self) -> None: """Ensure we don't exceed max active keys.""" active = [k for k in self._keys if k.active] while len(active) > self._max_active: active[0].active = False active = active[1:] def list_keys(self) -> list[dict[str, Any]]: """List all keys (without hashes).""" return [ { "label": k.label, "active": k.active, "created_at": k.created_at, "expires_at": k.expires_at, } for k in self._keys ]