"""SQLite-backed state backend for task persistence. Uses synchronous sqlite3 wrapped in a thread pool for async compatibility. For production Postgres, swap with asyncpg or SQLAlchemy async. """ from __future__ import annotations import json import sqlite3 import threading from typing import Any from fusionagi._logger import logger from fusionagi.core.persistence import StateBackend from fusionagi.schemas.task import Task, TaskState class SQLiteStateBackend(StateBackend): """SQLite-backed implementation of StateBackend. Stores tasks, task states, and traces in a local SQLite database. Thread-safe via a threading lock on write operations. """ def __init__(self, db_path: str = "fusionagi_state.db") -> None: self._db_path = db_path self._lock = threading.Lock() self._init_schema() def _get_conn(self) -> sqlite3.Connection: """Get a new connection (sqlite3 connections are not thread-safe).""" conn = sqlite3.connect(self._db_path) conn.row_factory = sqlite3.Row return conn def _init_schema(self) -> None: """Create tables if they don't exist.""" conn = self._get_conn() try: conn.executescript(""" CREATE TABLE IF NOT EXISTS tasks ( task_id TEXT PRIMARY KEY, data TEXT NOT NULL, state TEXT NOT NULL DEFAULT 'pending', created_at TEXT, updated_at TEXT ); CREATE TABLE IF NOT EXISTS traces ( id INTEGER PRIMARY KEY AUTOINCREMENT, task_id TEXT NOT NULL, entry TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (task_id) REFERENCES tasks(task_id) ); CREATE INDEX IF NOT EXISTS idx_traces_task ON traces(task_id); """) conn.commit() finally: conn.close() logger.info("SQLiteStateBackend initialized", extra={"db_path": self._db_path}) def get_task(self, task_id: str) -> Task | None: """Load task by id.""" conn = self._get_conn() try: row = conn.execute("SELECT data FROM tasks WHERE task_id = ?", (task_id,)).fetchone() if row is None: return None return Task.model_validate_json(row["data"]) finally: conn.close() def set_task(self, task: Task) -> None: """Save or update a task.""" with self._lock: conn = self._get_conn() try: data = task.model_dump_json() conn.execute( "INSERT OR REPLACE INTO tasks (task_id, data, state, created_at, updated_at) " "VALUES (?, ?, ?, ?, ?)", ( task.task_id, data, task.state.value, task.created_at.isoformat() if task.created_at else None, task.updated_at.isoformat() if task.updated_at else None, ), ) conn.commit() finally: conn.close() def get_task_state(self, task_id: str) -> TaskState | None: """Return current task state or None if task unknown.""" conn = self._get_conn() try: row = conn.execute("SELECT state FROM tasks WHERE task_id = ?", (task_id,)).fetchone() if row is None: return None return TaskState(row["state"]) finally: conn.close() def set_task_state(self, task_id: str, state: TaskState) -> None: """Update task state; creates no task if missing.""" with self._lock: conn = self._get_conn() try: task = self.get_task(task_id) if task is not None: conn.execute( "UPDATE tasks SET state = ?, updated_at = CURRENT_TIMESTAMP WHERE task_id = ?", (state.value, task_id), ) # Also update the JSON data blob updated = task.model_copy(update={"state": state}) conn.execute( "UPDATE tasks SET data = ? WHERE task_id = ?", (updated.model_dump_json(), task_id), ) conn.commit() finally: conn.close() def append_trace(self, task_id: str, entry: dict[str, Any]) -> None: """Append trace entry.""" with self._lock: conn = self._get_conn() try: conn.execute( "INSERT INTO traces (task_id, entry) VALUES (?, ?)", (task_id, json.dumps(entry)), ) conn.commit() finally: conn.close() def get_trace(self, task_id: str) -> list[dict[str, Any]]: """Load trace for task.""" conn = self._get_conn() try: rows = conn.execute( "SELECT entry FROM traces WHERE task_id = ? ORDER BY id", (task_id,), ).fetchall() return [json.loads(row["entry"]) for row in rows] finally: conn.close() def list_tasks(self, state: TaskState | None = None, limit: int = 100) -> list[Task]: """List tasks, optionally filtered by state.""" conn = self._get_conn() try: if state is not None: rows = conn.execute( "SELECT data FROM tasks WHERE state = ? ORDER BY rowid DESC LIMIT ?", (state.value, limit), ).fetchall() else: rows = conn.execute( "SELECT data FROM tasks ORDER BY rowid DESC LIMIT ?", (limit,), ).fetchall() return [Task.model_validate_json(row["data"]) for row in rows] finally: conn.close() def delete_task(self, task_id: str) -> bool: """Delete a task and its traces.""" with self._lock: conn = self._get_conn() try: conn.execute("DELETE FROM traces WHERE task_id = ?", (task_id,)) cursor = conn.execute("DELETE FROM tasks WHERE task_id = ?", (task_id,)) conn.commit() return cursor.rowcount > 0 finally: conn.close() def count_tasks(self) -> int: """Return total task count.""" conn = self._get_conn() try: row = conn.execute("SELECT COUNT(*) as cnt FROM tasks").fetchone() return row["cnt"] if row else 0 finally: conn.close()