Source code for nextorm.migrations.core

"""Core logic for file-based migrations — snapshot diffing, file writing, and running.

Migration files are plain Python modules stored under a project directory
(e.g. ``migrations/``).  Each file is numbered sequentially:

    migrations/
        0001_initial.py
        0002_add_email_to_user.py
        ...

Each migration module **must** define an ``upgrade(db)`` function.
``downgrade(db)`` is optional.

Example migration file::

    def upgrade(db):
        db._execute_dml("ALTER TABLE user ADD COLUMN email TEXT NOT NULL DEFAULT ''", [])
        db._execute_dml("ALTER TABLE user ALTER COLUMN email DROP DEFAULT", [])


    def downgrade(db):
        db._execute_dml("ALTER TABLE user DROP COLUMN email", [])
"""

from __future__ import annotations

import importlib.util
import json
import re
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast

from nextorm.fields import AttrValue

if TYPE_CHECKING:
    from nextorm.database import Database

__all__ = ["MigrationRunner", "MigrationStatus", "makemigrations", "migrate", "showmigrations"]

# ---------------------------------------------------------------------------
# Public types
# ---------------------------------------------------------------------------


[docs] class MigrationStatus: """Status of a single migration file. Attributes ---------- name: The migration filename (e.g. ``0001_initial.py``). version: The stem of the filename without the ``.py`` extension (e.g. ``0001_initial``). applied: ``True`` if the migration has been recorded in the tracking table. applied_at: ISO‑8601 timestamp string recorded when the migration was applied, or an empty string when not yet applied. """ __slots__ = ("name", "version", "applied", "applied_at")
[docs] def __init__(self, name: str, version: str, applied: bool, applied_at: str) -> None: self.name = name self.version = version self.applied = applied self.applied_at = applied_at
def __repr__(self) -> str: mark = "x" if self.applied else " " ts = f" ({self.applied_at})" if self.applied and self.applied_at else "" return f"[{mark}] {self.name}{ts}"
# --------------------------------------------------------------------------- # Schema snapshot helpers # --------------------------------------------------------------------------- _TRACKING_TABLE = "_nextorm_migrations" def _schema_to_dict(schema: dict[str, Any]) -> dict[str, Any]: """Serialize a ``{table_name: Table}`` schema to a JSON-safe dict.""" result: dict[str, Any] = {} for table_name, table in schema.items(): result[table_name] = { "columns": [ { "name": c.name, "nullable": c.nullable, "primary_key": c.primary_key, "unique": c.unique, } for c in table.columns ], "foreign_keys": [ { "name": fk.name, "column": fk.column, "ref_table": fk.ref_table, "ref_column": fk.ref_column, "on_delete": fk.on_delete, } for fk in table.foreign_keys ], "indexes": [ {"name": idx.name, "columns": idx.columns, "unique": idx.unique} for idx in table.indexes ], } return result def _snapshot_path(directory: Path) -> Path: return directory / ".schema_snapshot.json" def _load_snapshot(directory: Path) -> dict[str, Any]: snap = _snapshot_path(directory) if not snap.exists(): return {} return cast("dict[str, Any]", json.loads(snap.read_text())) def _save_snapshot(directory: Path, schema: dict[str, Any]) -> None: _snapshot_path(directory).write_text(json.dumps(schema, indent=2)) # --------------------------------------------------------------------------- # Migration file helpers # --------------------------------------------------------------------------- _MIGRATION_RE = re.compile(r"^(\d{4})_(.+)\.py$") def _next_migration_number(directory: Path) -> int: """Return the next sequential migration number.""" existing = [int(m.group(1)) for f in directory.iterdir() if (m := _MIGRATION_RE.match(f.name))] return max(existing, default=0) + 1 def _migration_files(directory: Path) -> list[Path]: """Return sorted list of migration .py files in *directory*.""" return sorted(f for f in directory.iterdir() if _MIGRATION_RE.match(f.name)) def _load_migration_module(path: Path) -> Any: """Import a migration file as a module.""" spec = importlib.util.spec_from_file_location(f"_migration_{path.stem}", path) if spec is None or spec.loader is None: raise ImportError(f"Cannot load migration: {path}") mod = importlib.util.module_from_spec(spec) sys.modules[mod.__name__] = mod spec.loader.exec_module(mod) return mod def _generate_migration_body(ops: list[str]) -> str: """Generate the Python source for a migration file from a list of SQL statements.""" lines: list[str] = [ '"""Auto-generated migration."""', "", "", "def upgrade(db):", ] if ops: for sql in ops: escaped = sql.replace("\\", "\\\\").replace('"', '\\"').replace("\n", "\\n") lines.append(f' db._execute_dml("{escaped}", [])') else: lines.append(" pass # no schema changes") lines += [ "", "", "def downgrade(db):", " pass # downgrade not auto-generated", ] return "\n".join(lines) + "\n" # --------------------------------------------------------------------------- # Tracking table # --------------------------------------------------------------------------- def _ensure_tracking_table(db: Database) -> None: """Create the _nextorm_migrations tracking table if it doesn't exist.""" db._execute_dml( f"CREATE TABLE IF NOT EXISTS {_TRACKING_TABLE} " "(version TEXT NOT NULL PRIMARY KEY, applied_at TEXT NOT NULL DEFAULT '')", [], ) def _applied_versions(db: Database) -> dict[str, str]: """Return ``{version: applied_at}`` for every recorded migration.""" rows = db._execute(f"SELECT version, applied_at FROM {_TRACKING_TABLE}", []) return {row[0]: (row[1] or "") for row in rows} def _record_version(db: Database, version: str) -> None: import datetime # noqa: PLC0415 now = datetime.datetime.now(datetime.UTC).isoformat(timespec="seconds") db._execute_dml( f"INSERT INTO {_TRACKING_TABLE} (version, applied_at) VALUES (?, ?)", [version, now], ) # --------------------------------------------------------------------------- # Public API # ---------------------------------------------------------------------------
[docs] def makemigrations( db: Database, name: str = "migration", *, directory: str | Path = "migrations", ) -> Path | None: """Diff the entity schema against the last snapshot and write a migration file. Parameters ---------- db: A bound :class:`~nextorm.database.Database` with schema already built. name: Human-readable suffix for the migration filename. directory: Filesystem directory where migration files are stored. Returns ------- Path The path of the newly written migration file, or ``None`` if there are no schema changes. """ from nextorm.schema.diff import diff_schemas # noqa: PLC0415 directory = Path(directory) directory.mkdir(parents=True, exist_ok=True) if not db._schema: raise RuntimeError( "Database schema is empty. Call generate_mapping() before makemigrations()." ) snapshot = _load_snapshot(directory) # Reconstruct a minimal current-schema Table dict from the snapshot for diffing from nextorm.schema.core import Column, Index, Table # noqa: PLC0415 snapshot_tables: dict[str, Table] = {} for tname, tdata in snapshot.items(): t = Table(name=tname) for col_data in tdata.get("columns", []): t.columns.append( Column( name=col_data["name"], # sentinel: type not stored in snapshot py_type=cast("type[AttrValue]", object), nullable=col_data.get("nullable", False), primary_key=col_data.get("primary_key", False), unique=col_data.get("unique", False), ) ) for idx_data in tdata.get("indexes", []): t.indexes.append( Index( name=idx_data["name"], columns=idx_data["columns"], unique=idx_data.get("unique", False), ) ) snapshot_tables[tname] = t ops = diff_schemas(snapshot_tables, db._schema) if not ops: return None # Render ops to SQL using the database's DDL renderer assert db._renderer is not None sql_stmts = [db._renderer.render(op) for op in ops] num = _next_migration_number(directory) slug = name.lower().replace(" ", "_") filename = f"{num:04d}_{slug}.py" path = directory / filename path.write_text(_generate_migration_body(sql_stmts)) # Update the snapshot _save_snapshot(directory, _schema_to_dict(db._schema)) return path
[docs] def migrate( db: Database, *, directory: str | Path = "migrations", fake: bool = False, ) -> list[str]: """Apply all pending migration files in *directory*. Parameters ---------- db: A bound, connected :class:`~nextorm.database.Database`. directory: Directory containing migration files. fake: If ``True``, record each migration as applied without executing the SQL. Returns ------- list[str] The names of migration files that were applied (or faked). """ directory = Path(directory) if not directory.exists(): return [] _ensure_tracking_table(db) applied = _applied_versions(db) applied_names: list[str] = [] for path in _migration_files(directory): version = path.stem if version in applied: continue if not fake: mod = _load_migration_module(path) if hasattr(mod, "upgrade"): mod.upgrade(db) _record_version(db, version) applied_names.append(path.name) return applied_names
[docs] def showmigrations( db: Database, *, directory: str | Path = "migrations", ) -> list[MigrationStatus]: """Return the status of every migration file in *directory*. Each entry is a :class:`MigrationStatus` describing whether the migration has been applied and, if so, when. Parameters ---------- db: A bound, connected :class:`~nextorm.database.Database`. directory: Directory containing migration files. Returns ------- list[MigrationStatus] One entry per migration file, in ascending version order. When the tracking table does not exist yet, all migrations are reported as pending. """ directory = Path(directory) if not directory.exists(): return [] # Read applied versions — create the tracking table only if it already # exists in the database; we don't want showmigrations to be a side-effect. applied: dict[str, str] = {} try: _ensure_tracking_table(db) applied = _applied_versions(db) except Exception: # noqa: BLE001 pass # DB not yet initialised — treat all as pending statuses: list[MigrationStatus] = [] for path in _migration_files(directory): version = path.stem is_applied = version in applied statuses.append( MigrationStatus( name=path.name, version=version, applied=is_applied, applied_at=applied.get(version, ""), ) ) return statuses
# --------------------------------------------------------------------------- # MigrationRunner — object-oriented interface # ---------------------------------------------------------------------------
[docs] class MigrationRunner: """Object-oriented wrapper for :func:`makemigrations` and :func:`migrate`. Example:: runner = MigrationRunner(db, directory="migrations/") runner.makemigrations(name="add_tags") runner.migrate() """
[docs] def __init__( self, db: Database, *, directory: str | Path = "migrations", ) -> None: self._db = db self.directory = Path(directory)
[docs] def makemigrations(self, name: str = "migration") -> Path | None: """Write a new migration file if there are schema changes.""" return makemigrations(self._db, name, directory=self.directory)
[docs] def migrate(self, *, fake: bool = False) -> list[str]: """Apply all pending migrations.""" return migrate(self._db, directory=self.directory, fake=fake)
[docs] def showmigrations(self) -> list[MigrationStatus]: """Return the status of every migration file in the migrations directory.""" return showmigrations(self._db, directory=self.directory)