diff --git a/pyproject.toml b/pyproject.toml index 5dc2f7e4f..291503bff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,6 +114,7 @@ module = [ "palace.manager.sqlalchemy.model.collection", "palace.manager.sqlalchemy.model.integration", "palace.manager.sqlalchemy.model.library", + "palace.manager.sqlalchemy.model.patron", "palace.manager.util.authentication_for_opds", "palace.manager.util.base64", "palace.manager.util.cache", diff --git a/src/palace/manager/api/annotations.py b/src/palace/manager/api/annotations.py index 022036e9c..32a1d5b88 100644 --- a/src/palace/manager/api/annotations.py +++ b/src/palace/manager/api/annotations.py @@ -11,6 +11,7 @@ ) from palace.manager.sqlalchemy.model.identifier import Identifier from palace.manager.sqlalchemy.model.patron import Annotation +from palace.manager.sqlalchemy.util import get_one_or_create from palace.manager.util.datetime_helpers import utc_now @@ -210,8 +211,9 @@ def parse(cls, _db, data, patron): # per target. extra_kwargs["target"] = target - annotation, ignore = Annotation.get_one_or_create( + annotation, ignore = get_one_or_create( _db, + Annotation, patron=patron, identifier=identifier, motivation=motivation, diff --git a/src/palace/manager/feed/acquisition.py b/src/palace/manager/feed/acquisition.py index f4db8b6c3..9450d201c 100644 --- a/src/palace/manager/feed/acquisition.py +++ b/src/palace/manager/feed/acquisition.py @@ -2,7 +2,7 @@ from __future__ import annotations import logging -from collections.abc import Callable, Generator +from collections.abc import Callable, Generator, Iterable from typing import TYPE_CHECKING, Any from dependency_injector.wiring import Provide, inject @@ -56,7 +56,7 @@ def __init__( self, title: str, url: str, - works: list[Work], + works: Iterable[Work], annotator: CirculationManagerAnnotator, facets: FacetsWithEntryPoint | None = None, pagination: Pagination | None = None, diff --git a/src/palace/manager/sqlalchemy/model/collection.py b/src/palace/manager/sqlalchemy/model/collection.py index fa42d33fe..ca20612f1 100644 --- a/src/palace/manager/sqlalchemy/model/collection.py +++ b/src/palace/manager/sqlalchemy/model/collection.py @@ -315,7 +315,7 @@ def protocol(self, new_protocol: str) -> None: STANDARD_DEFAULT_LOAN_PERIOD = 21 def default_loan_period( - self, library: Library, medium: str = EditionConstants.BOOK_MEDIUM + self, library: Library | None, medium: str = EditionConstants.BOOK_MEDIUM ) -> int: """Until we hear otherwise from the license provider, we assume that someone who borrows a non-open-access item from this @@ -336,7 +336,7 @@ def loan_period_key(cls, medium: str = EditionConstants.BOOK_MEDIUM) -> str: def default_loan_period_setting( self, - library: Library, + library: Library | None, medium: str = EditionConstants.BOOK_MEDIUM, ) -> int | None: """Until we hear otherwise from the license provider, we assume diff --git a/src/palace/manager/sqlalchemy/model/edition.py b/src/palace/manager/sqlalchemy/model/edition.py index e44255bfe..002156e10 100644 --- a/src/palace/manager/sqlalchemy/model/edition.py +++ b/src/palace/manager/sqlalchemy/model/edition.py @@ -101,7 +101,7 @@ class Edition(Base, EditionConstants): # An Edition may be the presentation edition for many LicensePools. is_presentation_for: Mapped[list[LicensePool]] = relationship( - "LicensePool", backref="presentation_edition" + "LicensePool", back_populates="presentation_edition" ) title = Column(Unicode, index=True) diff --git a/src/palace/manager/sqlalchemy/model/licensing.py b/src/palace/manager/sqlalchemy/model/licensing.py index d0482feef..62a8c508f 100644 --- a/src/palace/manager/sqlalchemy/model/licensing.py +++ b/src/palace/manager/sqlalchemy/model/licensing.py @@ -32,8 +32,10 @@ if TYPE_CHECKING: from palace.manager.sqlalchemy.model.collection import Collection from palace.manager.sqlalchemy.model.datasource import DataSource + from palace.manager.sqlalchemy.model.edition import Edition from palace.manager.sqlalchemy.model.identifier import Identifier from palace.manager.sqlalchemy.model.resource import Resource + from palace.manager.sqlalchemy.model.work import Work class PolicyException(BasePalaceException): @@ -199,6 +201,7 @@ class LicensePool(Base): # A LicensePool may be associated with a Work. (If it's not, no one # can check it out.) work_id = Column(Integer, ForeignKey("works.id"), index=True) + work: Mapped[Work] = relationship("Work", back_populates="license_pools") # Each LicensePool is associated with one DataSource and one # Identifier. @@ -224,6 +227,9 @@ class LicensePool(Base): # Each LicensePool has an Edition which contains the metadata used # to describe this book. presentation_edition_id = Column(Integer, ForeignKey("editions.id"), index=True) + presentation_edition: Mapped[Edition] = relationship( + "Edition", back_populates="is_presentation_for" + ) # If the source provides information about individual licenses, the # LicensePool may have many Licenses. diff --git a/src/palace/manager/sqlalchemy/model/patron.py b/src/palace/manager/sqlalchemy/model/patron.py index cfc9e446f..5300579e0 100644 --- a/src/palace/manager/sqlalchemy/model/patron.py +++ b/src/palace/manager/sqlalchemy/model/patron.py @@ -4,7 +4,8 @@ import datetime import logging import uuid -from typing import TYPE_CHECKING +from collections.abc import Callable +from typing import TYPE_CHECKING, Any from psycopg2.extras import NumericRange from sqlalchemy import ( @@ -29,22 +30,28 @@ from palace.manager.sqlalchemy.hybrid import hybrid_property from palace.manager.sqlalchemy.model.base import Base from palace.manager.sqlalchemy.model.credential import Credential -from palace.manager.sqlalchemy.util import get_one_or_create, numericrange_to_tuple +from palace.manager.sqlalchemy.model.datasource import DataSource +from palace.manager.sqlalchemy.util import NumericRangeTuple, numericrange_to_tuple from palace.manager.util.datetime_helpers import utc_now if TYPE_CHECKING: from palace.manager.sqlalchemy.model.devicetokens import DeviceToken + from palace.manager.sqlalchemy.model.lane import Lane from palace.manager.sqlalchemy.model.library import Library from palace.manager.sqlalchemy.model.licensing import ( License, LicensePool, LicensePoolDeliveryMechanism, ) + from palace.manager.sqlalchemy.model.work import Work class LoanAndHoldMixin: + license_pool: LicensePool + patron: Patron + @property - def work(self): + def work(self) -> Work | None: """Try to find the corresponding work for this Loan/Hold.""" license_pool = self.license_pool if not license_pool: @@ -56,11 +63,11 @@ def work(self): return None @property - def library(self): + def library(self) -> Library | None: """Try to find the corresponding library for this Loan/Hold.""" if self.patron: return self.patron.library - # If this Loan/Hold belongs to a external patron, there may be no library. + # If this Loan/Hold belongs to an external patron, there may be no library. return None @@ -151,9 +158,10 @@ class Patron(Base, RedisKeyMixin): # is never _unintentionally_ written to the database. It has to # be an explicit decision of the ILS integration code. cached_neighborhood = Column(Unicode, default=None, index=True) + neighborhood: str | None = None loans: Mapped[list[Loan]] = relationship( - "Loan", backref="patron", cascade="delete", uselist=True + "Loan", back_populates="patron", cascade="delete", uselist=True ) holds: Mapped[list[Hold]] = relationship( "Hold", @@ -165,7 +173,7 @@ class Patron(Base, RedisKeyMixin): annotations: Mapped[list[Annotation]] = relationship( "Annotation", - backref="patron", + back_populates="patron", order_by="desc(Annotation.timestamp)", cascade="delete", ) @@ -188,12 +196,8 @@ class Patron(Base, RedisKeyMixin): # than this time. MAX_SYNC_TIME = datetime.timedelta(hours=12) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.neighborhood: str | None = None - - def __repr__(self): - def date(d): + def __repr__(self) -> str: + def date(d: datetime.datetime | datetime.date | None) -> datetime.date | None: """Format an object that might be a datetime as a date. This keeps a patron representation short. @@ -210,7 +214,11 @@ def date(d): date(self.last_external_sync), ) - def identifier_to_remote_service(self, remote_data_source, generator=None): + def identifier_to_remote_service( + self, + remote_data_source: DataSource | str, + generator: Callable[[], str] | None = None, + ) -> str: """Find or randomly create an identifier to use when identifying this patron to a remote service. :param remote_data_source: A DataSource object (or name of a @@ -218,7 +226,7 @@ def identifier_to_remote_service(self, remote_data_source, generator=None): """ _db = Session.object_session(self) - def refresh(credential): + def refresh(credential: Credential) -> None: if generator and callable(generator): identifier = generator() else: @@ -233,26 +241,26 @@ def refresh(credential): refresh, allow_persistent_token=True, ) + # Any way that we create a credential should result in a result that does not + # have credential.credential set to None. Mypy doesn't know that, so we assert + # it here. + assert credential.credential is not None return credential.credential - def works_on_loan(self): - db = Session.object_session(self) - loans = db.query(Loan).filter(Loan.patron == self) + def works_on_loan(self) -> list[Work]: return [loan.work for loan in self.loans if loan.work] - def works_on_loan_or_on_hold(self): - db = Session.object_session(self) - results = set() + def works_on_loan_or_on_hold(self) -> set[Work]: holds = [hold.work for hold in self.holds if hold.work] loans = self.works_on_loan() return set(holds + loans) @hybrid_property - def synchronize_annotations(self): + def synchronize_annotations(self) -> bool | None: return self._synchronize_annotations @synchronize_annotations.setter - def synchronize_annotations(self, value): + def synchronize_annotations(self, value: bool | None) -> None: """When a patron says they don't want their annotations to be stored on a library server, delete all their annotations. """ @@ -268,7 +276,7 @@ def synchronize_annotations(self, value): self._synchronize_annotations = value @property - def root_lane(self): + def root_lane(self) -> Lane | None: """Find the Lane, if any, to be used as the Patron's root lane. A patron with a root Lane can only access that Lane and the @@ -293,7 +301,7 @@ def root_lane(self): .filter(Lane.root_for_patron_type.any(self.external_type)) .order_by(Lane.id) ) - lanes = qu.all() + lanes: list[Lane] = qu.all() if len(lanes) < 1: # The most common situation -- this patron has no special # root lane. @@ -307,7 +315,9 @@ def root_lane(self): ) return lanes[0] - def work_is_age_appropriate(self, work_audience, work_target_age): + def work_is_age_appropriate( + self, work_audience: str, work_target_age: int | tuple[int, int] + ) -> bool: """Is the given audience and target age an age-appropriate match for this Patron? NOTE: What "age-appropriate" means depends on some policy questions @@ -347,8 +357,12 @@ def work_is_age_appropriate(self, work_audience, work_target_age): @classmethod def age_appropriate_match( - cls, work_audience, work_target_age, reader_audience, reader_age - ): + cls, + work_audience: str, + work_target_age: NumericRange | NumericRangeTuple | float, + reader_audience: str | None, + reader_age: NumericRange | NumericRangeTuple | float, + ) -> bool: """Match the audience and target age of a work with that of a reader, and see whether they are an age-appropriate match. @@ -390,7 +404,9 @@ def age_appropriate_match( # At this point we know that the patron is a juvenile. - def ensure_tuple(x): + def ensure_tuple( + x: NumericRange | NumericRangeTuple | float, + ) -> NumericRangeTuple | float: # Convert a potential NumericRange into a tuple. if isinstance(x, NumericRange): x = numericrange_to_tuple(x) @@ -400,22 +416,26 @@ def ensure_tuple(x): if isinstance(reader_age, tuple): # A range was passed in rather than a specific age. Assume # the reader is at the top edge of the range. - ignore, reader_age = reader_age + _, reader_age_max = reader_age + else: + reader_age_max = reader_age work_target_age = ensure_tuple(work_target_age) if isinstance(work_target_age, tuple): # Pick the _bottom_ edge of a work's target age range -- # the work is appropriate for anyone _at least_ that old. - work_target_age, ignore = work_target_age + work_target_age_min, _ = work_target_age + else: + work_target_age_min = work_target_age # A YA reader is treated as an adult (with no reading # restrictions) if they have no associated age range, or their # age range includes ADULT_AGE_CUTOFF. if reader_audience == Classifier.AUDIENCE_YOUNG_ADULT and ( - reader_age is None + reader_age_max is None or ( - isinstance(reader_age, int) - and reader_age >= Classifier.ADULT_AGE_CUTOFF + isinstance(reader_age_max, int) + and reader_age_max >= Classifier.ADULT_AGE_CUTOFF ) ): log.debug("YA reader to be treated as an adult.") @@ -447,13 +467,13 @@ def ensure_tuple(x): # a child patron with a children's book. It comes down to a # question of the reader's age vs. the work's target age. - if work_target_age is None: + if work_target_age_min is None: # This is a generic children's or YA book with no # particular target age. Assume it's age appropriate. log.debug("Juvenile book with no target age is presumed age-appropriate.") return True - if reader_age is None: + if reader_age_max is None: # We have no idea how old the patron is, so any work with # the appropriate audience is considered age-appropriate. log.debug( @@ -461,7 +481,7 @@ def ensure_tuple(x): ) return True - if reader_age < work_target_age: + if reader_age_max < work_target_age_min: # The audience for this book matches the patron's # audience, but the book has a target age that is too high # for the reader. @@ -490,7 +510,7 @@ class Loan(Base, LoanAndHoldMixin): id = Column(Integer, primary_key=True) patron_id = Column(Integer, ForeignKey("patrons.id"), index=True) - patron: Patron # typing + patron: Mapped[Patron] = relationship("Patron", back_populates="loans") # A Loan is always associated with a LicensePool. license_pool_id = Column(Integer, ForeignKey("licensepools.id"), index=True) @@ -516,10 +536,14 @@ class Loan(Base, LoanAndHoldMixin): __table_args__ = (UniqueConstraint("patron_id", "license_pool_id"),) - def __lt__(self, other): + def __lt__(self, other: Any) -> bool: + if not isinstance(other, Loan) or self.id is None or other.id is None: + return NotImplemented return self.id < other.id - def until(self, default_loan_period): + def until( + self, default_loan_period: datetime.timedelta | None + ) -> datetime.datetime | None: """Give or estimate the time at which the loan will end.""" if self.end: return self.end @@ -550,18 +574,19 @@ class Hold(Base, LoanAndHoldMixin): "Patron", back_populates="holds", lazy="joined" ) - def __lt__(self, other): + def __lt__(self, other: Any) -> bool: + if not isinstance(other, Hold) or self.id is None or other.id is None: + return NotImplemented return self.id < other.id - @classmethod + @staticmethod def _calculate_until( - self, - start, - queue_position, - total_licenses, - default_loan_period, - default_reservation_period, - ): + start: datetime.datetime, + queue_position: int, + total_licenses: int, + default_loan_period: datetime.timedelta, + default_reservation_period: datetime.timedelta, + ) -> datetime.datetime | None: """Helper method for `Hold.until` that can be tested independently. We have to wait for the available licenses to cycle a certain number of times before we get a turn. @@ -585,7 +610,7 @@ def _calculate_until( return None # If you are at the very front of the queue, the worst case - # time to get the book is is the time it takes for the person + # time to get the book is the time it takes for the person # in front of you to get a reservation notification, borrow # the book at the last minute, and keep the book for the # maximum allowable time. @@ -607,7 +632,11 @@ def _calculate_until( cycles -= 1 return start + (cycle_period * cycles) - def until(self, default_loan_period, default_reservation_period): + def until( + self, + default_loan_period: datetime.timedelta | None, + default_reservation_period: datetime.timedelta | None, + ) -> datetime.datetime | None: """Give or estimate the time at which the book will be available to this patron. This is a *very* rough estimate that should be treated more or @@ -641,7 +670,12 @@ def until(self, default_loan_period, default_reservation_period): default_reservation_period, ) - def update(self, start, end, position): + def update( + self, + start: datetime.datetime | None, + end: datetime.datetime | None, + position: int | None, + ) -> None: """When the book becomes available, position will be 0 and end will be set to the time at which point the patron will lose their place in line. @@ -676,6 +710,8 @@ class Annotation(Base): __tablename__ = "annotations" id = Column(Integer, primary_key=True) patron_id = Column(Integer, ForeignKey("patrons.id"), index=True) + patron: Mapped[Patron] = relationship("Patron", back_populates="annotations") + identifier_id = Column(Integer, ForeignKey("identifiers.id"), index=True) motivation = Column(Unicode, index=True) timestamp = Column(DateTime(timezone=True), index=True) @@ -683,14 +719,7 @@ class Annotation(Base): content = Column(Unicode) target = Column(Unicode) - @classmethod - def get_one_or_create(self, _db, patron, *args, **kwargs): - """Find or create an Annotation, but only if the patron has - annotation sync turned on. - """ - return get_one_or_create(_db, Annotation, patron=patron, *args, **kwargs) - - def set_inactive(self): + def set_inactive(self) -> None: self.active = False self.content = None self.timestamp = utc_now() @@ -701,7 +730,7 @@ class PatronProfileStorage(ProfileStorage): Protocol. """ - def __init__(self, patron, url_for=None): + def __init__(self, patron: Patron, url_for: Callable[..., str]) -> None: """Set up a storage interface for a specific Patron. :param patron: We are accessing the profile for this patron. """ @@ -709,16 +738,16 @@ def __init__(self, patron, url_for=None): self.url_for = url_for @property - def writable_setting_names(self): + def writable_setting_names(self) -> set[str]: """Return the subset of settings that are considered writable.""" return {self.SYNCHRONIZE_ANNOTATIONS} @property - def profile_document(self): + def profile_document(self) -> dict[str, Any]: """Create a Profile document representing the patron's current status. """ - doc = dict() + doc: dict[str, Any] = dict() patron = self.patron doc[self.AUTHORIZATION_IDENTIFIER] = patron.authorization_identifier if patron.authorization_expires: @@ -740,7 +769,7 @@ def profile_document(self): ] return doc - def update(self, settable, full): + def update(self, settable: dict[str, Any], full: dict[str, Any]) -> None: """Bring the Patron's status up-to-date with the given document. Right now this means making sure Patron.synchronize_annotations is up to date. diff --git a/src/palace/manager/sqlalchemy/model/work.py b/src/palace/manager/sqlalchemy/model/work.py index 220798093..38bae8304 100644 --- a/src/palace/manager/sqlalchemy/model/work.py +++ b/src/palace/manager/sqlalchemy/model/work.py @@ -140,7 +140,7 @@ class Work(Base, LoggerMixin): # One Work may have copies scattered across many LicensePools. license_pools: Mapped[list[LicensePool]] = relationship( - "LicensePool", backref="work", lazy="joined", uselist=True + "LicensePool", back_populates="work", lazy="joined", uselist=True ) # A Work takes its presentation metadata from a single Edition. diff --git a/src/palace/manager/sqlalchemy/util.py b/src/palace/manager/sqlalchemy/util.py index a25d4772d..a09ce0b0d 100644 --- a/src/palace/manager/sqlalchemy/util.py +++ b/src/palace/manager/sqlalchemy/util.py @@ -137,7 +137,7 @@ def get_one_or_create( return db.query(model).filter_by(**kwargs).one(), False -def numericrange_to_string(r): +def numericrange_to_string(r: NumericRange | None) -> str: """Helper method to convert a NumericRange to a human-readable string.""" if not r: return "" @@ -145,10 +145,14 @@ def numericrange_to_string(r): upper = r.upper if upper is None and lower is None: return "" - if lower and upper is None: + if upper is None: return str(lower) - if upper and lower is None: + if lower is None: return str(upper) + # Currently this function only supports integer ranges, but NumericRange + # supports floats as well, so we assert that the values are integers, so + # this function fails if we ever start using floats. + assert isinstance(lower, int) and isinstance(upper, int) if not r.upper_inc: upper -= 1 if not r.lower_inc: @@ -158,7 +162,10 @@ def numericrange_to_string(r): return f"{lower}-{upper}" -def numericrange_to_tuple(r): +NumericRangeTuple = tuple[float | None, float | None] + + +def numericrange_to_tuple(r: NumericRange | None) -> NumericRangeTuple: """Helper method to normalize NumericRange into a tuple.""" if r is None: return (None, None) @@ -171,7 +178,7 @@ def numericrange_to_tuple(r): return lower, upper -def tuple_to_numericrange(t): +def tuple_to_numericrange(t: NumericRangeTuple | None) -> NumericRange | None: """Helper method to convert a tuple to an inclusive NumericRange.""" if not t: return None diff --git a/src/palace/manager/util/notifications.py b/src/palace/manager/util/notifications.py index aaad09a31..1661d31f4 100644 --- a/src/palace/manager/util/notifications.py +++ b/src/palace/manager/util/notifications.py @@ -13,7 +13,6 @@ from palace.manager.sqlalchemy.model.devicetokens import DeviceToken, DeviceTokenTypes from palace.manager.sqlalchemy.model.identifier import Identifier from palace.manager.sqlalchemy.model.patron import Hold, Loan, Patron -from palace.manager.sqlalchemy.model.work import Work from palace.manager.util.datetime_helpers import utc_now from palace.manager.util.log import LoggerMixin @@ -114,15 +113,19 @@ def send_loan_expiry_message( url = self.base_url edition = loan.license_pool.presentation_edition identifier = loan.license_pool.identifier - library_short_name = loan.library.short_name - library_name = loan.library.name + library = loan.library + # It shouldn't be possible to get here for a loan without a library, but for mypy + # and safety we will assert it anyway + assert library is not None + library_short_name = library.short_name + library_name = library.name title = f"Only {days_to_expiry} {'days' if days_to_expiry != 1 else 'day'} left on your loan!" body = f'Your loan for "{edition.title}" at {library_name} is expiring soon' data = dict( title=title, body=body, event_type=NotificationType.LOAN_EXPIRY, - loans_endpoint=f"{url}/{loan.library.short_name}/loans", + loans_endpoint=f"{url}/{library.short_name}/loans", type=identifier.type, identifier=identifier.identifier, library=library_short_name, @@ -156,16 +159,16 @@ def send_holds_notifications(self, holds: list[Hold]) -> list[str]: for hold in holds: try: tokens = self.notifiable_tokens(hold.patron) + work_title = hold.work.title # type: ignore[union-attr] self.log.info( f"Notifying patron {hold.patron.authorization_identifier or hold.patron.username} for " - f"hold: {hold.work.title}. Patron has {len(tokens)} device tokens." + f"hold: {work_title}. Patron has {len(tokens)} device tokens." ) loans_api = f"{url}/{hold.patron.library.short_name}/loans" - work: Work = hold.work identifier: Identifier = hold.license_pool.identifier library_name = hold.patron.library.name title = "Your hold is available!" - body = f'Your hold on "{work.title}" is available at {library_name}!' + body = f'Your hold on "{work_title}" is available at {library_name}!' data = dict( title=title, body=body, diff --git a/tests/manager/api/controller/test_annotation.py b/tests/manager/api/controller/test_annotation.py index d50d24465..fc9f6732b 100644 --- a/tests/manager/api/controller/test_annotation.py +++ b/tests/manager/api/controller/test_annotation.py @@ -166,7 +166,9 @@ def test_post_to_container(self, annotation_fixture: AnnotationFixture): patron.synchronize_annotations = True # The patron doesn't have any annotations yet. annotations = ( - annotation_fixture.db.session.query(Annotation).filter(Annotation.patron == patron).all() # type: ignore + annotation_fixture.db.session.query(Annotation) + .filter(Annotation.patron == patron) + .all() ) assert 0 == len(annotations) @@ -175,7 +177,9 @@ def test_post_to_container(self, annotation_fixture: AnnotationFixture): # The patron doesn't have the pool on loan yet, so the request fails. assert 400 == response.status_code annotations = ( - annotation_fixture.db.session.query(Annotation).filter(Annotation.patron == patron).all() # type: ignore + annotation_fixture.db.session.query(Annotation) + .filter(Annotation.patron == patron) + .all() ) assert 0 == len(annotations) @@ -185,7 +189,9 @@ def test_post_to_container(self, annotation_fixture: AnnotationFixture): assert 200 == response.status_code annotations = ( - annotation_fixture.db.session.query(Annotation).filter(Annotation.patron == patron).all() # type: ignore + annotation_fixture.db.session.query(Annotation) + .filter(Annotation.patron == patron) + .all() ) assert 1 == len(annotations) annotation = annotations[0] diff --git a/tests/manager/api/controller/test_profile.py b/tests/manager/api/controller/test_profile.py index 726560936..72d00a8ad 100644 --- a/tests/manager/api/controller/test_profile.py +++ b/tests/manager/api/controller/test_profile.py @@ -7,6 +7,7 @@ from palace.manager.api.authenticator import CirculationPatronProfileStorage from palace.manager.core.user_profile import ProfileController, ProfileStorage from palace.manager.sqlalchemy.model.patron import Annotation, Patron +from palace.manager.sqlalchemy.util import get_one_or_create from palace.manager.util.problem_detail import ProblemDetail from tests.fixtures.api_controller import ControllerFixture from tests.fixtures.database import DatabaseTransactionFixture @@ -94,8 +95,11 @@ def test_put(self, profile_fixture: ProfileFixture): # Now we can create an annotation for the patron who enabled # annotation sync. - Annotation.get_one_or_create( # type: ignore[unreachable] - profile_fixture.db.session, patron=request_patron, identifier=identifier + get_one_or_create( # type: ignore[unreachable] + profile_fixture.db.session, + Annotation, + patron=request_patron, + identifier=identifier, ) assert 1 == len(request_patron.annotations) diff --git a/tests/manager/api/test_monitor.py b/tests/manager/api/test_monitor.py index eac28b0d5..ba4086759 100644 --- a/tests/manager/api/test_monitor.py +++ b/tests/manager/api/test_monitor.py @@ -10,6 +10,7 @@ from palace.manager.api.opds_for_distributors import OPDSForDistributorsAPI from palace.manager.sqlalchemy.model.datasource import DataSource from palace.manager.sqlalchemy.model.patron import Annotation +from palace.manager.sqlalchemy.util import get_one_or_create from palace.manager.util.datetime_helpers import utc_now from tests.fixtures.database import DatabaseTransactionFixture @@ -199,8 +200,9 @@ def test_where_clause(self, db: DatabaseTransactionFixture): def _annotation( patron, pool, content, motivation=Annotation.IDLING, timestamp=very_old ): - annotation, ignore = Annotation.get_one_or_create( + annotation, _ = get_one_or_create( db.session, + Annotation, patron=patron, identifier=pool.identifier, motivation=motivation, diff --git a/tests/manager/api/test_opds_for_distributors.py b/tests/manager/api/test_opds_for_distributors.py index eee88afa2..80bd546ad 100644 --- a/tests/manager/api/test_opds_for_distributors.py +++ b/tests/manager/api/test_opds_for_distributors.py @@ -550,6 +550,7 @@ def test_import(self, opds_dist_api_fixture: OPDSForDistributorsAPIFixture): ) assert LicensePool.UNLIMITED_ACCESS == pool.licenses_owned assert LicensePool.UNLIMITED_ACCESS == pool.licenses_available + assert pool.work.last_update_time is not None assert (pool.work.last_update_time - now).total_seconds() <= 2 assert pool.should_track_playtime == False diff --git a/tests/manager/core/test_opds_import.py b/tests/manager/core/test_opds_import.py index aac53eef5..7e670d7e2 100644 --- a/tests/manager/core/test_opds_import.py +++ b/tests/manager/core/test_opds_import.py @@ -870,9 +870,11 @@ def test_import(self, opds_importer_fixture: OPDSImporterFixture): classifier = Classifier.classifiers.get(seven.subject.type, None) classifier.classify(seven.subject) - [crow_pool, mouse_pool] = sorted( - pools, key=lambda x: x.presentation_edition.title - ) + def sort_key(x: LicensePool) -> str: + assert x.presentation_edition.title is not None + return x.presentation_edition.title + + [crow_pool, mouse_pool] = sorted(pools, key=sort_key) assert db.default_collection() == crow_pool.collection assert db.default_collection() == mouse_pool.collection @@ -881,6 +883,7 @@ def test_import(self, opds_importer_fixture: OPDSImporterFixture): work = mouse_pool.work work.calculate_presentation() + assert work.quality is not None assert 0.4142 == round(work.quality, 4) assert Classifier.AUDIENCE_CHILDREN == work.audience assert NumericRange(7, 7, "[]") == work.target_age diff --git a/tests/manager/sqlalchemy/model/test_patron.py b/tests/manager/sqlalchemy/model/test_patron.py index 506f0f595..00e21ae80 100644 --- a/tests/manager/sqlalchemy/model/test_patron.py +++ b/tests/manager/sqlalchemy/model/test_patron.py @@ -1,7 +1,8 @@ import datetime -from unittest.mock import MagicMock, call +from unittest.mock import MagicMock, call, patch import pytest +from freezegun import freeze_time from palace.manager.core.classifier import Classifier from palace.manager.sqlalchemy.constants import LinkRelations @@ -15,7 +16,11 @@ Patron, PatronProfileStorage, ) -from palace.manager.sqlalchemy.util import create, tuple_to_numericrange +from palace.manager.sqlalchemy.util import ( + create, + get_one_or_create, + tuple_to_numericrange, +) from palace.manager.util.datetime_helpers import datetime_utc, utc_now from tests.fixtures.database import DatabaseTransactionFixture from tests.fixtures.library import LibraryFixture @@ -118,6 +123,7 @@ def test_work(self, db: DatabaseTransactionFixture): hold, is_new = pool.on_hold_to(patron) assert work == hold.work + @freeze_time() def test_until(self, db: DatabaseTransactionFixture): one_day = datetime.timedelta(days=1) two_days = datetime.timedelta(days=2) @@ -153,44 +159,27 @@ def test_until(self, db: DatabaseTransactionFixture): assert None == m(one_day, None) # Otherwise, the answer is determined by _calculate_until. - def _mock__calculate_until(self, *args): - """Track the arguments passed into _calculate_until.""" - self.called_with = args - return "mock until" - - old__calculate_until = hold._calculate_until - Hold._calculate_until = _mock__calculate_until - - assert "mock until" == m(one_day, two_days) - - ( - calculate_from, - position, - licenses_available, - default_loan_period, - default_reservation_period, - ) = hold.called_with - - assert (calculate_from - now).total_seconds() < 5 - assert hold.position == position - assert pool.licenses_available == licenses_available - assert one_day == default_loan_period - assert two_days == default_reservation_period - - # If we don't know the patron's position in the hold queue, we - # assume they're at the end. - hold.position = None - assert "mock until" == m(one_day, two_days) - ( - calculate_from, - position, - licenses_available, - default_loan_period, - default_reservation_period, - ) = hold.called_with - assert pool.patrons_in_hold_queue == position - - Hold._calculate_until = old__calculate_until + with patch.object(Hold, "_calculate_until") as _mock_calculate_until: + _mock_calculate_until.return_value = "mock until" + + assert "mock until" == m(one_day, two_days) + _mock_calculate_until.assert_called_once_with( + now, hold.position, pool.licenses_available, one_day, two_days + ) + + _mock_calculate_until.reset_mock() + + # If we don't know the patron's position in the hold queue, we + # assume they're at the end. + hold.position = None + assert "mock until" == m(one_day, two_days) + _mock_calculate_until.assert_called_once_with( + now, + pool.patrons_in_hold_queue, + pool.licenses_available, + one_day, + two_days, + ) def test_calculate_until(self): start = datetime_utc(2010, 1, 1) @@ -430,8 +419,9 @@ def test_set_synchronize_annotations(self, db: DatabaseTransactionFixture): patron.synchronize_annotations = True # Each patron gets one annotation. - annotation, ignore = Annotation.get_one_or_create( + annotation, ignore = get_one_or_create( db.session, + Annotation, patron=patron, identifier=identifier, motivation=Annotation.IDLING, @@ -448,8 +438,9 @@ def test_set_synchronize_annotations(self, db: DatabaseTransactionFixture): assert 0 == len(p1.annotations) # But patron #2 can use Annotation.get_one_or_create. - i2, is_new = Annotation.get_one_or_create( + i2, is_new = get_one_or_create( db.session, + Annotation, patron=p2, identifier=db.identifier(), motivation=Annotation.IDLING, @@ -570,8 +561,8 @@ def mock_age_appropriate( # If the patron has no root lane, age_appropriate_match is not # even called -- all works are age-appropriate. m = patron.work_is_age_appropriate - work_audience = object() - work_target_age = object() + work_audience = MagicMock() + work_target_age = MagicMock() assert True == m(work_audience, work_target_age) assert 0 == mock.call_count diff --git a/tests/manager/sqlalchemy/test_util.py b/tests/manager/sqlalchemy/test_util.py index 17f78bcc7..6ea152894 100644 --- a/tests/manager/sqlalchemy/test_util.py +++ b/tests/manager/sqlalchemy/test_util.py @@ -6,6 +6,7 @@ from palace.manager.sqlalchemy.model.edition import Edition from palace.manager.sqlalchemy.util import ( get_one, + numericrange_to_string, numericrange_to_tuple, pg_advisory_lock, tuple_to_numericrange, @@ -103,3 +104,9 @@ def test_exception_case(self, db: DatabaseTransactionFixture): def test_no_lock_id(self, db: DatabaseTransactionFixture): with pg_advisory_lock(db.session, None): assert self._lock_exists(db.session, self.TEST_LOCK_ID) is False + + +class TestNumericRangeToString: + def test_numericrange_to_string_float(self): + with pytest.raises(AssertionError): + numericrange_to_string(NumericRange(1.1, 1.8, "[]")) diff --git a/tests/manager/util/test_notifications.py b/tests/manager/util/test_notifications.py index 44694e256..d07a06139 100644 --- a/tests/manager/util/test_notifications.py +++ b/tests/manager/util/test_notifications.py @@ -95,6 +95,9 @@ def test_send_loan_notification(self, push_notf_fixture: PushNotificationsFixtur loan, 1, [device_token] ) + library = loan.library + assert library is not None + assert messaging.Message.call_count == 1 assert messaging.Message.call_args_list[0] == [ (), @@ -106,14 +109,14 @@ def test_send_loan_notification(self, push_notf_fixture: PushNotificationsFixtur ), "data": dict( title="Only 1 day left on your loan!", - body=f'Your loan for "{work.presentation_edition.title}" at {loan.library.name} is expiring soon', + body=f'Your loan for "{work.presentation_edition.title}" at {library.name} is expiring soon', event_type=NotificationType.LOAN_EXPIRY, loans_endpoint="http://localhost/default/loans", external_identifier=patron.external_identifier, authorization_identifier=patron.authorization_identifier, identifier=work.presentation_edition.primary_identifier.identifier, type=work.presentation_edition.primary_identifier.type, - library=loan.library.short_name, + library=library.short_name, days_to_expiry="1", ), }, @@ -223,6 +226,7 @@ def assert_message_call( hold: Hold, include_auth_id: bool = True, ) -> None: + assert hold.library is not None data = dict( title="Your hold is available!", body=f'Your hold on "{work.title}" is available at {hold.library.name}!',