Skip to content

Commit

Permalink
tests: fix the mock_osdb
Browse files Browse the repository at this point in the history
  • Loading branch information
chaen committed Jan 16, 2025
1 parent 2bf4e28 commit 8c365fd
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 57 deletions.
104 changes: 51 additions & 53 deletions diracx-testing/src/diracx/testing/mock_osdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,17 @@ async def create_index_template(self) -> None:
await conn.run_sync(self._sql_db.metadata.create_all)

async def upsert(self, doc_id, document) -> None:
async with self:
values = {}
for key, value in document.items():
if key in self.fields:
values[key] = value
else:
values.setdefault("extra", {})[key] = value
values = {}
for key, value in document.items():
if key in self.fields:
values[key] = value
else:
values.setdefault("extra", {})[key] = value

stmt = sqlite_insert(self._table).values(doc_id=doc_id, **values)
# TODO: Upsert the JSON blob properly
stmt = stmt.on_conflict_do_update(index_elements=["doc_id"], set_=values)
await self._sql_db.conn.execute(stmt)
stmt = sqlite_insert(self._table).values(doc_id=doc_id, **values)
# TODO: Upsert the JSON blob properly
stmt = stmt.on_conflict_do_update(index_elements=["doc_id"], set_=values)
await self._sql_db.conn.execute(stmt)

async def search(
self,
Expand All @@ -105,48 +104,47 @@ async def search(
distinct: bool = False,
per_page: int = 100,
page: int | None = None,
) -> tuple[int, list[dict[Any, Any]]]:
async with self:
# Apply selection
if parameters:
columns = []
for p in parameters:
if p in self.fields:
columns.append(self._table.columns[p])
else:
columns.append(self._table.columns["extra"][p].label(p))
else:
columns = self._table.columns
stmt = select(*columns)
if distinct:
stmt = stmt.distinct()

# Apply filtering
stmt = sql_utils.apply_search_filters(
self._table.columns.__getitem__, stmt, search
)

# Apply sorting
stmt = sql_utils.apply_sort_constraints(
self._table.columns.__getitem__, stmt, sorts
)

# Apply pagination
if page is not None:
stmt = stmt.offset((page - 1) * per_page).limit(per_page)

results = []
async for row in await self._sql_db.conn.stream(stmt):
result = dict(row._mapping)
result.pop("doc_id", None)
if "extra" in result:
result.update(result.pop("extra"))
for k, v in list(result.items()):
if isinstance(v, datetime) and v.tzinfo is None:
result[k] = v.replace(tzinfo=timezone.utc)
if v is None:
result.pop(k)
results.append(result)
) -> list[dict[Any, Any]]:
# Apply selection
if parameters:
columns = []
for p in parameters:
if p in self.fields:
columns.append(self._table.columns[p])
else:
columns.append(self._table.columns["extra"][p].label(p))
else:
columns = self._table.columns
stmt = select(*columns)
if distinct:
stmt = stmt.distinct()

# Apply filtering
stmt = sql_utils.apply_search_filters(
self._table.columns.__getitem__, stmt, search
)

# Apply sorting
stmt = sql_utils.apply_sort_constraints(
self._table.columns.__getitem__, stmt, sorts
)

# Apply pagination
if page is not None:
stmt = stmt.offset((page - 1) * per_page).limit(per_page)

results = []
async for row in await self._sql_db.conn.stream(stmt):
result = dict(row._mapping)
result.pop("doc_id", None)
if "extra" in result:
result.update(result.pop("extra"))
for k, v in list(result.items()):
if isinstance(v, datetime) and v.tzinfo is None:
result[k] = v.replace(tzinfo=timezone.utc)
if v is None:
result.pop(k)
results.append(result)
return results

async def ping(self):
Expand Down
18 changes: 14 additions & 4 deletions diracx-testing/src/diracx/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def configure(self, enabled_dependencies):
assert (
self.app.dependency_overrides == {} and self.app.lifetime_functions == []
), "configure cannot be nested"

for k, v in self.all_dependency_overrides.items():

class_name = k.__self__.__name__
Expand Down Expand Up @@ -284,17 +285,26 @@ async def create_db_schemas(self):
import sqlalchemy
from sqlalchemy.util.concurrency import greenlet_spawn

from diracx.db.os.utils import BaseOSDB
from diracx.db.sql.utils import BaseSQLDB
from diracx.testing.mock_osdb import MockOSDBMixin

for k, v in self.app.dependency_overrides.items():
# Ignore dependency overrides which aren't BaseSQLDB.transaction
if (
isinstance(v, UnavailableDependency)
or k.__func__ != BaseSQLDB.transaction.__func__
# Ignore dependency overrides which aren't BaseSQLDB.transaction or BaseOSDB.session
if isinstance(v, UnavailableDependency) or k.__func__ not in (
BaseSQLDB.transaction.__func__,
BaseOSDB.session.__func__,
):

continue

# The first argument of the overridden BaseSQLDB.transaction is the DB object
db = v.args[0]
# We expect the OS DB to be mocked with sqlite, so use the
# internal DB
if isinstance(db, MockOSDBMixin):
db = db._sql_db

assert isinstance(db, BaseSQLDB), (k, db)

# set PRAGMA foreign_keys=ON if sqlite
Expand Down

0 comments on commit 8c365fd

Please sign in to comment.