diff --git a/src/diracx/db/sql/utils.py b/src/diracx/db/sql/utils.py index 8f8a293e..60e15183 100644 --- a/src/diracx/db/sql/utils.py +++ b/src/diracx/db/sql/utils.py @@ -6,9 +6,10 @@ import logging import os from abc import ABCMeta +from contextvars import ContextVar from datetime import datetime, timedelta, timezone from functools import partial -from typing import TYPE_CHECKING, AsyncIterator, Self +from typing import TYPE_CHECKING, AsyncIterator, Self, cast from pydantic import parse_obj_as from sqlalchemy import Column as RawColumn @@ -73,7 +74,12 @@ class BaseSQLDB(metaclass=ABCMeta): metadata: MetaData def __init__(self, db_url: str) -> None: - self._conn = None + # We use a ContextVar to make sure that self._conn + # is specific to each context, and avoid parallel + # route executions to overlap + self._conn: ContextVar[AsyncConnection | None] = ContextVar( + "_conn", default=None + ) self._db_url = db_url self._engine: AsyncEngine | None = None @@ -121,7 +127,13 @@ def transaction(cls) -> Self: def engine(self) -> AsyncEngine: """The engine to use for database operations. + It is normally not necessary to use the engine directly, + unless you are doing something special, like writing a + test fixture that gives you a db. + + Requires that the engine_context has been entered. + """ assert self._engine is not None, "engine_context must be entered" return self._engine @@ -129,8 +141,8 @@ def engine(self) -> AsyncEngine: @contextlib.asynccontextmanager async def engine_context(self) -> AsyncIterator[None]: """Context manage to manage the engine lifecycle. - - Tables are automatically created upon entering + This is called once at the application startup + (see ``lifetime_functions``) """ assert self._engine is None, "engine_context cannot be nested" @@ -144,19 +156,30 @@ async def engine_context(self) -> AsyncIterator[None]: @property def conn(self) -> AsyncConnection: - if self._conn is None: + if self._conn.get() is None: raise RuntimeError(f"{self.__class__} was used before entering") - return self._conn + return cast(AsyncConnection, self._conn.get()) async def __aenter__(self): - self._conn = await self.engine.connect().__aenter__() + """ + Create a connection. + This is called by the Dependency mechanism (see ``db_transaction``), + It will create a new connection/transaction for each route call. + """ + + self._conn.set(await self.engine.connect().__aenter__()) return self async def __aexit__(self, exc_type, exc, tb): + """ + This is called when exciting a route. + If there was no exception, the changes in the DB are committed. + Otherwise, they are rollbacked. + """ if exc_type is None: - await self._conn.commit() - await self._conn.__aexit__(exc_type, exc, tb) - self._conn = None + await self._conn.get().commit() + await self._conn.get().__aexit__(exc_type, exc, tb) + self._conn.set(None) def apply_search_filters(table, stmt, search):