Source code for nextorm.session

"""Session management — identity map, dirty tracking, db_session context manager.

Usage (synchronous)::

    from nextorm.session import db_session

    with db_session:
        user = cache.get((User, 1))

Usage (async)::

    from nextorm.session import db_session

    async with db_session:
        ...

Usage as a decorator::

    @db_session
    def my_view(): ...


    @db_session
    async def my_async_view(): ...

Parametrised form::

    @db_session(retry=3)
    def my_view(): ...


    with db_session(sql_debug=True):
        ...
"""

from __future__ import annotations

import functools
import inspect
import threading
from contextlib import asynccontextmanager, contextmanager, suppress
from typing import TYPE_CHECKING, Any, Self, overload

if TYPE_CHECKING:
    from collections.abc import AsyncGenerator, Callable, Generator

from nextorm.entity import Entity
from nextorm.exceptions import TransactionError

__all__ = ["SessionCache", "db_session"]


def _collect_dbs(cache: SessionCache | None) -> list[Any]:
    """Collect unique Database / AsyncDatabase instances referenced in *cache*."""
    assert cache is not None  # always called within an active session
    seen: set[int] = set()
    dbs: list[Any] = []
    # Check identity map, pending-save queue, and dirty set
    all_entities = list(cache._objects.values()) + list(cache._to_save) + list(cache._dirty)
    for entity in all_entities:
        db = vars(entity).get("_db_")
        if db is not None and id(db) not in seen:
            seen.add(id(db))
            dbs.append(db)
    return dbs


[docs] class SessionCache: """Identity map and dirty-object tracker for a single database session. The cache is keyed by ``(entity_class, primary_key_value)`` which ensures that at most one Python object exists per database row inside a session. Attributes ---------- _objects: The identity map: ``{(entity_cls, pk): entity_instance}``. _dirty: Set of entity instances that have been modified and need flushing. _to_save: Instances scheduled for ``INSERT`` (not yet persisted). _modified_collections: Tracks M2M collections that have pending additions/removals. """
[docs] def __init__(self) -> None: self._objects: dict[tuple[type[Entity], Any], Entity] = {} self._dirty: set[Entity] = set() self._to_save: list[Entity] = [] self._modified_collections: dict[tuple[Entity, str], dict[str, list[Entity]]] = {} #: Set by :class:`DBSessionManager` when ``serializable=True`` was requested. self.serializable: bool = False #: Set by :class:`DBSessionManager` when ``immediate=True`` was requested. self.immediate: bool = False #: Set to ``False`` by :class:`DBSessionManager` when ``optimistic=False`` was requested. #: When ``True`` (the default), :meth:`~nextorm.database.Database.save` adds #: ``AND col = orig_val`` clauses for the columns the caller has READ since load, #: implementing Pony-style per-field optimistic concurrency detection. self.optimistic: bool = True #: When ``True``, entity objects should not be accessed after the session ends. #: Currently the cache is always cleared; ``strict=False`` retains the same #: behaviour — no cross-session lazy-load is supported yet. self.strict: bool = False
# ------------------------------------------------------------------ # Identity map # ------------------------------------------------------------------
[docs] def get(self, key: tuple[type[Entity], Any]) -> Entity | None: """Return the cached instance for *key*, or ``None``.""" return self._objects.get(key)
[docs] def put(self, entity: Entity, pk: Any) -> None: """Add *entity* to the identity map under its ``(type, pk)`` key.""" self._objects[(type(entity), pk)] = entity
[docs] def remove(self, entity: Entity, pk: Any) -> None: """Remove *entity* from the identity map.""" self._objects.pop((type(entity), pk), None) self._dirty.discard(entity)
[docs] def clear(self) -> None: """Wipe all cached state (called on session commit / rollback).""" self._objects.clear() self._dirty.clear() self._to_save.clear() self._modified_collections.clear()
# ------------------------------------------------------------------ # Dirty tracking # ------------------------------------------------------------------
[docs] def mark_dirty(self, entity: Entity) -> None: """Mark *entity* as modified so it will be flushed on commit.""" self._dirty.add(entity)
[docs] def unmark_dirty(self, entity: Entity) -> None: """Remove *entity* from the dirty set (e.g. after successful flush).""" self._dirty.discard(entity)
@property def dirty_objects(self) -> frozenset[Entity]: """Snapshot of currently-dirty entity instances.""" return frozenset(self._dirty) # ------------------------------------------------------------------ # New-object queue # ------------------------------------------------------------------
[docs] def schedule_save(self, entity: Entity) -> None: """Enqueue *entity* for INSERT on the next commit.""" if entity not in self._to_save: self._to_save.append(entity)
[docs] def unschedule_save(self, entity: Entity) -> None: """Remove *entity* from the INSERT queue (called after a successful save).""" with suppress(ValueError): self._to_save.remove(entity)
@property def objects_to_save(self) -> list[Entity]: """Ordered list of entities pending INSERT.""" return list(self._to_save) # ------------------------------------------------------------------ # M2M collection tracking # ------------------------------------------------------------------
[docs] def track_collection_change(self, owner: Entity, attr: str, action: str, related: Entity) -> None: """Record that *related* was added to or removed from *owner*.*attr*. Parameters ---------- action: ``"add"`` or ``"remove"``. """ key = (owner, attr) if key not in self._modified_collections: self._modified_collections[key] = {"add": [], "remove": []} self._modified_collections[key][action].append(related)
@property def modified_collections( self, ) -> dict[tuple[Entity, str], dict[str, list[Entity]]]: """Snapshot of pending M2M collection mutations.""" return dict(self._modified_collections)
# --------------------------------------------------------------------------- # Thread-local / task-local session stack # --------------------------------------------------------------------------- class _SessionStack: """Nesting-aware session holder for a single thread. Only ONE :class:`SessionCache` exists per thread at any time. Nested ``db_session`` calls (decorator-on-decorator or nested ``with`` blocks) increment and decrement a depth counter instead of pushing separate caches. This ensures that entities created at any nesting level land in the same identity map, mirroring PonyORM's behaviour. """ def __init__(self) -> None: self._cache: SessionCache | None = None self._depth: int = 0 def push(self, cache: SessionCache) -> None: """Push *cache* as the outermost session (must be called at depth 0). For re-entrant / nested entries, call :meth:`enter_nested` instead. """ assert self._depth == 0, "push() called while a session is already active" self._cache = cache self._depth = 1 def enter_nested(self) -> None: """Increment nesting depth for an inner ``db_session`` call.""" self._depth += 1 def pop(self) -> SessionCache: """Decrement depth; return the shared cache. When depth reaches 0 the cache reference is cleared so that the next :meth:`push` starts fresh. """ self._depth -= 1 cache = self._cache if self._depth == 0: self._cache = None return cache # type: ignore[return-value] @property def current(self) -> SessionCache | None: """The active :class:`SessionCache`, or ``None`` when no session is open.""" return self._cache @property def depth(self) -> int: """Current nesting depth — 0 outside a session, ≥1 inside.""" return self._depth # Per-thread session stack — each thread gets its own independent stack so that # concurrent threads cannot interfere with each other's session state. _tls: threading.local = threading.local() def _get_session_stack() -> _SessionStack: """Return the :class:`_SessionStack` for the current thread. Created lazily on first access so that threads that never open a session pay no overhead. """ stack: _SessionStack | None = getattr(_tls, "stack", None) if stack is None: stack = _SessionStack() _tls.stack = stack return stack def get_current_session() -> SessionCache: """Return the active :class:`SessionCache`. Raises ------ RuntimeError If called outside of a ``db_session`` block. """ cache = _get_session_stack().current if cache is None: raise RuntimeError("No active db_session. Use 'with db_session:' or '@db_session'.") return cache # --------------------------------------------------------------------------- # DBSessionManager — the object exposed as ``db_session`` # ---------------------------------------------------------------------------
[docs] class DBSessionManager: """Unified sync/async context manager and decorator for database sessions. The same ``db_session`` object can be used in four ways:: with db_session: # sync context manager (no-arg) ... with db_session(): # sync context manager (call form) ... @db_session # sync decorator def f(): ... @db_session # async decorator (detected at call time) async def f(): ... Parametrised form:: with db_session(sql_debug=True): ... @db_session(retry=3) def f(): ... Nesting is supported — only the outermost context actually clears the cache on exit. Parameters ---------- retry: Number of *additional* attempts to make after the first one fails with :exc:`~nextorm.exceptions.TransactionError` (e.g. a deadlock or serialisation failure). ``retry=3`` means up to 4 total attempts. Applies only when used as a decorator. sql_debug: When ``True``, enables SQL debug logging for the duration of the session and restores the previous state on exit. Equivalent to wrapping the body with :class:`~nextorm.debug.sql_debugging`. serializable: Hint that the session should use ``SERIALIZABLE`` transaction isolation. Recorded on the session cache; enforcement is the responsibility of the provider / middleware layer. immediate: Hint that the session should lock immediately (SQLite ``BEGIN IMMEDIATE``). Recorded on the session cache; enforcement is the responsibility of the provider / middleware layer. """
[docs] def __init__( self, *, retry: int = 0, sql_debug: bool = False, show_values: bool = False, serializable: bool = False, immediate: bool = False, optimistic: bool = True, strict: bool = False, allowed_exceptions: list[type[BaseException]] | None = None, retry_exceptions: list[type[BaseException]] | None = None, ) -> None: self._retry = retry self._sql_debug = sql_debug self._show_values = show_values self._serializable = serializable self._immediate = immediate self._optimistic = optimistic self._strict = strict self._allowed_exceptions: tuple[type[BaseException], ...] = tuple(allowed_exceptions or ()) # Default retry list is [TransactionError], matching PonyORM semantics self._retry_exceptions: tuple[type[BaseException], ...] = tuple( retry_exceptions if retry_exceptions is not None else [TransactionError] )
# --- sync context manager ----------------------------------------------- def _push_cache(self) -> SessionCache: """Return the active :class:`SessionCache`, creating one only for the outermost entry. Nested ``db_session`` calls increment the depth counter and reuse the existing cache, ensuring a single identity map for the whole nesting tree — matching PonyORM semantics. """ stack = _get_session_stack() if stack.current is not None: # Nested entry — share the existing outermost cache. stack.enter_nested() return stack.current cache = SessionCache() cache.serializable = self._serializable cache.immediate = self._immediate cache.optimistic = self._optimistic cache.strict = self._strict stack.push(cache) return cache
[docs] def __enter__(self) -> SessionCache: if self._sql_debug: from nextorm.debug import set_sql_debug # noqa: PLC0415 set_sql_debug(True) return self._push_cache()
[docs] def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool | None: if self._sql_debug: from nextorm.debug import set_sql_debug # noqa: PLC0415 set_sql_debug(False) # Check outermost BEFORE popping (depth 1 = this is the only active session) is_outermost = _get_session_stack().depth == 1 # allowed_exceptions: treat as success (commit, then suppress). suppress = False if ( exc_type is not None and self._allowed_exceptions and issubclass(exc_type, self._allowed_exceptions) ): suppress = True exc_type = None # treat as clean exit for commit dbs: list[Any] = [] commit_exc: BaseException | None = None if exc_type is None: # Clean exit: flush all pending changes then commit, primary-first. if is_outermost: dbs = _collect_dbs(_get_session_stack().current) try: for db in dbs: db.flush() except Exception as _fe: commit_exc = _fe for _db in dbs: try: # noqa: SIM105 _db._rollback_transaction() except Exception: pass else: primary, secondaries = (dbs[0] if dbs else None), dbs[1:] if primary is not None: try: primary._commit_transaction() except Exception as _ce: commit_exc = _ce for _db in secondaries: try: # noqa: SIM105 _db._rollback_transaction() except Exception: pass else: for _db in secondaries: try: # noqa: SIM105 _db._commit_transaction() except Exception as _se: if commit_exc is None: commit_exc = _se else: # Unhandled exception: roll back all databases. if is_outermost: dbs = _collect_dbs(_get_session_stack().current) for db in dbs: try: # noqa: SIM105 db._rollback_transaction() except Exception: pass # ALWAYS pop and clear — even if commit raised above. cache = _get_session_stack().pop() if is_outermost: cache.clear() if commit_exc is not None: raise commit_exc # re-raise after cleanup return bool(suppress) or None
# Calling ``with db_session():`` creates a temporary manager so that both # ``with db_session:`` and ``with db_session():`` work identically. @overload def __call__[T, **P](self, func: Callable[P, T], /) -> Callable[P, T]: ... @overload def __call__( self, func: None = None, /, *, retry: int | None = None, sql_debug: bool | None = None, show_values: bool | None = None, serializable: bool | None = None, immediate: bool | None = None, optimistic: bool | None = None, strict: bool | None = None, allowed_exceptions: list[type[BaseException]] | None = None, retry_exceptions: list[type[BaseException]] | None = None, ) -> Self: ...
[docs] def __call__( self, func: Callable[..., Any] | None = None, /, *, retry: int | None = None, sql_debug: bool | None = None, show_values: bool | None = None, serializable: bool | None = None, immediate: bool | None = None, optimistic: bool | None = None, strict: bool | None = None, allowed_exceptions: list[type[BaseException]] | None = None, retry_exceptions: list[type[BaseException]] | None = None, ) -> Callable[..., Any] | Self: """Called as ``db_session(...)`` to create a parametrised manager, or as a decorator. When called **without** a callable argument (e.g. ``db_session(retry=3)``), returns a new :class:`DBSessionManager` configured with the given parameters. When called **with** a callable (``@db_session`` bare decorator), wraps it. """ if func is None: # ``with db_session(retry=n, ...)`` or ``@db_session(retry=n)`` return DBSessionManager( retry=retry if retry is not None else self._retry, sql_debug=sql_debug if sql_debug is not None else self._sql_debug, show_values=show_values if show_values is not None else self._show_values, serializable=serializable if serializable is not None else self._serializable, immediate=immediate if immediate is not None else self._immediate, optimistic=optimistic if optimistic is not None else self._optimistic, strict=strict if strict is not None else self._strict, allowed_exceptions=( allowed_exceptions if allowed_exceptions is not None else list(self._allowed_exceptions) ), retry_exceptions=( retry_exceptions if retry_exceptions is not None else list(self._retry_exceptions) ), ) # Bare ``@db_session`` / ``@db_session()`` with a callable → wrap it if inspect.iscoroutinefunction(func): return self._async_wrapper(func) return self._sync_wrapper(func)
def _sync_wrapper(self, func: Callable[..., Any]) -> Callable[..., Any]: retry = self._retry retry_exc = self._retry_exceptions # already a tuple @functools.wraps(func) def _inner(*args: Any, **kwargs: Any) -> Any: attempts = 0 while True: try: with self: return func(*args, **kwargs) except retry_exc: attempts += 1 if attempts > retry: raise return _inner def _async_wrapper(self, func: Callable[..., Any]) -> Callable[..., Any]: retry = self._retry retry_exc = self._retry_exceptions # already a tuple @functools.wraps(func) async def _inner(*args: Any, **kwargs: Any) -> Any: attempts = 0 while True: try: async with self: return await func(*args, **kwargs) except retry_exc: attempts += 1 if attempts > retry: raise return _inner # Allow ``@db_session`` (without call) on both sync and async functions def __get__(self, obj: Any, objtype: Any = None) -> Any: # pragma: no cover return self # --- async context manager ----------------------------------------------
[docs] async def __aenter__(self) -> SessionCache: if self._sql_debug: from nextorm.debug import set_sql_debug # noqa: PLC0415 set_sql_debug(True) return self._push_cache()
[docs] async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool | None: if self._sql_debug: from nextorm.debug import set_sql_debug # noqa: PLC0415 set_sql_debug(False) is_outermost = _get_session_stack().depth == 1 suppress = False if ( exc_type is not None and self._allowed_exceptions and issubclass(exc_type, self._allowed_exceptions) ): suppress = True exc_type = None dbs: list[Any] = [] commit_exc: BaseException | None = None if exc_type is None: # Clean exit: flush all pending changes then commit, primary-first. if is_outermost: dbs = _collect_dbs(_get_session_stack().current) try: for db in dbs: if getattr(db, "_is_async", False): await db.aflush() else: db.flush() except Exception as _fe: commit_exc = _fe for _db in dbs: try: # noqa: SIM105 if getattr(_db, "_is_async", False): await _db._arollback_transaction() else: _db._rollback_transaction() except Exception: pass else: primary, secondaries = (dbs[0] if dbs else None), dbs[1:] if primary is not None: try: if getattr(primary, "_is_async", False): await primary._acommit_transaction() else: primary._commit_transaction() except Exception as _ce: commit_exc = _ce for _db in secondaries: try: # noqa: SIM105 if getattr(_db, "_is_async", False): await _db._arollback_transaction() else: _db._rollback_transaction() except Exception: pass else: for _db in secondaries: try: # noqa: SIM105 if getattr(_db, "_is_async", False): await _db._acommit_transaction() else: _db._commit_transaction() except Exception as _se: if commit_exc is None: commit_exc = _se else: # Unhandled exception: roll back all databases. if is_outermost: dbs = _collect_dbs(_get_session_stack().current) for db in dbs: try: # noqa: SIM105 if getattr(db, "_is_async", False): await db._arollback_transaction() else: db._rollback_transaction() except Exception: pass cache = _get_session_stack().pop() if is_outermost: cache.clear() if commit_exc is not None: raise commit_exc return bool(suppress) or None
# --- generator-based helpers (for library code) -------------------------
[docs] @contextmanager def as_context(self) -> Generator[SessionCache, None, None]: """Explicit generator-based context (useful in pytest with ``with`` blocks).""" with self as cache: yield cache
[docs] @asynccontextmanager async def as_async_context(self) -> AsyncGenerator[SessionCache, None]: """Explicit async generator-based context.""" async with self as cache: yield cache
# Introspection helpers used in tests and framework internals @property def depth(self) -> int: """Current nesting depth (0 = not inside any session).""" return _get_session_stack().depth
[docs] @staticmethod def current() -> SessionCache | None: """Return the innermost active :class:`SessionCache`, or ``None``.""" return _get_session_stack().current
db_session: DBSessionManager = DBSessionManager()