Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a ContextVar to manage the sql connection instance as an instance… #169

Merged
merged 1 commit into from
Nov 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions src/diracx/db/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import logging
import os
from abc import ABCMeta
from contextvars import ContextVar
from datetime import datetime, timedelta, timezone
from functools import partial
from typing import TYPE_CHECKING, AsyncIterator, Self
from typing import TYPE_CHECKING, AsyncIterator, Self, cast

from pydantic import parse_obj_as
from sqlalchemy import Column as RawColumn
Expand Down Expand Up @@ -73,7 +74,12 @@ class BaseSQLDB(metaclass=ABCMeta):
metadata: MetaData

def __init__(self, db_url: str) -> None:
self._conn = None
# We use a ContextVar to make sure that self._conn
# is specific to each context, and avoid parallel
# route executions to overlap
self._conn: ContextVar[AsyncConnection | None] = ContextVar(
"_conn", default=None
)
self._db_url = db_url
self._engine: AsyncEngine | None = None

Expand Down Expand Up @@ -121,16 +127,22 @@ def transaction(cls) -> Self:
def engine(self) -> AsyncEngine:
"""The engine to use for database operations.

It is normally not necessary to use the engine directly,
unless you are doing something special, like writing a
test fixture that gives you a db.


Requires that the engine_context has been entered.

"""
assert self._engine is not None, "engine_context must be entered"
return self._engine

@contextlib.asynccontextmanager
async def engine_context(self) -> AsyncIterator[None]:
"""Context manage to manage the engine lifecycle.

Tables are automatically created upon entering
This is called once at the application startup
(see ``lifetime_functions``)
"""
assert self._engine is None, "engine_context cannot be nested"

Expand All @@ -144,19 +156,30 @@ async def engine_context(self) -> AsyncIterator[None]:

@property
def conn(self) -> AsyncConnection:
if self._conn is None:
if self._conn.get() is None:
raise RuntimeError(f"{self.__class__} was used before entering")
return self._conn
return cast(AsyncConnection, self._conn.get())

async def __aenter__(self):
self._conn = await self.engine.connect().__aenter__()
"""
Create a connection.
This is called by the Dependency mechanism (see ``db_transaction``),
It will create a new connection/transaction for each route call.
"""

self._conn.set(await self.engine.connect().__aenter__())
return self

async def __aexit__(self, exc_type, exc, tb):
"""
This is called when exciting a route.
If there was no exception, the changes in the DB are committed.
Otherwise, they are rollbacked.
"""
if exc_type is None:
await self._conn.commit()
await self._conn.__aexit__(exc_type, exc, tb)
self._conn = None
await self._conn.get().commit()
await self._conn.get().__aexit__(exc_type, exc, tb)
self._conn.set(None)


def apply_search_filters(table, stmt, search):
Expand Down