Skip to content

Commit

Permalink
Minimal updates to db import for tests to pass.
Browse files Browse the repository at this point in the history
  • Loading branch information
jgadling committed Oct 3, 2024
1 parent ca6612e commit f38cfaf
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 21 deletions.
16 changes: 15 additions & 1 deletion apiv2/db_import/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from db_import.importers.annotation import (
AnnotationAuthorDBImporter,
AnnotationDBImporter,
AnnotationMethodLinkDBImporter,
StaleAnnotationDeletionDBImporter,
)
from db_import.importers.base_importer import DBImportConfig
Expand All @@ -33,6 +34,7 @@ def db_import_options(func):
options = []
options.append(click.option("--import-annotations", is_flag=True, default=False))
options.append(click.option("--import-annotation-authors", is_flag=True, default=False))
options.append(click.option("--import-annotation-method-links", is_flag=True, default=False))
options.append(click.option("--import-dataset-authors", is_flag=True, default=False))
options.append(click.option("--import-dataset-funding", is_flag=True, default=False))
options.append(click.option("--import-depositions", is_flag=True, default=False))
Expand Down Expand Up @@ -75,6 +77,7 @@ def load(
filter_dataset: list[str],
import_annotations: bool,
import_annotation_authors: bool,
import_annotation_method_links: bool,
import_dataset_authors: bool,
import_dataset_funding: bool,
import_depositions: bool,
Expand All @@ -96,6 +99,7 @@ def load(
filter_dataset,
import_annotations,
import_annotation_authors,
import_annotation_method_links,
import_dataset_authors,
import_dataset_funding,
import_depositions,
Expand All @@ -119,6 +123,7 @@ def load_func(
filter_dataset: list[str] | None = None,
import_annotations: bool = False,
import_annotation_authors: bool = False,
import_annotation_method_links: bool = False,
import_dataset_authors: bool = False,
import_dataset_funding: bool = False,
import_depositions: bool = False,
Expand All @@ -139,6 +144,7 @@ def load_func(
if import_everything:
import_annotations = True
import_annotation_authors = True
import_annotation_method_links = True
import_dataset_authors = True
import_dataset_funding = True
import_depositions = True
Expand All @@ -148,7 +154,7 @@ def load_func(
import_tomogram_authors = True
import_tomogram_voxel_spacing = True
else:
import_annotations = max(import_annotations, import_annotation_authors)
import_annotations = max(import_annotations, import_annotation_authors, import_annotation_method_links)
import_tomograms = max(import_tomograms, import_tomogram_authors)
import_tomogram_voxel_spacing = max(import_annotations, import_tomograms, import_tomogram_voxel_spacing)
import_runs = max(import_runs, import_tiltseries, import_tomogram_voxel_spacing)
Expand Down Expand Up @@ -209,6 +215,7 @@ def load_func(

if import_tomograms:
tomogram_cleaner = StaleTomogramDeletionDBImporter(voxel_spacing_obj.id, config)
TomogramDBImporter.load_deposition_map(config)
for tomogram in TomogramDBImporter.get_item(voxel_spacing_obj.id, run_id, voxel_spacing, config):
tomogram_obj = tomogram.import_to_db()
tomogram_cleaner.mark_as_active(tomogram_obj)
Expand All @@ -231,6 +238,13 @@ def load_func(
config,
)
annotation_authors.import_to_db()
if import_annotation_method_links:
anno_method_links = AnnotationMethodLinkDBImporter.get_item(
annotation_obj.id,
annotation,
config,
)
anno_method_links.import_to_db()
annotation_cleaner.remove_stale_objects()

voxel_spacing_cleaner.mark_as_active(voxel_spacing_obj)
Expand Down
39 changes: 36 additions & 3 deletions apiv2/db_import/importers/annotation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from typing import Any, Iterator

from database import models
Expand Down Expand Up @@ -34,7 +33,6 @@ def __init__(

def get_data_map(self) -> dict[str, Any]:
deposition = get_deposition(self.config, self.metadata.get("deposition_id"))
method_links = self.metadata.get("method_links")
return {
"s3_metadata_path": self.join_path(self.config.s3_prefix, self.metadata_path),
"https_metadata_path": self.join_path(self.config.https_prefix, self.metadata_path),
Expand All @@ -56,7 +54,6 @@ def get_data_map(self) -> dict[str, Any]:
"is_curator_recommended": ["is_curator_recommended"],
"method_type": ["method_type"],
"deposition_id": deposition.id,
"method_links": json.dumps(method_links) if method_links else None,
}

def import_to_db(self) -> Base:
Expand Down Expand Up @@ -227,6 +224,42 @@ def get_item(
return cls(annotation_id, parent, config)


class AnnotationMethodLinkDBImporter(StaleDeletionDBImporter):
def __init__(self, annotation_id: int, parent: AnnotationDBImporter, config: DBImportConfig):
self.annotation_id = annotation_id
self.parent = parent
self.config = config
self.metadata = parent.metadata.get("method_links", [])

def get_data_map(self) -> dict[str, Any]:
return {
"annotation_id": self.annotation_id,
"link_type": ["link_type"],
"name": ["custom_name"],
"link": ["link"],
}

@classmethod
def get_id_fields(cls) -> list[str]:
return ["annotation_id", "link"]

@classmethod
def get_db_model_class(cls) -> type[Base]:
return models.AnnotationMethodLink

def get_filters(self) -> dict[str, Any]:
return {"annotation_id": self.annotation_id}

@classmethod
def get_item(
cls,
annotation_id: int,
parent: AnnotationDBImporter,
config: DBImportConfig,
) -> "AnnotationAuthorDBImporter":
return cls(annotation_id, parent, config)


class StaleAnnotationDeletionDBImporter(StaleParentDeletionDBImporter):
ref_klass = AnnotationDBImporter

Expand Down
18 changes: 17 additions & 1 deletion apiv2/db_import/importers/tomogram.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from typing import Any, Iterator

import sqlalchemy as sa
from database import models
from db_import.common.normalize_fields import normalize_fiducial_alignment
from db_import.importers.base_importer import (
Expand All @@ -16,6 +17,7 @@


class TomogramDBImporter(BaseDBImporter):
deposition_map = {}
def __init__(
self,
voxel_spacing_id: int,
Expand All @@ -37,6 +39,12 @@ def get_metadata_file_path(self) -> str:
def get_data_map(self) -> dict[str, Any]:
return {**self.get_direct_mapped_fields(), **self.get_computed_fields()}

@classmethod
def load_deposition_map(cls, config) -> None:
session = config.get_db_session()
for item in session.scalars(sa.select(models.Deposition)).all():
cls.deposition_map[item.id] = item

@classmethod
def get_id_fields(cls) -> list[str]:
return ["name", "tomogram_voxel_spacing_id"]
Expand All @@ -60,6 +68,9 @@ def get_direct_mapped_fields(cls) -> dict[str, Any]:
"offset_y": ["offset", "y"],
"offset_z": ["offset", "z"],
"deposition_id": ["deposition_id"],
"deposition_date": ["deposition_date"],
"release_date": ["release_date"],
"last_modified_date": ["last_modified_date"],
}

def normalize_to_unknown_str(self, value: str) -> str:
Expand Down Expand Up @@ -95,8 +106,13 @@ def get_computed_fields(self) -> dict[str, Any]:
"key_photo_thumbnail_url": None,
"neuroglancer_config": self.generate_neuroglancer_data(),
"type": self.get_tomogram_type(),
"is_standardized": self.metadata.get("is_standardized") or False,
"is_portal_standard": self.metadata.get("is_standardized") or False,
}
date_fields = ["deposition_date", "release_date", "last_modified_date"]
if not self.metadata.get("deposition_date"):
deposition = self.deposition_map[self.metadata["deposition_id"]]
for date_field in date_fields:
extra_data[date_field] = getattr(deposition, date_field)
if key_photos := self.metadata.get("key_photo"):
extra_data["key_photo_url"] = self.join_path(https_prefix, key_photos.get("snapshot"))
extra_data["key_photo_thumbnail_url"] = self.join_path(https_prefix, key_photos.get("thumbnail"))
Expand Down
58 changes: 51 additions & 7 deletions apiv2/db_import/tests/populate_db.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from datetime import datetime
from datetime import date, datetime

import sqlalchemy as sa
from database.models import (
Annotation,
AnnotationAuthor,
AnnotationFile,
AnnotationMethodLink,
AnnotationShape,
Dataset,
DatasetAuthor,
Expand Down Expand Up @@ -34,6 +35,7 @@
ANNOTATION_ID = 602
ANNOTATION_FILE_ID = 701
ANNOTATION_AUTHOR_ID = 702
ANNOTATION_METHOD_LINK_ID = 802

STALE_RUN_ID = 902
STALE_TOMOGRAM_ID = 903
Expand All @@ -49,9 +51,9 @@ def stale_deposition_metadata() -> dict:
"id": DEPOSITION_ID1,
"title": "Test Deposition",
"description": "Test Description",
"deposition_date": datetime.now().date(),
"release_date": datetime.now().date(),
"last_modified_date": datetime.now().date(),
"deposition_date": date(2022, 4, 2),
"release_date": date(2022, 6, 1),
"last_modified_date": date(2022, 9, 2),
"deposition_publications": "Publications",
}

Expand Down Expand Up @@ -233,7 +235,10 @@ def populate_stale_tomogram_voxel_spacing(session: sa.orm.Session, run_id: int =
offset_x=0,
offset_y=0,
offset_z=0,
is_standardized=True,
is_portal_standard=True,
deposition_date=datetime.min,
release_date=datetime.min,
last_modified_date=datetime.min,
)
session.add(stale_tomogram)
session.add(TomogramAuthor(tomogram=stale_tomogram, name="Jane Smith", author_list_order=1))
Expand Down Expand Up @@ -284,7 +289,10 @@ def populate_tomograms(session: sa.orm.Session) -> Tomogram:
offset_x=0,
offset_y=0,
offset_z=0,
is_standardized=True,
is_portal_standard=True,
deposition_date=datetime.min,
release_date=datetime.min,
last_modified_date=datetime.min,
)


Expand Down Expand Up @@ -314,7 +322,10 @@ def populate_stale_tomograms(session: sa.orm.Session) -> Tomogram:
offset_x=0,
offset_y=0,
offset_z=0,
is_standardized=True,
is_portal_standard=True,
deposition_date=datetime.min,
release_date=datetime.min,
last_modified_date=datetime.min,
)


Expand Down Expand Up @@ -530,6 +541,39 @@ def populate_annotation_authors(session: sa.orm.Session) -> None:
session.add(author2)


@write_data
def populate_stale_annotation_method_links(session: sa.orm.Session) -> None:
populate_stale_annotations(session)
session.add(AnnotationMethodLink(annotation_id=STALE_ANNOTATION_ID, name="Stale Link 0", link_type="other", link="https://some-link.com"))
session.add(
AnnotationMethodLink(
annotation_id=STALE_ANNOTATION_ID,
name="Stale link",
link_type="source_code",
link="https://stale-link.com",
),
)

@write_data
def populate_annotation_method_links(session: sa.orm.Session) -> None:
populate_annotations(session)
row = AnnotationMethodLink(
id=ANNOTATION_METHOD_LINK_ID,
annotation_id=ANNOTATION_ID,
link="https://fake-link.com/resources/100-foo-1.0_method.pdf",
link_type="documentation",
name="Method Documentation",
)
session.add(row)
row2 = AnnotationMethodLink(
annotation_id=ANNOTATION_ID,
link="https://another-link.com",
link_type="website",
name="Stale Link",
)
session.add(row2)


@write_data
def populate_stale_annotation_authors(session: sa.orm.Session) -> None:
populate_stale_annotations(session)
Expand Down
65 changes: 65 additions & 0 deletions apiv2/db_import/tests/test_db_annotation_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
ANNOTATION_AUTHOR_ID,
ANNOTATION_FILE_ID,
ANNOTATION_ID,
ANNOTATION_METHOD_LINK_ID,
DATASET_ID,
RUN1_ID,
TOMOGRAM_VOXEL_ID1,
populate_annotation_authors,
populate_annotation_files,
populate_annotation_method_links,
populate_stale_annotation_authors,
populate_stale_annotation_files,
populate_stale_annotation_method_links,
)
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -123,6 +126,25 @@ def expected_annotation_authors() -> list[dict[str, Any]]:
]


@pytest.fixture
def expected_annotation_method_links() -> list[dict[str, Any]]:
return [
{
"id": ANNOTATION_METHOD_LINK_ID,
"annotation_id": ANNOTATION_ID,
"link": "https://fake-link.com/resources/100-foo-1.0_method.pdf",
"link_type": "documentation",
"name": "Method Documentation",
},
{
"annotation_id": ANNOTATION_ID,
"link": "https://another-link.com/100-foo-1.0_code.zip",
"link_type": "source_code",
"name": "Source Code",
},
]


# Tests addition and update of annotations and annotation files
def test_import_annotations(
sync_db_session: Session,
Expand Down Expand Up @@ -207,3 +229,46 @@ def test_import_annotation_authors_removes_stale(
assert len(annotation.authors) == len(expected_annotation_authors)
for author in annotation.authors.order_by(models.AnnotationAuthor.author_list_order):
verify_model(author, next(expected_annotations_authors_iter))


# Tests update of existing annotation method links, addition of new method links
def test_import_annotation_method_links(
sync_db_session: Session,
verify_dataset_import: Callable[[list[str]], models.Dataset],
verify_model: Callable[[Base, dict[str, Any]], None],
expected_annotations: list[dict[str, Any]],
expected_annotation_method_links: list[dict[str, Any]],
) -> None:
populate_annotation_method_links(sync_db_session)
sync_db_session.commit()
verify_dataset_import(import_annotation_method_links=True)
expected_iter = iter(expected_annotation_method_links)
actual_runs = sync_db_session.get(models.Run, RUN1_ID)
for annotation in sorted(actual_runs.annotations, key=lambda x: x.s3_metadata_path):
assert len(annotation.method_links) == len(expected_annotation_method_links)
# for item in annotation.method_links.order_by(models.AnnotationMethodLink.link):
for item in sorted(annotation.method_links, key=lambda x: x.link):
verify_model(item, next(expected_iter))


# Tests deletion of stale annotation and annotation method links
def test_import_annotation_method_links_removes_stale(
sync_db_session: Session,
verify_dataset_import: Callable[[list[str]], models.Dataset],
verify_model: Callable[[Base, dict[str, Any]], None],
expected_annotations: list[dict[str, Any]],
expected_annotation_method_links: list[dict[str, Any]],
) -> None:
populate_annotation_method_links(sync_db_session)
populate_stale_annotation_method_links(sync_db_session)
sync_db_session.commit()
verify_dataset_import(import_annotation_method_links=True)
expected_iter = iter(expected_annotation_method_links)
actual_runs = sync_db_session.get(models.Run, RUN1_ID)
for annotation in sorted(actual_runs.annotations, key=lambda x: x.s3_metadata_path):
if annotation.id != ANNOTATION_ID:
continue
assert len(annotation.method_links) == len(expected_annotation_method_links)
# for item in annotation.method_links.order_by(models.AnnotationMethodLink.link):
for item in sorted(annotation.method_links, key=lambda x: x.id, reverse=True):
verify_model(item, next(expected_iter))
Loading

0 comments on commit f38cfaf

Please sign in to comment.