"""Postgres-backed persistence for production deployments. Uses psycopg2 (or asyncpg when available) for connection pooling. Falls back gracefully to in-memory if Postgres is unavailable. """ from __future__ import annotations import json import threading from typing import Any from fusionagi._logger import logger from fusionagi.core.persistence import StateBackend from fusionagi.schemas.task import Task, TaskState _CREATE_SCHEMA = """ CREATE TABLE IF NOT EXISTS tasks ( task_id TEXT PRIMARY KEY, data JSONB NOT NULL, state TEXT NOT NULL DEFAULT 'pending', created_at TIMESTAMPTZ DEFAULT NOW(), updated_at TIMESTAMPTZ DEFAULT NOW() ); CREATE TABLE IF NOT EXISTS traces ( id SERIAL PRIMARY KEY, task_id TEXT NOT NULL REFERENCES tasks(task_id) ON DELETE CASCADE, entry JSONB NOT NULL, created_at TIMESTAMPTZ DEFAULT NOW() ); CREATE INDEX IF NOT EXISTS idx_traces_task_id ON traces(task_id); CREATE INDEX IF NOT EXISTS idx_tasks_state ON tasks(state); """ class PostgresStateBackend(StateBackend): """Postgres-backed implementation of StateBackend. Args: dsn: PostgreSQL connection string (e.g., "postgresql://user:pass@host/db"). pool_size: Connection pool size (min connections kept open). max_overflow: Maximum extra connections beyond pool_size. """ def __init__( self, dsn: str = "postgresql://localhost/fusionagi", pool_size: int = 5, max_overflow: int = 10, ) -> None: self._dsn = dsn self._pool_size = pool_size self._max_overflow = max_overflow self._lock = threading.Lock() self._pool: Any = None self._available = False self._init_pool() def _init_pool(self) -> None: """Initialize connection pool and create schema.""" try: from psycopg2 import pool as pg_pool self._pool = pg_pool.ThreadedConnectionPool( minconn=1, maxconn=self._pool_size + self._max_overflow, dsn=self._dsn, ) conn = self._pool.getconn() try: with conn.cursor() as cur: cur.execute(_CREATE_SCHEMA) conn.commit() finally: self._pool.putconn(conn) self._available = True logger.info("PostgresStateBackend: connected", extra={"dsn": self._dsn.split("@")[-1]}) except ImportError: logger.warning("PostgresStateBackend: psycopg2 not installed, operating as no-op") except Exception as e: logger.warning("PostgresStateBackend: connection failed, operating as no-op", extra={"error": str(e)}) def _get_conn(self) -> Any: if not self._available or self._pool is None: return None return self._pool.getconn() def _put_conn(self, conn: Any) -> None: if self._pool is not None and conn is not None: self._pool.putconn(conn) def get_task(self, task_id: str) -> Task | None: """Load task by id from Postgres.""" conn = self._get_conn() if conn is None: return None try: with conn.cursor() as cur: cur.execute("SELECT data FROM tasks WHERE task_id = %s", (task_id,)) row = cur.fetchone() if row is None: return None return Task.model_validate(row[0] if isinstance(row[0], dict) else json.loads(row[0])) finally: self._put_conn(conn) def set_task(self, task: Task) -> None: """Upsert task into Postgres.""" if not self._available: return conn = self._get_conn() if conn is None: return try: with self._lock: with conn.cursor() as cur: cur.execute( """INSERT INTO tasks (task_id, data, state) VALUES (%s, %s, %s) ON CONFLICT (task_id) DO UPDATE SET data = EXCLUDED.data, state = EXCLUDED.state, updated_at = NOW()""", (task.task_id, task.model_dump_json(), task.state.value), ) conn.commit() finally: self._put_conn(conn) def get_task_state(self, task_id: str) -> TaskState | None: """Return current task state.""" conn = self._get_conn() if conn is None: return None try: with conn.cursor() as cur: cur.execute("SELECT state FROM tasks WHERE task_id = %s", (task_id,)) row = cur.fetchone() return TaskState(row[0]) if row else None finally: self._put_conn(conn) def set_task_state(self, task_id: str, state: TaskState) -> None: """Update task state in Postgres.""" if not self._available: return conn = self._get_conn() if conn is None: return try: with self._lock: with conn.cursor() as cur: cur.execute( "UPDATE tasks SET state = %s, updated_at = NOW() WHERE task_id = %s", (state.value, task_id), ) conn.commit() finally: self._put_conn(conn) def append_trace(self, task_id: str, entry: dict[str, Any]) -> None: """Append trace entry to Postgres.""" if not self._available: return conn = self._get_conn() if conn is None: return try: with self._lock: with conn.cursor() as cur: cur.execute( "INSERT INTO traces (task_id, entry) VALUES (%s, %s)", (task_id, json.dumps(entry)), ) conn.commit() finally: self._put_conn(conn) def get_trace(self, task_id: str) -> list[dict[str, Any]]: """Load trace entries from Postgres.""" conn = self._get_conn() if conn is None: return [] try: with conn.cursor() as cur: cur.execute( "SELECT entry FROM traces WHERE task_id = %s ORDER BY id", (task_id,), ) return [ row[0] if isinstance(row[0], dict) else json.loads(row[0]) for row in cur.fetchall() ] finally: self._put_conn(conn) def list_tasks(self, state: TaskState | None = None, limit: int = 100) -> list[Task]: """List tasks from Postgres.""" conn = self._get_conn() if conn is None: return [] try: with conn.cursor() as cur: if state is not None: cur.execute("SELECT data FROM tasks WHERE state = %s ORDER BY updated_at DESC LIMIT %s", (state.value, limit)) else: cur.execute("SELECT data FROM tasks ORDER BY updated_at DESC LIMIT %s", (limit,)) return [ Task.model_validate(row[0] if isinstance(row[0], dict) else json.loads(row[0])) for row in cur.fetchall() ] finally: self._put_conn(conn) def delete_task(self, task_id: str) -> bool: """Delete task and its traces from Postgres.""" if not self._available: return False conn = self._get_conn() if conn is None: return False try: with self._lock: with conn.cursor() as cur: cur.execute("DELETE FROM tasks WHERE task_id = %s", (task_id,)) deleted = cur.rowcount > 0 conn.commit() return deleted finally: self._put_conn(conn) def count_tasks(self) -> int: """Count tasks in Postgres.""" conn = self._get_conn() if conn is None: return 0 try: with conn.cursor() as cur: cur.execute("SELECT COUNT(*) FROM tasks") row = cur.fetchone() return row[0] if row else 0 finally: self._put_conn(conn) def close(self) -> None: """Close the connection pool.""" if self._pool is not None: self._pool.closeall() self._available = False