"""DB connector: query databases via configurable SQL drivers.""" from typing import Any from fusionagi._logger import logger from fusionagi.tools.connectors.base import BaseConnector class DBConnector(BaseConnector): """Database connector supporting SQLite (built-in) and Postgres (via psycopg). Provides read-only query access by default. Write operations require explicit ``allow_write=True`` at init. """ name = "db" def __init__( self, connection_string: str = ":memory:", driver: str = "sqlite", allow_write: bool = False, ) -> None: self._conn_str = connection_string self._driver = driver self._allow_write = allow_write self._conn: Any = None def _get_connection(self) -> Any: if self._conn is not None: return self._conn if self._driver == "sqlite": import sqlite3 self._conn = sqlite3.connect(self._conn_str) self._conn.row_factory = sqlite3.Row elif self._driver == "postgres": try: import psycopg self._conn = psycopg.connect(self._conn_str) except ImportError as e: raise ImportError("Install psycopg: pip install psycopg[binary]") from e else: raise ValueError(f"Unsupported driver: {self._driver}") return self._conn def invoke(self, action: str, params: dict[str, Any]) -> Any: if action == "query": return self._query(params.get("query", ""), params.get("params")) if action == "execute" and self._allow_write: return self._execute(params.get("query", ""), params.get("params")) if action == "tables": return self._list_tables() if action == "schema": return self._table_schema(params.get("table", "")) return {"error": f"Unknown or disallowed action: {action}"} def _query(self, sql: str, bind_params: Any = None) -> dict[str, Any]: if not sql.strip(): return {"rows": [], "error": "Empty query"} try: conn = self._get_connection() cur = conn.cursor() cur.execute(sql, bind_params or ()) rows = cur.fetchall() if self._driver == "sqlite": cols = [d[0] for d in (cur.description or [])] rows = [dict(zip(cols, r)) for r in rows] else: cols = [d.name for d in (cur.description or [])] rows = [dict(zip(cols, r)) for r in rows] cur.close() return {"rows": rows[:1000], "columns": cols, "count": len(rows), "error": None} except Exception as e: logger.warning("DBConnector query failed", extra={"error": str(e)}) return {"rows": [], "error": str(e)} def _execute(self, sql: str, bind_params: Any = None) -> dict[str, Any]: try: conn = self._get_connection() cur = conn.cursor() cur.execute(sql, bind_params or ()) conn.commit() affected = cur.rowcount cur.close() return {"affected_rows": affected, "error": None} except Exception as e: logger.warning("DBConnector execute failed", extra={"error": str(e)}) return {"affected_rows": 0, "error": str(e)} def _list_tables(self) -> dict[str, Any]: if self._driver == "sqlite": return self._query("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name") return self._query("SELECT tablename AS name FROM pg_tables WHERE schemaname='public' ORDER BY tablename") def _table_schema(self, table: str) -> dict[str, Any]: if not table: return {"columns": [], "error": "Table name required"} if self._driver == "sqlite": return self._query(f"PRAGMA table_info('{table}')") return self._query( "SELECT column_name, data_type, is_nullable FROM information_schema.columns " "WHERE table_name = %s ORDER BY ordinal_position", (table,), ) def schema(self) -> dict[str, Any]: actions = ["query", "tables", "schema"] if self._allow_write: actions.append("execute") return { "name": self.name, "actions": actions, "parameters": {"query": "string", "params": "list", "table": "string"}, }