Source code for nextorm.database

"""Database — binds a provider and owns a set of entities.

Usage::

    from nextorm import Database, Entity, Req


    class User(Entity):
        name: Req[str]
        email: Req[str]


    db = Database(entities=[User])
    db.bind("sqlite", ":memory:")
    db.generate_mapping(create_tables=True)

    with db_session:
        User(name="alice", email="alice@example.com")  # auto-INSERTed at session end
    # to get PK immediately: flush() inside the session
"""

from __future__ import annotations

import contextlib
import time
import typing
from typing import TYPE_CHECKING, Any, TypeVar, cast

import nextorm.providers  # noqa: F401  # pyright: ignore[reportUnusedImport]
from nextorm.debug import QueryStat, _global_stats_lock, _print_sql, global_stats
from nextorm.entity import Entity, EntityMeta, _entity_registry
from nextorm.exceptions import MappingError, OptimisticCheckError
from nextorm.fields import RelationKind, _generate_ulid, _generate_uuid7, _serialize_value
from nextorm.pool import ConnectionPool
from nextorm.providers.base import (
    _PROVIDER_REGISTRY,
    SyncConnection,
    SyncProvider,
    get_sync_provider,
)
from nextorm.schema.builder import build_schema
from nextorm.schema.core import Table
from nextorm.schema.ddl import DDLRenderer, MariaDBRenderer, PostgresRenderer, SQLiteRenderer
from nextorm.schema.diff import diff_schemas
from nextorm.sql.builder import PARAM_STYLE_BUILDERS, SQLBuilder, SQLiteBuilder
from nextorm.sql.nodes import BinOp, ColumnRef, Delete, Insert, Literal, Param, Select, Update

if TYPE_CHECKING:
    from nextorm.query import QuerySet

__all__ = ["Database"]

T = TypeVar("T", bound=Entity)

# ---------------------------------------------------------------------------
# Global database registry
# ---------------------------------------------------------------------------
# All Database / AsyncDatabase instances that have called generate_mapping()
# are registered here so that Entity.__init__ can locate the db without
# requiring the db object to be a module-level variable.
_database_registry: list[Any] = []

_DDL_RENDERERS: dict[str, type[DDLRenderer]] = {
    "sqlite": SQLiteRenderer,
    "postgres": PostgresRenderer,
    "mariadb": MariaDBRenderer,
}


# ---------------------------------------------------------------------------
# Composite-PK helpers (shared by Database and the async counterpart)
# ---------------------------------------------------------------------------


def _pk_col_for(entity_cls: type[Entity], field_name: str) -> str:
    """Return the SQL column name for a PK field name.

    For a scalar field this is ``FieldSpec.column or field_name``.
    For a relation field it is ``<field_name>_id``.
    """
    if field_name in entity_cls._fields_:
        return entity_cls._fields_[field_name].spec.column or field_name
    # Relation field — column is <rel>_id
    return f"{field_name}_id"


def _pk_val_for(entity: Entity, field_name: str) -> Any:
    """Return the current PK column value for *field_name* on *entity*.

    Scalar fields are read via ``getattr``; relation fields read the stored
    FK id from ``entity.__dict__``.
    """
    entity_cls = type(entity)
    if field_name in entity_cls._fields_:
        return getattr(entity, field_name)
    return entity.__dict__.get(f"_{field_name}_id")


def _get_pk_val(entity: Entity) -> Any:
    """Return the primary-key value for *entity* (scalar or tuple).

    Returns ``None`` if the entity has no PK or if any PK part is ``None``
    (meaning the entity has not been saved yet).
    """
    pk_fields = type(entity)._pk_fields_
    if not pk_fields:
        return None
    vals = [_pk_val_for(entity, f) for f in pk_fields]
    if any(v is None for v in vals):
        return None
    return vals[0] if len(vals) == 1 else tuple(vals)


def _build_pk_where(entity_cls: type[Entity], pk_val: Any) -> BinOp:
    """Build a SQL WHERE clause for a single or composite primary key.

    *pk_val* must be the scalar value (single PK) or a tuple of values
    (composite PK) in the same order as ``entity_cls._pk_fields_``.
    """
    pk_fields = entity_cls._pk_fields_
    if len(pk_fields) == 1:
        col = _pk_col_for(entity_cls, pk_fields[0])
        return BinOp(ColumnRef(col), "=", Param(value=pk_val))
    # Composite: AND-chain all parts
    vals: tuple[Any, ...] = pk_val  # must be a tuple
    node: BinOp = BinOp(ColumnRef(_pk_col_for(entity_cls, pk_fields[0])), "=", Param(value=vals[0]))
    for fname, fval in zip(pk_fields[1:], vals[1:], strict=False):
        node = BinOp(
            node,
            "AND",
            BinOp(ColumnRef(_pk_col_for(entity_cls, fname)), "=", Param(value=fval)),
        )
    return node


[docs] class Database: """Owns a provider binding, a persistent connection, and entity schema. The database keeps a single open connection so that in-memory SQLite databases (``":memory:"``) remain accessible across multiple operations. Call :meth:`close` (or use the database as a context manager) to release the connection when done. Parameters ---------- entities: Explicit list of entity classes to include. When omitted, *all* :class:`~nextorm.entity.Entity` subclasses registered via :data:`~nextorm.entity._entity_registry` are used. """ #: Distinguishes this class from :class:`~nextorm.async_database.AsyncDatabase` #: without requiring a circular import. Checked by ``FieldDescriptor.__get__`` to #: decide whether a lazy-field load can be executed synchronously. _is_async: bool = False
[docs] def __init__(self, entities: list[type[Entity]] | None = None) -> None: self._entities: list[type[Entity]] | None = list(entities) if entities is not None else None self._provider: str | None = None self._connect_args: tuple[Any, ...] = () self._connect_kwargs: dict[str, Any] = {} self._renderer: DDLRenderer | None = None self._sync_provider_instance: SyncProvider | None = None self._builder: SQLBuilder | None = None self._connection: SyncConnection | None = None self._pool: ConnectionPool | None = None self._schema: dict[str, Table] = {} self._bound: bool = False self._last_sql: str = "" self._local_stats: dict[str, QueryStat] = {}
# ------------------------------------------------------------------ # Context manager # ------------------------------------------------------------------ def __enter__(self) -> Database: return self def __exit__(self, *_: object) -> None: self.close() def __del__(self) -> None: self.close() # ------------------------------------------------------------------ # Entity registration # ------------------------------------------------------------------
[docs] def register(self, *entity_classes: type[Entity]) -> None: """Add one or more entity classes to this database's scope. Entities register themselves globally at import time; call this method to opt specific classes into a database that was constructed with an explicit ``entities=[...]`` list, or to add entities after construction. Calling ``register()`` with an entity that is already registered is a no-op. Each argument must be an :class:`~nextorm.entity.Entity` subclass — passing anything else raises :exc:`TypeError`. Example:: db = Database(entities=[User]) db.register(Post, Comment) # add more after the fact db.bind("sqlite", ":memory:") db.generate_mapping(create_tables=True) """ for cls in entity_classes: if not isinstance(cls, EntityMeta): # pyright: ignore[reportUnnecessaryIsInstance] raise TypeError(f"{cls!r} is not an Entity subclass") if self._entities is None: self._entities = [] if cls not in self._entities: self._entities.append(cls)
@property def entities(self) -> dict[str, type[Entity]]: """Return ``{entity_name: entity_class}`` for all entities in this DB.""" return {cls.__name__: cast("type[Entity]", cls) for cls in self._effective_entities()} def _effective_entities(self) -> list[EntityMeta]: """Return the entity list, falling back to the global registry.""" if self._entities is not None: return list(self._entities) return sorted( (cls for cls in _entity_registry), key=lambda c: c.__name__, ) # ------------------------------------------------------------------ # Provider binding # ------------------------------------------------------------------
[docs] def bind( self, provider: str, *args: Any, pool_min: int = 0, pool_max: int = 1, pool_timeout: float = 30.0, **kwargs: Any, ) -> None: """Bind the database to a named provider. Parameters ---------- provider: Registered provider name (``"sqlite"``, ``"postgres"``, ``"mariadb"``). \\*args / \\*\\*kwargs: Connection arguments forwarded to the provider (e.g. the database path for SQLite, the DSN for PostgreSQL). pool_min: Minimum (pre-created) pool connections. ``0`` means no connections are opened until the first query. pool_max: Maximum pool size. ``1`` (the default) keeps the legacy single-persistent-connection behaviour. Values ``> 1`` enable true pooling — each operation checks out a connection from the pool and returns it afterwards. pool_timeout: Seconds to wait for a free connection before :exc:`~nextorm.pool.PoolTimeoutError` is raised. """ if provider not in _PROVIDER_REGISTRY: raise ValueError( f"Unknown provider {provider!r}. Available: {sorted(_PROVIDER_REGISTRY)}" ) self._provider = provider self._connect_args = args self._connect_kwargs = kwargs if provider in _DDL_RENDERERS: self._renderer = _DDL_RENDERERS[provider]() self._sync_provider_instance = get_sync_provider(provider)() builder_cls = PARAM_STYLE_BUILDERS.get( self._sync_provider_instance.param_style, SQLiteBuilder ) self._builder = builder_cls() self._bound = True if pool_max > 1 or pool_min > 0: provider_inst = self._sync_provider_instance def _factory() -> SyncConnection: return provider_inst.connect(*args, **kwargs) self._pool = ConnectionPool( _factory, min_size=pool_min, max_size=pool_max, timeout=pool_timeout )
[docs] def unbind(self) -> None: """Close the connection and clear the provider binding.""" self.close() # close() already deregisters from _database_registry self._provider = None self._connect_args = () self._connect_kwargs = {} self._renderer = None self._sync_provider_instance = None self._builder = None self._pool = None self._schema = {} self._bound = False
@property def is_bound(self) -> bool: """Return ``True`` when :meth:`bind` has been called successfully.""" return self._bound # ------------------------------------------------------------------ # Connection management # ------------------------------------------------------------------ def _ensure_connection(self) -> SyncConnection: """Return the open connection, opening one if necessary. When a :class:`~nextorm.pool.ConnectionPool` is configured, returns a *checked-out* connection from the pool. The caller is responsible for releasing it via :meth:`_release_connection`. """ if not self._bound: raise RuntimeError("Database is not bound to a provider. Call bind() first.") if self._pool is not None: return cast("SyncConnection", self._pool.acquire()) assert self._sync_provider_instance is not None if self._connection is None: self._connection = self._sync_provider_instance.connect( *self._connect_args, **self._connect_kwargs ) return self._connection def _release_connection(self, conn: SyncConnection) -> None: """Return *conn* to the pool (no-op when pool is not configured).""" if self._pool is not None: self._pool.release(conn)
[docs] def close(self) -> None: """Close the persistent connection (if open) or all pool connections.""" if self._pool is not None: self._pool.close_all() elif self._connection is not None: self._connection.close() self._connection = None with contextlib.suppress(ValueError): _database_registry.remove(self)
[docs] def commit(self) -> None: """Flush pending changes then commit the current database transaction. Calls :meth:`flush` first so all pending inserts/updates are written, then finalises the underlying connection's transaction. This mirrors PonyORM's approach where ``commit()`` integrates the flush step internally — callers never need to flush manually before committing. """ self.flush() self._commit_transaction()
def _commit_transaction(self) -> None: """Issue SQL COMMIT without an implicit flush. Used by the staged-commit logic in the module-level :func:`~nextorm.commit` and the session-exit path so that flush and commit can be separated (flush all first, then commit primary, then commit secondaries). """ conn = self._ensure_connection() try: conn.commit() finally: self._release_connection(conn) def _rollback_transaction(self) -> None: """Issue SQL ROLLBACK without clearing the session cache. Unlike :meth:`rollback` (which also clears the session identity map), this low-level variant is used during partial-commit cleanup where the session cache will be cleared separately after all databases are processed. """ conn = self._ensure_connection() try: conn.rollback() finally: self._release_connection(conn)
[docs] def rollback(self) -> None: """Roll back the current database transaction and clear the session cache. Discards uncommitted work on the active connection and wipes the identity map of the current :func:`~nextorm.session.db_session`. When called outside a session the cache-clear step is a no-op. """ self._rollback_transaction() # Also clear the session identity map so stale objects aren't reused. from nextorm.session import _get_session_stack # noqa: PLC0415 cache = _get_session_stack().current if cache is not None: cache.clear()
[docs] def flush(self) -> None: """Write all pending dirty and new objects in the current session to the DB. Iterates the current :func:`~nextorm.session.db_session` cache and calls :meth:`save` on every object that has been scheduled for INSERT or marked dirty that belongs to *this* database. Because NextORM auto-commits per DML statement, each ``save`` is independently committed. When called outside a session this is a no-op. """ from nextorm.session import _get_session_stack # noqa: PLC0415 cache = _get_session_stack().current if cache is None: return for entity in list(cache.objects_to_save): entity_db = vars(entity).get("_db_") # Save if this is the right DB, or if the entity has no DB yet # (unmapped entity — _require_mapped will raise the appropriate error). if entity_db is None or entity_db is self: self.save(entity) for entity in list(cache.dirty_objects): entity_db = vars(entity).get("_db_") if entity_db is None or entity_db is self: self.save(entity)
# ------------------------------------------------------------------ # Schema mapping # ------------------------------------------------------------------
[docs] def generate_mapping( self, *, create_tables: bool = False, validate_relations: bool = False ) -> None: """Build the internal schema and optionally create database tables. Must be called after :meth:`bind`. Registers this database instance so that :func:`~nextorm.entity._find_db_for_entity` can locate it. Parameters ---------- create_tables: When ``True``, execute ``CREATE TABLE`` DDL for every entity in the schema. The connection is kept open so that in-memory SQLite databases (``":memory:"``) remain accessible after the call. validate_relations: When ``True``, every declared ``Set[T]`` relation is checked to have a matching ``Single[T]`` back-reference on the target entity. Ambiguous cases (multiple candidates with no ``reverse=``) raise :exc:`~nextorm.exceptions.MappingError`. Example:: db = Database(entities=[User, Post]) db.bind("sqlite", ":memory:") db.generate_mapping(create_tables=True) """ if not self._bound: raise RuntimeError("Database is not bound to a provider. Call bind() first.") if validate_relations: self._validate_relations() self._schema = build_schema(cast("list[type[Entity]]", self._effective_entities())) # Register so Entity.__init__ can auto-locate this database. if self not in _database_registry: _database_registry.append(self) if create_tables: assert self._renderer is not None self._ddl_statements = [ self._renderer.create_table(table) for table in self._schema.values() ] assert self._sync_provider_instance is not None conn = self._ensure_connection() try: self._sync_provider_instance.execute_ddl(conn, self._ddl_statements) finally: self._release_connection(conn)
@property def schema(self) -> dict[str, Table]: """Current schema mapping (empty until :meth:`generate_mapping` is called).""" return dict(self._schema)
[docs] def migrate(self) -> list[str]: """Apply pending schema changes to the live database. Introspects the current database schema, computes the diff against the entity-derived target schema (built by :meth:`generate_mapping`), and executes each pending DDL operation (``CREATE TABLE``, ``ALTER TABLE … ADD/DROP COLUMN``, ``CREATE/DROP INDEX``, etc.). Returns the list of SQL statements that were executed. An empty list means the database is already up to date. Raises :exc:`RuntimeError` if the database is not bound to a provider or if :meth:`generate_mapping` has not been called yet. """ if not self._bound: raise RuntimeError("Database is not bound to a provider. Call bind() first.") if not self._schema: raise RuntimeError("Schema is empty. Call generate_mapping() before migrate().") assert self._sync_provider_instance is not None assert self._renderer is not None conn = self._ensure_connection() try: current = self._sync_provider_instance.introspect(conn) ops = diff_schemas(current, self._schema) stmts = [self._renderer.render(op) for op in ops] if stmts: self._sync_provider_instance.execute_ddl(conn, stmts) finally: self._release_connection(conn) return stmts
[docs] def get_ddl(self) -> list[str]: """DDL statements from the last :meth:`generate_mapping` call.""" return list(getattr(self, "_ddl_statements", []))
# ------------------------------------------------------------------ # Query API # ------------------------------------------------------------------
[docs] def select(self, entity_class: type[T]) -> QuerySet[T]: """Return a :class:`~nextorm.query.QuerySet` for *entity_class*. :meth:`generate_mapping` must have been called first. """ if not self._schema: raise RuntimeError("Schema is empty. Call generate_mapping() before select().") table_name = entity_class._table_name_ table = self._schema.get(table_name) if table is None: raise RuntimeError( f"Entity {entity_class.__name__!r} is not in the mapped schema. " "Ensure it was included when bind() / generate_mapping() was called." ) assert self._builder is not None from nextorm.query import QuerySet as _QuerySet # noqa: PLC0415 from nextorm.sql.nodes import BinOp, ColumnRef, Param # noqa: PLC0415 qs: QuerySet[T] = _QuerySet(entity_class, table, self, self._builder) # STI: automatically filter by discriminator value for child entities disc_val = entity_class._discriminator_val_ if disc_val is not None and entity_class._sti_parent_ is not None: disc_col = entity_class._sti_parent_._discriminator_col_ if disc_col is not None: # pragma: no branch qs = qs.filter(BinOp(ColumnRef(disc_col), "=", Param(value=disc_val))) return qs
# ------------------------------------------------------------------ # Persistence # ------------------------------------------------------------------
[docs] def save(self, entity: Entity) -> None: """Insert or update *entity* in the database. * All PK columns are ``None`` → INSERT (auto-PK written back when single auto PK). * At least one PK is set → UPDATE all non-PK columns. Lifecycle hooks are called around each operation. When a :func:`~nextorm.session.db_session` is active the entity is registered in its identity map after a successful save. """ entity_cls = type(entity) self._require_mapped(entity_cls) table = self._schema[entity_cls._table_name_] pk_val = _get_pk_val(entity) # For composite PKs the user supplies all PK values before the first save. # Use _dbvals_ presence to distinguish new (not-yet-inserted) from existing. # We cannot use _db_ because Entity.__init__ sets _db_ immediately when # created inside a db_session (before the first actual DB write). is_new = pk_val is None or ( entity_cls._pk_field_ is None and "_dbvals_" not in vars(entity) ) try: if is_new: entity.before_insert() self._do_insert(entity, entity_cls, table) entity.after_insert() else: entity.before_update() self._do_update(entity, entity_cls, table, pk_val) entity.after_update() except Exception: # Save failed — remove entity from session tracking so flush() won't retry. from nextorm.session import _get_session_stack # noqa: PLC0415 cache = _get_session_stack().current if cache is not None: cache.unmark_dirty(entity) cache.unschedule_save(entity) raise self._post_save(entity, pk_val if pk_val is not None else _get_pk_val(entity))
[docs] def insert(self, entity: Entity) -> None: """Always INSERT *entity*, even when its primary key is already set. Unlike :meth:`save`, which uses the PK value to choose between INSERT and UPDATE, this method always performs an INSERT. Use it for: * Entities with user-assigned (non-auto) primary keys. * Data loading / migration where you know the row does not yet exist. * Inserting a copy of an entity with a freshly assigned identity. When the PK is auto-generated (``FieldSpec(primary_key=True, auto=True)``): * PK is ``None`` → the database generates one, which is written back. * PK is a concrete value → that value is passed in the INSERT, overriding the auto-increment counter (useful for data migration). Lifecycle hooks ``before_insert`` / ``after_insert`` are called. """ entity_cls = type(entity) self._require_mapped(entity_cls) table = self._schema[entity_cls._table_name_] entity.before_insert() self._do_insert(entity, entity_cls, table, include_auto_pk=True) entity.after_insert() self._post_save(entity, _get_pk_val(entity))
# ------------------------------------------------------------------ # Session / identity-map bookkeeping (shared by save + insert) # ------------------------------------------------------------------ def _post_save(self, entity: Entity, pk_val: Any) -> None: """Register *entity* in the active session's identity map after a write.""" from nextorm.session import _get_session_stack # noqa: PLC0415 cache = _get_session_stack().current if cache is not None and pk_val is not None: cache.put(entity, pk_val) cache.unmark_dirty(entity) cache.unschedule_save(entity) # remove from INSERT queue after successful save
[docs] def delete_instance(self, entity: Entity) -> None: """Delete *entity* from the database and clear its primary key. Raises :exc:`ValueError` when the entity has not been saved (PK is ``None``). """ entity_cls = type(entity) self._require_mapped(entity_cls) pk_val = _get_pk_val(entity) pk_fields = entity_cls._pk_fields_ if not pk_fields: raise RuntimeError(f"Entity {entity_cls.__name__!r} has no primary-key field.") if pk_val is None: raise ValueError( f"Cannot delete {entity_cls.__name__!r}: primary key is None " "(entity has not been saved)." ) table = self._schema[entity_cls._table_name_] stmt = Delete( table=table.name, where=_build_pk_where(entity_cls, pk_val), ) assert self._builder is not None sql, params = self._builder.render(stmt) entity.before_delete() self._execute_dml(sql, params) # Clear all PK fields for fname in pk_fields: if fname in entity_cls._fields_: setattr(entity, fname, None) else: vars(entity).pop(f"_{fname}_id", None) entity.after_delete() # Remove from the active session's identity map from nextorm.session import _get_session_stack # noqa: PLC0415 cache = _get_session_stack().current if cache is not None: cache.remove(entity, pk_val)
# ------------------------------------------------------------------ # Low-level SQL execution (used by QuerySet + save/delete) # ------------------------------------------------------------------
[docs] def get_connection(self) -> Any: """Return the raw underlying DBAPI connection. Use this as an escape hatch when you need provider-specific features not covered by the NextORM API. The returned type depends on the active provider (e.g. ``sqlite3.Connection`` for SQLite). Raises :exc:`RuntimeError` when the database is not connected. """ conn = self._ensure_connection() # Unwrap the thin SyncConnection wrapper to reach the DBAPI connection. raw = getattr(conn, "_conn", conn) return raw
[docs] def execute(self, sql: str, *args: Any) -> int: """Execute arbitrary SQL and return the number of affected rows. *args* are passed as positional parameters to the DBAPI cursor (i.e. bound as ``?`` or ``%s`` placeholders depending on the provider). This is useful for DDL statements or DML that cannot be expressed via the QuerySet API:: db.execute("CREATE INDEX idx_user_email ON user (email)") db.execute("DELETE FROM session WHERE expires_at < ?", cutoff) """ return self._execute_dml(sql, list(args))
[docs] def select_raw(self, sql: str, *args: Any) -> list[dict[str, Any]]: """Execute a raw SELECT and return the results as a list of dicts. Column names are taken from the cursor description. Use this when you need a join result or aggregate that doesn't map to a single entity class:: rows = db.select_raw( "SELECT u.name, COUNT(o.id) AS order_count " 'FROM user u LEFT JOIN "order" o ON o.user_id = u.id ' "GROUP BY u.id" ) for row in rows: print(row["name"], row["order_count"]) """ conn = self._ensure_connection() try: cur = conn.cursor() cur.execute(sql, list(args)) desc = cur.description if desc is None: # pragma: no cover return [] col_names = [d[0] for d in desc] return [dict(zip(col_names, row, strict=False)) for row in cur.fetchall()] finally: self._release_connection(conn)
def _execute(self, sql: str, params: list[Any]) -> list[tuple[Any, ...]]: """Execute *sql* with *params* and return all rows.""" self._last_sql = sql _print_sql(sql, params) conn = self._ensure_connection() try: cur = conn.cursor() t0 = time.perf_counter() cur.execute(sql, params) rows = list(cur.fetchall()) self._record_stat(sql, time.perf_counter() - t0) return rows finally: self._release_connection(conn) def _execute_described( self, sql: str, params: list[Any] ) -> tuple[list[tuple[Any, ...]], list[str]]: """Execute *sql* and return ``(rows, column_names)``.""" self._last_sql = sql _print_sql(sql, params) conn = self._ensure_connection() try: cur = conn.cursor() t0 = time.perf_counter() cur.execute(sql, params) rows = list(cur.fetchall()) self._record_stat(sql, time.perf_counter() - t0) col_names = [d[0] for d in cur.description] if cur.description else [] return rows, col_names finally: self._release_connection(conn) def _execute_dml(self, sql: str, params: list[Any]) -> int: """Execute a DML statement and return the affected row count.""" self._last_sql = sql _print_sql(sql, params) conn = self._ensure_connection() try: cur = conn.cursor() t0 = time.perf_counter() cur.execute(sql, params) count = cur.rowcount # Auto-commit only outside a db_session; inside a session the commit # happens at session end (PonyORM semantics). from nextorm.session import _get_session_stack # noqa: PLC0415 if _get_session_stack().current is None: conn.commit() self._record_stat(sql, time.perf_counter() - t0) return count finally: self._release_connection(conn) def _execute_insert(self, sql: str, params: list[Any]) -> int | None: """Execute an INSERT and return the auto-generated row-id.""" self._last_sql = sql _print_sql(sql, params) conn = self._ensure_connection() try: cur = conn.cursor() t0 = time.perf_counter() cur.execute(sql, params) rowid = cur.lastrowid # Auto-commit only outside a db_session. from nextorm.session import _get_session_stack # noqa: PLC0415 if _get_session_stack().current is None: conn.commit() self._record_stat(sql, time.perf_counter() - t0) return rowid finally: self._release_connection(conn) @property def last_sql(self) -> str: """The last SQL string sent to the database. Updated after every ``_execute``, ``_execute_dml``, and ``_execute_insert`` call. Empty string if no query has been executed yet. Useful for debugging:: db.select(User).filter(User.age > 18).fetch_all() print(db.last_sql) """ return self._last_sql # ------------------------------------------------------------------ # Query statistics # ------------------------------------------------------------------ @property def local_stats(self) -> dict[str, QueryStat]: """Per-database query statistics since the last :meth:`clear_local_stats` call. Returns a snapshot copy keyed by SQL string. Each value is a :class:`~nextorm.debug.QueryStat` with ``count``, ``sum_time``, ``min_time``, ``max_time``, and ``avg_time`` attributes. """ return dict(self._local_stats)
[docs] def clear_local_stats(self) -> None: """Reset per-database query statistics.""" self._local_stats.clear()
[docs] def merge_local_stats(self) -> None: """Merge per-database stats into the module-level :data:`~nextorm.debug.global_stats`. Call this periodically (e.g. at the end of a request) to accumulate statistics across multiple :class:`Database` instances. """ with _global_stats_lock: for sql, stat in self._local_stats.items(): if sql not in global_stats: global_stats[sql] = QueryStat() global_stats[sql]._merge(stat)
def _record_stat(self, sql: str, elapsed: float) -> None: """Record one query execution in :attr:`_local_stats`.""" if sql not in self._local_stats: self._local_stats[sql] = QueryStat() self._local_stats[sql]._record(elapsed) def _load_lazy_field(self, entity: Entity, field_name: str) -> Any: """Execute a per-field SELECT to load a single lazy field value. Called automatically by :class:`~nextorm.entity.FieldDescriptor` when a lazy field is accessed on an entity that was loaded by this database. """ entity_cls = type(entity) fi = entity_cls._fields_[field_name] col_name = fi.spec.column or field_name pk_val = _get_pk_val(entity) if pk_val is None: return None where = _build_pk_where(entity_cls, pk_val) stmt = Select( columns=(ColumnRef(col_name),), from_table=entity_cls._table_name_, where=where, ) assert self._builder is not None sql, params = self._builder.render(stmt) rows = self._execute(sql, params) return rows[0][0] if rows else None # ------------------------------------------------------------------ # INSERT / UPDATE helpers # ------------------------------------------------------------------ def _do_insert( self, entity: Entity, entity_cls: type[Entity], table: Table, *, include_auto_pk: bool = False, ) -> None: # Determine the single auto-PK field name (None for composite PKs) single_pk_field = entity_cls._pk_field_ # None when composite # Auto-generate UUID / ULID PK values before the column list is built. for fi in entity_cls._fields_.values(): if fi.spec.uuid_auto is not None and getattr(entity, fi.name) is None: import uuid as _uuid # noqa: PLC0415 if fi.spec.uuid_auto == "v7": setattr(entity, fi.name, _generate_uuid7()) elif fi.spec.uuid_auto == "v4": setattr(entity, fi.name, _uuid.uuid4()) else: # "ulid" setattr(entity, fi.name, _generate_ulid()) # Exclude auto-PK from column list unless include_auto_pk=True AND it has a value. cols_and_vals: list[tuple[str, Any]] = [ (fi.spec.column or fi.name, _serialize_value(getattr(entity, fi.name))) for fi in entity_cls._fields_.values() if not ( fi.spec.primary_key and fi.spec.auto and (not include_auto_pk or getattr(entity, fi.name) is None) ) ] # Also persist FK values from Single relations (owning side) for ri in entity_cls._relations_.values(): if ri.spec.kind == RelationKind.SINGLE: fk_col = ri.spec.column or f"{ri.name}_id" fk_val = entity.__dict__.get(f"_{ri.name}_id") cols_and_vals.append((fk_col, fk_val)) # STI: inject the discriminator column value for child entities disc_val = entity_cls._discriminator_val_ if disc_val is not None and entity_cls._sti_parent_ is not None: disc_col = entity_cls._sti_parent_._discriminator_col_ if disc_col is not None: # pragma: no branch cols_and_vals.append((disc_col, disc_val)) assert self._builder is not None if cols_and_vals: stmt = Insert( table=table.name, columns=tuple(c for c, _ in cols_and_vals), values=tuple(Param(value=v) for _, v in cols_and_vals), ) sql, params = self._builder.render(stmt) else: # Entity has only an auto-pk — syntax varies by backend: # sqlite: INSERT INTO t VALUES (NULL) # mariadb: INSERT INTO t VALUES () # postgres: INSERT INTO t DEFAULT VALUES (standard SQL) if self._provider == "sqlite": sql, params = f"INSERT INTO {table.name} VALUES (NULL)", [] elif self._provider == "mariadb": sql, params = f"INSERT INTO {table.name} VALUES ()", [] else: # pragma: no cover — PostgreSQL DEFAULT VALUES; no live PG in test suite sql, params = f"INSERT INTO {table.name} DEFAULT VALUES", [] rowid = self._execute_insert(sql, params) # Write the auto-generated PK back only for DB-side auto-increment PKs. # UUID / ULID PKs (spec.uuid_auto set, spec.auto=False) are already written # above before the INSERT — do not overwrite them with lastrowid. if single_pk_field is not None and rowid is not None: pk_fi = entity_cls._fields_.get(single_pk_field) if pk_fi is not None and pk_fi.spec.auto: setattr(entity, single_pk_field, rowid) # Attach database context so lazy-load works on newly saved entities vars(entity)["_db_"] = self # Set _dbvals_ so that subsequent attribute mutations are auto-dirtied. dbvals: dict[str, Any] = {} for fi in entity_cls._fields_.values(): if not fi.spec.primary_key and not fi.spec.volatile: col = fi.spec.column or fi.name dbvals[col] = getattr(entity, fi.name, None) for ri in entity_cls._relations_.values(): if ri.spec.kind == RelationKind.SINGLE: dbvals[ri.spec.column or f"{ri.name}_id"] = vars(entity).get(f"_{ri.name}_id") vars(entity)["_dbvals_"] = dbvals if "_read_cols_" not in vars(entity): vars(entity)["_read_cols_"] = set() def _do_update( self, entity: Entity, entity_cls: type[Entity], table: Table, pk_val: Any, ) -> None: # Snapshot _read_cols_ immediately — the assignments listcomp below calls # getattr on every field descriptor, which would otherwise pollute # _read_cols_ with all SET-clause columns before we inspect it. read_cols_snapshot: frozenset[str] = frozenset(vars(entity).get("_read_cols_") or ()) dbvals: dict[str, Any] | None = vars(entity).get("_dbvals_") # For composite PKs, scalar PK fields are not marked fi.spec.primary_key. # Exclude them explicitly so the UPDATE SET clause only touches non-PK columns. scalar_pk_names: frozenset[str] = frozenset( f for f in entity_cls._pk_fields_ if f in entity_cls._fields_ ) assignments: list[tuple[str, Param]] = [ (fi.spec.column or fi.name, Param(value=_serialize_value(getattr(entity, fi.name)))) for fi in entity_cls._fields_.values() if not fi.spec.primary_key and not fi.spec.volatile and fi.name not in scalar_pk_names ] # Also update FK values from Single relations (owning side), # but skip relation fields that are part of the composite PK (they are identity) pk_rel_names = {f for f in entity_cls._pk_fields_ if f not in entity_cls._fields_} for ri in entity_cls._relations_.values(): if ri.spec.kind == RelationKind.SINGLE and ri.name not in pk_rel_names: fk_col = ri.spec.column or f"{ri.name}_id" fk_val = entity.__dict__.get(f"_{ri.name}_id") assignments.append((fk_col, Param(value=fk_val))) # Option A — per-field optimistic concurrency check. # Only columns the caller actually READ (via descriptor __get__) and that # exist in _dbvals_ are checked; PK and untracked columns are silently skipped. where = _build_pk_where(entity_cls, pk_val) use_optimistic = bool(read_cols_snapshot) and dbvals is not None if use_optimistic: # Respect db_session(optimistic=False) to let callers opt out. from nextorm.session import _get_session_stack # noqa: PLC0415 session = _get_session_stack().current if session is not None and not session.optimistic: use_optimistic = False if use_optimistic: for col in read_cols_snapshot: if col not in dbvals: # type: ignore[operator] # PK / unknown — skip continue orig_val = dbvals[col] # type: ignore[index] if orig_val is None: where = BinOp(where, "AND", BinOp(ColumnRef(col), "IS", Literal(None))) else: where = BinOp(where, "AND", BinOp(ColumnRef(col), "=", Param(value=orig_val))) stmt = Update( table=table.name, assignments=tuple(assignments), where=where, ) assert self._builder is not None sql, params = self._builder.render(stmt) rows_affected = self._execute_dml(sql, params) if use_optimistic and rows_affected == 0: raise OptimisticCheckError( f"Concurrent update detected for {entity_cls.__name__!r} " f"pk={pk_val!r}: row was modified by another transaction." ) # ------------------------------------------------------------------ # Guard # ------------------------------------------------------------------ def _require_mapped(self, entity_cls: type[Entity]) -> None: """Raise :exc:`RuntimeError` if *entity_cls* is not in the schema.""" if entity_cls._table_name_ not in self._schema: raise RuntimeError( f"Entity {entity_cls.__name__!r} is not in the mapped schema. " "Call generate_mapping() first." ) # ------------------------------------------------------------------ # Relation validation # ------------------------------------------------------------------ def _validate_relations(self) -> None: """Validate that all ``Set[T]`` relations have proper back-references. Raises :exc:`~nextorm.exceptions.MappingError` on: * A ``Set[T]`` with no back-reference on ``T`` (missing ``Single`` or ``Set`` on the other side). * Multiple relations between the same two entities when neither side carries ``reverse=`` to disambiguate which back-reference to use. """ entities = self._effective_entities() entity_by_name: dict[str, EntityMeta] = {e.__name__.lower(): e for e in entities} def _resolve(target: type[Entity] | str | typing.ForwardRef | None) -> EntityMeta | None: if target is None: # pragma: no cover return None if isinstance(target, str): # pragma: no cover return entity_by_name.get(target.lower()) if isinstance(target, typing.ForwardRef): return entity_by_name.get(target.__forward_arg__.lower()) return cast("EntityMeta", target) for entity_cls in entities: for ri in entity_cls._relations_.values(): if ri.spec.kind != RelationKind.SET: continue target_cls = _resolve(ri.spec.target) if target_cls is None: # Unresolvable forward reference — not a validation error continue # Collect all back-refs on target_cls pointing at entity_cls back_refs = [ r for r in target_cls._relations_.values() if _resolve(r.spec.target) is entity_cls ] if not back_refs: from nextorm.entity import _target_name # noqa: PLC0415 target_name = _target_name(ri.spec.target) or repr(ri.spec.target) raise MappingError( f"{entity_cls.__name__}.{ri.name}: Set[{target_name}] requires a " f"back-reference on {target_name} pointing at {entity_cls.__name__}. " f"Add Single[{entity_cls.__name__}] or " f"Set[{entity_cls.__name__}] to {target_name}." ) # Ambiguity: multiple back-refs on the same target without reverse= if len(back_refs) > 1 and ri.spec.reverse is None: ambiguous = [r.name for r in back_refs if r.spec.reverse is None] if len(ambiguous) > 1: # pragma: no branch raise MappingError( f"{entity_cls.__name__}.{ri.name}: ambiguous — " f"{target_cls.__name__} has multiple back-references to " f"{entity_cls.__name__} ({', '.join(ambiguous)}). " f"Add reverse='<attr>' on one side to disambiguate." )