"""Lightweight database migration runner for FusionAGI. Usage: python -m migrations.migrate up # Apply all pending migrations python -m migrations.migrate down # Rollback last migration python -m migrations.migrate status # Show migration status """ from __future__ import annotations import os import sqlite3 import sys from pathlib import Path VERSIONS_DIR = Path(__file__).parent / "versions" DEFAULT_DB = os.environ.get("FUSIONAGI_DB_PATH", "fusionagi.db") def get_connection(db_path: str = DEFAULT_DB) -> sqlite3.Connection: """Get database connection and ensure migration tracking table exists.""" conn = sqlite3.connect(db_path) conn.execute( "CREATE TABLE IF NOT EXISTS _migrations " "(id INTEGER PRIMARY KEY AUTOINCREMENT, version TEXT NOT NULL UNIQUE, " "applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)" ) conn.commit() return conn def get_applied(conn: sqlite3.Connection) -> set[str]: """Get set of applied migration versions.""" rows = conn.execute("SELECT version FROM _migrations").fetchall() return {r[0] for r in rows} def get_migration_files() -> list[tuple[str, Path]]: """Get sorted list of (version, path) tuples.""" files = sorted(VERSIONS_DIR.glob("*.sql")) return [(f.stem, f) for f in files] def parse_migration(path: Path) -> tuple[str, str]: """Parse a migration file into (up_sql, down_sql).""" text = path.read_text() parts = text.split("-- DOWN") up_sql = parts[0].replace("-- UP", "").strip() down_sql = parts[1].strip() if len(parts) > 1 else "" return up_sql, down_sql def migrate_up(db_path: str = DEFAULT_DB) -> int: """Apply all pending migrations. Returns count applied.""" conn = get_connection(db_path) applied = get_applied(conn) count = 0 for version, path in get_migration_files(): if version not in applied: up_sql, _ = parse_migration(path) conn.executescript(up_sql) conn.execute("INSERT INTO _migrations (version) VALUES (?)", (version,)) conn.commit() print(f"Applied: {version}") count += 1 if count == 0: print("No pending migrations.") return count def migrate_down(db_path: str = DEFAULT_DB) -> bool: """Rollback the last applied migration.""" conn = get_connection(db_path) applied = get_applied(conn) if not applied: print("No migrations to rollback.") return False migrations = get_migration_files() applied_migrations = [(v, p) for v, p in migrations if v in applied] if not applied_migrations: print("No migrations to rollback.") return False version, path = applied_migrations[-1] _, down_sql = parse_migration(path) if not down_sql: print(f"No DOWN section in {version}. Cannot rollback.") return False conn.executescript(down_sql) try: conn.execute("DELETE FROM _migrations WHERE version = ?", (version,)) except Exception: pass conn.commit() print(f"Rolled back: {version}") return True def show_status(db_path: str = DEFAULT_DB) -> None: """Show migration status.""" conn = get_connection(db_path) applied = get_applied(conn) for version, _ in get_migration_files(): status = "applied" if version in applied else "pending" print(f" {version}: {status}") if __name__ == "__main__": cmd = sys.argv[1] if len(sys.argv) > 1 else "status" db = sys.argv[2] if len(sys.argv) > 2 else DEFAULT_DB if cmd == "up": migrate_up(db) elif cmd == "down": migrate_down(db) elif cmd == "status": show_status(db) else: print(f"Unknown command: {cmd}. Use: up, down, status")