From 8c365fdb6c035edde8ba4f9e679b6c1b5fe2d041 Mon Sep 17 00:00:00 2001 From: Christophe Haen Date: Thu, 16 Jan 2025 13:02:40 +0100 Subject: [PATCH] tests: fix the mock_osdb --- .../src/diracx/testing/mock_osdb.py | 104 +++++++++--------- diracx-testing/src/diracx/testing/utils.py | 18 ++- 2 files changed, 65 insertions(+), 57 deletions(-) diff --git a/diracx-testing/src/diracx/testing/mock_osdb.py b/diracx-testing/src/diracx/testing/mock_osdb.py index 6e181a79..66561710 100644 --- a/diracx-testing/src/diracx/testing/mock_osdb.py +++ b/diracx-testing/src/diracx/testing/mock_osdb.py @@ -83,18 +83,17 @@ async def create_index_template(self) -> None: await conn.run_sync(self._sql_db.metadata.create_all) async def upsert(self, doc_id, document) -> None: - async with self: - values = {} - for key, value in document.items(): - if key in self.fields: - values[key] = value - else: - values.setdefault("extra", {})[key] = value + values = {} + for key, value in document.items(): + if key in self.fields: + values[key] = value + else: + values.setdefault("extra", {})[key] = value - stmt = sqlite_insert(self._table).values(doc_id=doc_id, **values) - # TODO: Upsert the JSON blob properly - stmt = stmt.on_conflict_do_update(index_elements=["doc_id"], set_=values) - await self._sql_db.conn.execute(stmt) + stmt = sqlite_insert(self._table).values(doc_id=doc_id, **values) + # TODO: Upsert the JSON blob properly + stmt = stmt.on_conflict_do_update(index_elements=["doc_id"], set_=values) + await self._sql_db.conn.execute(stmt) async def search( self, @@ -105,48 +104,47 @@ async def search( distinct: bool = False, per_page: int = 100, page: int | None = None, - ) -> tuple[int, list[dict[Any, Any]]]: - async with self: - # Apply selection - if parameters: - columns = [] - for p in parameters: - if p in self.fields: - columns.append(self._table.columns[p]) - else: - columns.append(self._table.columns["extra"][p].label(p)) - else: - columns = self._table.columns - stmt = select(*columns) - if distinct: - stmt = stmt.distinct() - - # Apply filtering - stmt = sql_utils.apply_search_filters( - self._table.columns.__getitem__, stmt, search - ) - - # Apply sorting - stmt = sql_utils.apply_sort_constraints( - self._table.columns.__getitem__, stmt, sorts - ) - - # Apply pagination - if page is not None: - stmt = stmt.offset((page - 1) * per_page).limit(per_page) - - results = [] - async for row in await self._sql_db.conn.stream(stmt): - result = dict(row._mapping) - result.pop("doc_id", None) - if "extra" in result: - result.update(result.pop("extra")) - for k, v in list(result.items()): - if isinstance(v, datetime) and v.tzinfo is None: - result[k] = v.replace(tzinfo=timezone.utc) - if v is None: - result.pop(k) - results.append(result) + ) -> list[dict[Any, Any]]: + # Apply selection + if parameters: + columns = [] + for p in parameters: + if p in self.fields: + columns.append(self._table.columns[p]) + else: + columns.append(self._table.columns["extra"][p].label(p)) + else: + columns = self._table.columns + stmt = select(*columns) + if distinct: + stmt = stmt.distinct() + + # Apply filtering + stmt = sql_utils.apply_search_filters( + self._table.columns.__getitem__, stmt, search + ) + + # Apply sorting + stmt = sql_utils.apply_sort_constraints( + self._table.columns.__getitem__, stmt, sorts + ) + + # Apply pagination + if page is not None: + stmt = stmt.offset((page - 1) * per_page).limit(per_page) + + results = [] + async for row in await self._sql_db.conn.stream(stmt): + result = dict(row._mapping) + result.pop("doc_id", None) + if "extra" in result: + result.update(result.pop("extra")) + for k, v in list(result.items()): + if isinstance(v, datetime) and v.tzinfo is None: + result[k] = v.replace(tzinfo=timezone.utc) + if v is None: + result.pop(k) + results.append(result) return results async def ping(self): diff --git a/diracx-testing/src/diracx/testing/utils.py b/diracx-testing/src/diracx/testing/utils.py index c895f4d4..73bfef24 100644 --- a/diracx-testing/src/diracx/testing/utils.py +++ b/diracx-testing/src/diracx/testing/utils.py @@ -252,6 +252,7 @@ def configure(self, enabled_dependencies): assert ( self.app.dependency_overrides == {} and self.app.lifetime_functions == [] ), "configure cannot be nested" + for k, v in self.all_dependency_overrides.items(): class_name = k.__self__.__name__ @@ -284,17 +285,26 @@ async def create_db_schemas(self): import sqlalchemy from sqlalchemy.util.concurrency import greenlet_spawn + from diracx.db.os.utils import BaseOSDB from diracx.db.sql.utils import BaseSQLDB + from diracx.testing.mock_osdb import MockOSDBMixin for k, v in self.app.dependency_overrides.items(): - # Ignore dependency overrides which aren't BaseSQLDB.transaction - if ( - isinstance(v, UnavailableDependency) - or k.__func__ != BaseSQLDB.transaction.__func__ + # Ignore dependency overrides which aren't BaseSQLDB.transaction or BaseOSDB.session + if isinstance(v, UnavailableDependency) or k.__func__ not in ( + BaseSQLDB.transaction.__func__, + BaseOSDB.session.__func__, ): + continue + # The first argument of the overridden BaseSQLDB.transaction is the DB object db = v.args[0] + # We expect the OS DB to be mocked with sqlite, so use the + # internal DB + if isinstance(db, MockOSDBMixin): + db = db._sql_db + assert isinstance(db, BaseSQLDB), (k, db) # set PRAGMA foreign_keys=ON if sqlite