From faea04df1ac5db0c38f5015de9b316af69904858 Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Thu, 22 Aug 2024 09:17:19 -0300 Subject: [PATCH 1/7] Redis json lock added --- README.md | 2 +- docker-compose.yml | 2 +- .../manager/service/redis/models/lock.py | 204 +++++++++++++++++- src/palace/manager/service/redis/redis.py | 13 ++ src/palace/manager/service/storage/s3.py | 2 + tests/fixtures/database.py | 59 ++++- .../manager/service/redis/models/test_lock.py | 127 ++++++++++- tox.ini | 2 +- 8 files changed, 391 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index 2589105bd5..25ae67f438 100644 --- a/README.md +++ b/README.md @@ -149,7 +149,7 @@ grant all privileges on database circ to palace; Redis is used as the broker for Celery and the caching layer. You can run Redis with docker using the following command: ```sh -docker run -d --name redis -p 6379:6379 redis +docker run -d --name redis -p 6379:6379 redis/redis-stack-server ``` ### Environment variables diff --git a/docker-compose.yml b/docker-compose.yml index f9e801d4a2..ad6faefe4d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -110,7 +110,7 @@ services: retries: 5 redis: - image: "redis:7" + image: "redis/redis-stack-server:7.4.0-v0" healthcheck: test: ["CMD", "redis-cli", "ping"] interval: 30s diff --git a/src/palace/manager/service/redis/models/lock.py b/src/palace/manager/service/redis/models/lock.py index ef4c348aff..dfc2a74f60 100644 --- a/src/palace/manager/service/redis/models/lock.py +++ b/src/palace/manager/service/redis/models/lock.py @@ -1,11 +1,12 @@ +import json import random import time from abc import ABC, abstractmethod -from collections.abc import Generator, Sequence +from collections.abc import Generator, Mapping, Sequence from contextlib import contextmanager from datetime import timedelta from functools import cached_property -from typing import cast +from typing import Any, TypeVar, cast from uuid import uuid4 from palace.manager.celery.task import Task @@ -69,6 +70,18 @@ def key(self) -> str: :return: The key used to store the lock in Redis. """ + def _exception_exit(self) -> None: + """ + Clean up before exiting the context manager, if an exception occurs. + """ + self.release() + + def _normal_exit(self) -> None: + """ + Clean up before exiting the context manager, if no exception occurs. + """ + self.release() + @contextmanager def lock( self, @@ -94,10 +107,10 @@ def lock( exception_occurred = True raise finally: - if (release_on_error and exception_occurred) or ( - release_on_exit and not exception_occurred - ): - self.release() + if release_on_error and exception_occurred: + self._exception_exit() + elif release_on_exit and not exception_occurred: + self._normal_exit() class RedisLock(BaseRedisLock): @@ -232,3 +245,182 @@ def __init__( else: name = [lock_name] super().__init__(redis_client, name, random_value, lock_timeout, retry_delay) + + +class RedisJsonLock(BaseRedisLock, ABC): + _ACQUIRE_SCRIPT = """ + -- If the locks json object doesn't exist, create it with the initial value + redis.call("json.set", KEYS[1], "$", ARGV[4], "nx") + + -- Get the current lock value + local lock_value = cjson.decode(redis.call("json.get", KEYS[1], ARGV[1]))[1] + if not lock_value then + -- The lock isn't currently locked, so we lock it and set the timeout + redis.call("json.set", KEYS[1], ARGV[1], cjson.encode(ARGV[2])) + redis.call("pexpire", KEYS[1], ARGV[3]) + return 1 + elseif lock_value == ARGV[2] then + -- The lock is already held by us, so we extend the timeout + redis.call("pexpire", KEYS[1], ARGV[3]) + return 2 + else + -- The lock is held by someone else, we do nothing + return nil + end + """ + + _RELEASE_SCRIPT = """ + if cjson.decode(redis.call("json.get", KEYS[1], ARGV[1]))[1] == ARGV[2] then + redis.call("json.del", KEYS[1], ARGV[1]) + return 1 + else + return nil + end + """ + + _EXTEND_SCRIPT = """ + if cjson.decode(redis.call("json.get", KEYS[1], ARGV[1]))[1] == ARGV[2] then + redis.call("pexpire", KEYS[1], ARGV[3]) + return 1 + else + return nil + end + """ + + _DELETE_SCRIPT = """ + if cjson.decode(redis.call("json.get", KEYS[1], ARGV[1]))[1] == ARGV[2] then + redis.call("del", KEYS[1]) + return 1 + else + return nil + end + """ + + def __init__( + self, + redis_client: Redis, + random_value: str | None = None, + ): + super().__init__(redis_client, random_value) + + # Register our scripts + self._acquire_script = self._redis_client.register_script(self._ACQUIRE_SCRIPT) + self._release_script = self._redis_client.register_script(self._RELEASE_SCRIPT) + self._extend_script = self._redis_client.register_script(self._EXTEND_SCRIPT) + self._delete_script = self._redis_client.register_script(self._DELETE_SCRIPT) + + @property + @abstractmethod + def _lock_timeout_ms(self) -> int: + """ + The lock timeout in milliseconds. + """ + ... + + @property + def _lock_json_key(self) -> str: + """ + The key to use for the lock value in the JSON object. + + This can be overridden if you need to store the lock value in a different key. It should + be a Redis JSONPath. + See: https://redis.io/docs/latest/develop/data-types/json/path/ + """ + return "$.lock" + + @property + def _initial_value(self) -> str: + """ + The initial value to use for the locks JSON object. + """ + return json.dumps({}) + + T = TypeVar("T") + + @classmethod + def _parse_multi( + cls, value: Mapping[str, Sequence[T]] | None + ) -> dict[str, T | None]: + if value is None: + return {} + return {k: cls._parse_value(v) for k, v in value.items()} + + @staticmethod + def _parse_value(value: Sequence[T] | None) -> T | None: + if value is None: + return None + try: + return value[0] + except IndexError: + return None + + @classmethod + def _parse_value_or_raise(cls, value: Sequence[T] | None) -> T: + parsed_value = cls._parse_value(value) + if parsed_value is None: + raise LockError(f"Could not parse value ({json.dumps(value)})") + return parsed_value + + def _get_value(self, json_key: str) -> Any | None: + value = self._redis_client.json().get(self.key, json_key) + if value is None or len(value) != 1: + return None + return value[0] + + def acquire(self) -> bool: + return ( + self._acquire_script( + keys=(self.key,), + args=( + self._lock_json_key, + self._random_value, + self._lock_timeout_ms, + self._initial_value, + ), + ) + is not None + ) + + def release(self) -> bool: + """ + Release the lock. + + You must have the lock to release it. This will unset the lock value in the JSON object, but importantly + it will not delete the JSON object itself. If you want to delete the JSON object, use the delete method. + """ + return ( + self._release_script( + keys=(self.key,), + args=(self._lock_json_key, self._random_value), + ) + is not None + ) + + def locked(self, by_us: bool = False) -> bool: + lock_value: str | None = self._parse_value( + self._redis_client.json().get(self.key, self._lock_json_key) + ) + if by_us: + return lock_value == self._random_value + return lock_value is not None + + def extend_timeout(self) -> bool: + return ( + self._extend_script( + keys=(self.key,), + args=(self._lock_json_key, self._random_value, self._lock_timeout_ms), + ) + is not None + ) + + def delete(self) -> bool: + """ + Delete the whole json object, including the lock. Must have the lock to delete the object. + """ + return ( + self._delete_script( + keys=(self.key,), + args=(self._lock_json_key, self._random_value), + ) + is not None + ) diff --git a/src/palace/manager/service/redis/redis.py b/src/palace/manager/service/redis/redis.py index c9cd41be9b..cd73c4edd6 100644 --- a/src/palace/manager/service/redis/redis.py +++ b/src/palace/manager/service/redis/redis.py @@ -79,19 +79,29 @@ def key_args(self, args: list[Any]) -> Sequence[str]: RedisCommandArgs("KEYS"), RedisCommandArgs("GET"), RedisCommandArgs("EXPIRE"), + RedisCommandArgs("PEXPIRE"), RedisCommandArgs("GETRANGE"), RedisCommandArgs("SET"), RedisCommandArgs("TTL"), RedisCommandArgs("PTTL"), + RedisCommandArgs("PTTL"), RedisCommandArgs("SADD"), RedisCommandArgs("SPOP"), RedisCommandArgs("SCARD"), + RedisCommandArgs("WATCH"), RedisCommandArgs("SRANDMEMBER"), RedisCommandArgs("SREM"), RedisCommandArgs("DEL", args_end=None), RedisCommandArgs("MGET", args_end=None), RedisCommandArgs("EXISTS", args_end=None), RedisCommandArgs("EXPIRETIME"), + RedisCommandArgs("JSON.SET"), + RedisCommandArgs("JSON.STRLEN"), + RedisCommandArgs("JSON.STRAPPEND"), + RedisCommandArgs("JSON.NUMINCRBY"), + RedisCommandArgs("JSON.GET"), + RedisCommandArgs("JSON.OBJKEYS"), + RedisCommandArgs("JSON.ARRAPPEND"), RedisVariableCommandArgs("EVALSHA", key_index=1), ] } @@ -161,3 +171,6 @@ def _prefix(self) -> str: def execute_command(self, *args: Any, **options: Any) -> Any: self._check_prefix(*args) return super().execute_command(*args, **options) + + def __enter__(self) -> Pipeline: + return self diff --git a/src/palace/manager/service/storage/s3.py b/src/palace/manager/service/storage/s3.py index 97704bfffa..fa8dc2e91a 100644 --- a/src/palace/manager/service/storage/s3.py +++ b/src/palace/manager/service/storage/s3.py @@ -110,6 +110,8 @@ def exception(self) -> BaseException | None: class S3Service(LoggerMixin): + MINIMUM_MULTIPART_UPLOAD_SIZE = 5 * 1024 * 1024 # 5MB + def __init__( self, client: S3Client, diff --git a/tests/fixtures/database.py b/tests/fixtures/database.py index 7b1b384c27..f8baad2054 100644 --- a/tests/fixtures/database.py +++ b/tests/fixtures/database.py @@ -6,7 +6,7 @@ import tempfile import time import uuid -from collections.abc import Generator, Iterable +from collections.abc import Generator, Iterable, Mapping from contextlib import contextmanager from functools import cached_property from textwrap import dedent @@ -36,10 +36,15 @@ from palace.manager.core.config import Configuration from palace.manager.core.exceptions import BasePalaceException, PalaceValueError from palace.manager.core.opds_import import OPDSAPI -from palace.manager.integration.base import HasIntegrationConfiguration +from palace.manager.integration.base import ( + HasIntegrationConfiguration, + HasLibraryIntegrationConfiguration, +) from palace.manager.integration.base import SettingsType as TIntegrationSettings from palace.manager.integration.configuration.library import LibrarySettings from palace.manager.integration.goals import Goals +from palace.manager.integration.settings import BaseSettings +from palace.manager.service.integration_registry.base import IntegrationRegistry from palace.manager.sqlalchemy.constants import MediaTypes from palace.manager.sqlalchemy.model.classification import ( Classification, @@ -921,6 +926,16 @@ def license( def isbn_take(self) -> str: return self._isbns.pop() + @cached_property + def _goal_registry_mapping(self) -> Mapping[Goals, IntegrationRegistry[Any]]: + return { + Goals.CATALOG_GOAL: self._services.services.integration_registry.catalog_services(), + Goals.DISCOVERY_GOAL: self._services.services.integration_registry.discovery(), + Goals.LICENSE_GOAL: self._services.services.integration_registry.license_providers(), + Goals.METADATA_GOAL: self._services.services.integration_registry.metadata(), + Goals.PATRON_AUTH_GOAL: self._services.services.integration_registry.patron_auth(), + } + def integration_configuration( self, protocol: type[HasIntegrationConfiguration[TIntegrationSettings]] | str, @@ -930,17 +945,10 @@ def integration_configuration( name: str | None = None, settings: TIntegrationSettings | None = None, ) -> IntegrationConfiguration: - registry_mapping = { - Goals.CATALOG_GOAL: self._services.services.integration_registry.catalog_services(), - Goals.DISCOVERY_GOAL: self._services.services.integration_registry.discovery(), - Goals.LICENSE_GOAL: self._services.services.integration_registry.license_providers(), - Goals.METADATA_GOAL: self._services.services.integration_registry.metadata(), - Goals.PATRON_AUTH_GOAL: self._services.services.integration_registry.patron_auth(), - } protocol_str = ( protocol if isinstance(protocol, str) - else registry_mapping[goal].get_protocol(protocol) + else self._goal_registry_mapping[goal].get_protocol(protocol) ) assert protocol_str is not None integration, ignore = get_one_or_create( @@ -972,6 +980,37 @@ def integration_configuration( return integration + def integration_library_configuration( + self, + parent: IntegrationConfiguration, + library: Library, + settings: BaseSettings | None = None, + ) -> IntegrationLibraryConfiguration: + assert parent.goal is not None + assert parent.protocol is not None + parent_cls = self._goal_registry_mapping[parent.goal][parent.protocol] + if not issubclass(parent_cls, HasLibraryIntegrationConfiguration): + raise TypeError( + f"{parent_cls.__name__} does not support library configuration" + ) + + integration, ignore = get_one_or_create( + self.session, + IntegrationLibraryConfiguration, + parent=parent, + library=library, + ) + + if settings is not None: + if not isinstance(settings, parent_cls.library_settings_class()): + raise TypeError( + f"settings must be an instance of {parent_cls.library_settings_class().__name__} " + f"not {settings.__class__.__name__}" + ) + parent_cls.library_settings_update(integration, settings) + + return integration + def discovery_service_integration( self, url: str | None = None ) -> IntegrationConfiguration: diff --git a/tests/manager/service/redis/models/test_lock.py b/tests/manager/service/redis/models/test_lock.py index ca7aa956d6..93cd874709 100644 --- a/tests/manager/service/redis/models/test_lock.py +++ b/tests/manager/service/redis/models/test_lock.py @@ -1,10 +1,17 @@ from datetime import timedelta +from typing import Any from unittest.mock import create_autospec import pytest from palace.manager.celery.task import Task -from palace.manager.service.redis.models.lock import LockError, RedisLock, TaskLock +from palace.manager.service.redis.models.lock import ( + LockError, + RedisJsonLock, + RedisLock, + TaskLock, +) +from palace.manager.service.redis.redis import Redis from tests.fixtures.redis import RedisFixture @@ -182,3 +189,121 @@ def test___init__(self, redis_fixture: RedisFixture): # If we provide a lock_name, we should use that instead task_lock = TaskLock(redis_fixture.client, mock_task, lock_name="test_lock") assert task_lock.key.endswith("::TaskLock::test_lock") + + +class MockJsonLock(RedisJsonLock): + def __init__( + self, + redis_client: Redis, + key: str = "test", + timeout: int = 1000, + random_value: str | None = None, + ): + self._key = redis_client.get_key(key) + self._timeout = timeout + super().__init__(redis_client, random_value) + + @property + def key(self) -> str: + return self._key + + @property + def _lock_timeout_ms(self) -> int: + return self._timeout + + +class JsonLockFixture: + def __init__(self, redis_fixture: RedisFixture) -> None: + self.client = redis_fixture.client + self.lock = MockJsonLock(redis_fixture.client) + self.other_lock = MockJsonLock(redis_fixture.client) + + def get_key(self, key: str, json_key: str) -> Any: + ret_val = self.client.json().get(key, json_key) + if ret_val is None or len(ret_val) != 1: + return None + return ret_val[0] + + def assert_locked(self, lock: RedisJsonLock) -> None: + assert self.get_key(lock.key, lock._lock_json_key) == lock._random_value + + +@pytest.fixture +def json_lock_fixture(redis_fixture: RedisFixture) -> JsonLockFixture: + return JsonLockFixture(redis_fixture) + + +class TestJsonLock: + def test_acquire(self, json_lock_fixture: JsonLockFixture): + # We can acquire the lock. And acquiring the lock sets a timeout on the key, so the lock + # will expire eventually if something goes wrong. + assert json_lock_fixture.lock.acquire() + assert json_lock_fixture.client.ttl(json_lock_fixture.lock.key) > 0 + json_lock_fixture.assert_locked(json_lock_fixture.lock) + + # Acquiring the lock again with the same random value should return True + # and extend the timeout for the lock + json_lock_fixture.client.pexpire(json_lock_fixture.lock.key, 500) + timeout = json_lock_fixture.client.pttl(json_lock_fixture.lock.key) + assert json_lock_fixture.lock.acquire() + assert json_lock_fixture.client.pttl(json_lock_fixture.lock.key) > timeout + + # Acquiring the lock again with a different random value should return False + assert not json_lock_fixture.other_lock.acquire() + json_lock_fixture.assert_locked(json_lock_fixture.lock) + + def test_release(self, json_lock_fixture: JsonLockFixture): + # If you acquire a lock another client cannot release it + assert json_lock_fixture.lock.acquire() + assert json_lock_fixture.other_lock.release() is False + + # Make sure the key is set in redis + json_lock_fixture.assert_locked(json_lock_fixture.lock) + + # But the client that acquired the lock can release it + assert json_lock_fixture.lock.release() is True + + # And the key should still exist, but the lock key in the json is removed from redis + assert json_lock_fixture.get_key(json_lock_fixture.lock.key, "$") == {} + + def test_delete(self, json_lock_fixture: JsonLockFixture): + # If you acquire a lock another client cannot delete it + assert json_lock_fixture.lock.acquire() + assert json_lock_fixture.other_lock.delete() is False + + # Make sure the key is set in redis + assert json_lock_fixture.get_key(json_lock_fixture.lock.key, "$") is not None + json_lock_fixture.assert_locked(json_lock_fixture.lock) + + # But the client that acquired the lock can delete it + assert json_lock_fixture.lock.delete() is True + + # And the key should still exist, but the lock key in the json is removed from redis + assert json_lock_fixture.get_key(json_lock_fixture.lock.key, "$") is None + + def test_extend_timeout(self, json_lock_fixture: JsonLockFixture): + # If the lock has a timeout, the acquiring client can extend it, but another client cannot + assert json_lock_fixture.lock.acquire() + json_lock_fixture.client.pexpire(json_lock_fixture.lock.key, 500) + assert json_lock_fixture.other_lock.extend_timeout() is False + assert json_lock_fixture.client.pttl(json_lock_fixture.lock.key) <= 500 + + # The key should have a new timeout + assert json_lock_fixture.lock.extend_timeout() is True + assert json_lock_fixture.client.pttl(json_lock_fixture.lock.key) > 500 + + def test_locked(self, json_lock_fixture: JsonLockFixture): + # If the lock is not acquired, it should not be locked + assert json_lock_fixture.lock.locked() is False + + # If the lock is acquired, it should be locked + assert json_lock_fixture.lock.acquire() + assert json_lock_fixture.lock.locked() is True + assert json_lock_fixture.other_lock.locked() is True + assert json_lock_fixture.lock.locked(by_us=True) is True + assert json_lock_fixture.other_lock.locked(by_us=True) is False + + # If the lock is released, it should not be locked + assert json_lock_fixture.lock.release() is True + assert json_lock_fixture.lock.locked() is False + assert json_lock_fixture.other_lock.locked() is False diff --git a/tox.ini b/tox.ini index 51ab13f7ec..8aa20dcf75 100644 --- a/tox.ini +++ b/tox.ini @@ -76,7 +76,7 @@ host_var = PALACE_TEST_MINIO_URL_HOST [docker:redis-circ] -image = redis:7 +image = redis/redis-stack-server:7.4.0-v0 expose = PALACE_TEST_REDIS_URL_PORT=6379/tcp host_var = From 9e615a3879a27faf3a6e0cbd1d775de621ec5ebf Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Thu, 22 Aug 2024 09:23:54 -0300 Subject: [PATCH 2/7] Celery MarcFileExporter implementation --- .gitignore | 3 + bin/cache_marc_files | 6 - docker/services/cron/cron.d/circulation | 3 - .../api/admin/controller/catalog_services.py | 4 +- src/palace/manager/api/circulation_manager.py | 5 +- src/palace/manager/api/controller/marc.py | 16 +- src/palace/manager/celery/tasks/marc.py | 144 +++ src/palace/manager/marc/__init__.py | 0 .../{core/marc.py => marc/annotator.py} | 406 ++------ src/palace/manager/marc/exporter.py | 373 ++++++++ src/palace/manager/marc/settings.py | 73 ++ src/palace/manager/marc/uploader.py | 142 +++ src/palace/manager/scripts/marc.py | 223 ----- src/palace/manager/service/celery/celery.py | 6 + .../integration_registry/catalog_services.py | 14 +- .../manager/service/redis/models/lock.py | 69 +- .../manager/service/redis/models/marc.py | 271 ++++++ src/palace/manager/service/redis/redis.py | 2 + tests/conftest.py | 1 + tests/fixtures/marc.py | 99 ++ tests/fixtures/s3.py | 173 +++- .../admin/controller/test_catalog_services.py | 29 +- tests/manager/api/controller/test_marc.py | 21 +- tests/manager/celery/tasks/test_marc.py | 287 ++++++ tests/manager/core/test_marc.py | 900 ------------------ tests/manager/marc/__init__.py | 0 tests/manager/marc/test_annotator.py | 716 ++++++++++++++ tests/manager/marc/test_exporter.py | 425 +++++++++ tests/manager/marc/test_uploader.py | 314 ++++++ tests/manager/scripts/test_marc.py | 466 --------- .../manager/service/redis/models/test_lock.py | 26 + .../manager/service/redis/models/test_marc.py | 406 ++++++++ tests/manager/service/storage/test_s3.py | 97 +- 33 files changed, 3629 insertions(+), 2091 deletions(-) delete mode 100755 bin/cache_marc_files create mode 100644 src/palace/manager/celery/tasks/marc.py create mode 100644 src/palace/manager/marc/__init__.py rename src/palace/manager/{core/marc.py => marc/annotator.py} (60%) create mode 100644 src/palace/manager/marc/exporter.py create mode 100644 src/palace/manager/marc/settings.py create mode 100644 src/palace/manager/marc/uploader.py delete mode 100644 src/palace/manager/scripts/marc.py create mode 100644 src/palace/manager/service/redis/models/marc.py create mode 100644 tests/fixtures/marc.py create mode 100644 tests/manager/celery/tasks/test_marc.py delete mode 100644 tests/manager/core/test_marc.py create mode 100644 tests/manager/marc/__init__.py create mode 100644 tests/manager/marc/test_annotator.py create mode 100644 tests/manager/marc/test_exporter.py create mode 100644 tests/manager/marc/test_uploader.py delete mode 100644 tests/manager/scripts/test_marc.py create mode 100644 tests/manager/service/redis/models/test_marc.py diff --git a/.gitignore b/.gitignore index 7ee0099696..50d20bcb5c 100644 --- a/.gitignore +++ b/.gitignore @@ -78,3 +78,6 @@ docs/source/* .DS_Store src/palace/manager/core/_version.py + +# Celery beat schedule file +celerybeat-schedule.db diff --git a/bin/cache_marc_files b/bin/cache_marc_files deleted file mode 100755 index b42e34ce62..0000000000 --- a/bin/cache_marc_files +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/env python -"""Refresh and store the MARC files for lanes.""" - -from palace.manager.scripts.marc import CacheMARCFiles - -CacheMARCFiles().run() diff --git a/docker/services/cron/cron.d/circulation b/docker/services/cron/cron.d/circulation index c82a52652b..10122136ed 100644 --- a/docker/services/cron/cron.d/circulation +++ b/docker/services/cron/cron.d/circulation @@ -36,9 +36,6 @@ HOME=/var/www/circulation # Sync a library's collection with NoveList 0 0 * * 0 root bin/run -d 60 novelist_update >> /var/log/cron.log 2>&1 -# Generate MARC files for libraries that have a MARC exporter configured. -0 3,11 * * * root bin/run cache_marc_files >> /var/log/cron.log 2>&1 - # The remaining scripts keep the circulation manager in sync with # specific types of collections. diff --git a/src/palace/manager/api/admin/controller/catalog_services.py b/src/palace/manager/api/admin/controller/catalog_services.py index 5245ec6ee1..b2dd6ddaef 100644 --- a/src/palace/manager/api/admin/controller/catalog_services.py +++ b/src/palace/manager/api/admin/controller/catalog_services.py @@ -8,9 +8,9 @@ ) from palace.manager.api.admin.form_data import ProcessFormData from palace.manager.api.admin.problem_details import MULTIPLE_SERVICES_FOR_LIBRARY -from palace.manager.core.marc import MARCExporter from palace.manager.integration.goals import Goals from palace.manager.integration.settings import BaseSettings +from palace.manager.marc.exporter import MarcExporter from palace.manager.sqlalchemy.listeners import site_configuration_has_changed from palace.manager.sqlalchemy.model.integration import ( IntegrationConfiguration, @@ -21,7 +21,7 @@ class CatalogServicesController( - IntegrationSettingsController[MARCExporter], + IntegrationSettingsController[MarcExporter], AdminPermissionsControllerMixin, ): def process_catalog_services(self) -> Response | ProblemDetail: diff --git a/src/palace/manager/api/circulation_manager.py b/src/palace/manager/api/circulation_manager.py index 36d569c7e3..d6d2d682b7 100644 --- a/src/palace/manager/api/circulation_manager.py +++ b/src/palace/manager/api/circulation_manager.py @@ -343,7 +343,10 @@ def setup_one_time_controllers(self): """ self.index_controller = IndexController(self) self.opds_feeds = OPDSFeedController(self) - self.marc_records = MARCRecordController(self.services.storage.public()) + self.marc_records = MARCRecordController( + self.services.storage.public(), + self.services.integration_registry.catalog_services(), + ) self.loans = LoanController(self) self.annotations = AnnotationController(self) self.urn_lookup = URNLookupController(self) diff --git a/src/palace/manager/api/controller/marc.py b/src/palace/manager/api/controller/marc.py index 802a576081..3114fbca40 100644 --- a/src/palace/manager/api/controller/marc.py +++ b/src/palace/manager/api/controller/marc.py @@ -9,8 +9,11 @@ from sqlalchemy import select from sqlalchemy.orm import Session -from palace.manager.core.marc import MARCExporter from palace.manager.integration.goals import Goals +from palace.manager.marc.exporter import MarcExporter +from palace.manager.service.integration_registry.catalog_services import ( + CatalogServicesRegistry, +) from palace.manager.service.storage.s3 import S3Service from palace.manager.sqlalchemy.model.collection import Collection from palace.manager.sqlalchemy.model.integration import ( @@ -49,21 +52,24 @@ class MARCRecordController: """ - def __init__(self, storage_service: S3Service | None) -> None: + def __init__( + self, storage_service: S3Service | None, registry: CatalogServicesRegistry + ) -> None: self.storage_service = storage_service + self.registry = registry @staticmethod def library() -> Library: return flask.request.library # type: ignore[no-any-return,attr-defined] - @staticmethod - def has_integration(session: Session, library: Library) -> bool: + def has_integration(self, session: Session, library: Library) -> bool: + protocols = self.registry.get_protocols(MarcExporter) integration_query = ( select(IntegrationLibraryConfiguration) .join(IntegrationConfiguration) .where( IntegrationConfiguration.goal == Goals.CATALOG_GOAL, - IntegrationConfiguration.protocol == MARCExporter.__name__, + IntegrationConfiguration.protocol.in_(protocols), IntegrationLibraryConfiguration.library == library, ) ) diff --git a/src/palace/manager/celery/tasks/marc.py b/src/palace/manager/celery/tasks/marc.py new file mode 100644 index 0000000000..920ce34a70 --- /dev/null +++ b/src/palace/manager/celery/tasks/marc.py @@ -0,0 +1,144 @@ +import datetime +from typing import Any + +from celery import shared_task + +from palace.manager.celery.task import Task +from palace.manager.marc.exporter import LibraryInfo, MarcExporter +from palace.manager.marc.uploader import MarcUploader +from palace.manager.service.celery.celery import QueueNames +from palace.manager.service.redis.models.marc import MarcFileUploads +from palace.manager.util.datetime_helpers import utc_now + + +@shared_task(queue=QueueNames.default, bind=True) +def marc_export(task: Task, force: bool = False) -> None: + """ + Export MARC records for all collections with the `export_marc_records` flag set to True, whose libraries + have a MARC exporter integration enabled. + """ + + with task.session() as session: + registry = task.services.integration_registry.catalog_services() + start_time = utc_now() + collections = MarcExporter.enabled_collections(session, registry) + for collection in collections: + # Collection.id should never be able to be None here, but mypy doesn't know that. + # So we assert it for mypy's benefit. + assert collection.id is not None + lock = MarcFileUploads(task.services.redis.client(), collection.id) + with lock.lock() as acquired: + if not acquired: + task.log.info( + f"Skipping collection {collection.name} ({collection.id}) because another task holds its lock." + ) + continue + + libraries_info = MarcExporter.enabled_libraries( + session, registry, collection.id + ) + needs_update = ( + any(info.needs_update for info in libraries_info) or force + ) + + if not needs_update: + task.log.info( + f"Skipping collection {collection.name} ({collection.id}) because it has been updated recently." + ) + continue + + works = MarcExporter.query_works( + session, + collection.id, + work_id_offset=0, + batch_size=1, + ) + if not works: + task.log.info( + f"Skipping collection {collection.name} ({collection.id}) because it has no works." + ) + continue + + task.log.info( + f"Generating MARC records for collection {collection.name} ({collection.id})." + ) + marc_export_collection.delay( + collection_id=collection.id, + start_time=start_time, + libraries=[l.dict() for l in libraries_info], + ) + + +@shared_task(queue=QueueNames.default, bind=True) +def marc_export_collection( + task: Task, + collection_id: int, + start_time: datetime.datetime, + libraries: list[dict[str, Any]], + batch_size: int = 500, + last_work_id: int | None = None, + update_number: int = 0, +) -> None: + """ + Export MARC records for a single collection. + + This task is designed to be re-queued until all works in the collection have been processed, + this can take some time, however each individual task should complete quickly, so that it + doesn't block other tasks from running. + """ + + base_url = task.services.config.sitewide.base_url() + storage_service = task.services.storage.public() + libraries_info = [LibraryInfo.parse_obj(l) for l in libraries] + uploader = MarcUploader( + storage_service, + MarcFileUploads(task.services.redis.client(), collection_id, update_number), + ) + with uploader.begin(): + if not uploader.locked: + task.log.info( + f"Skipping collection {collection_id} because another task is already processing it." + ) + return + + with task.session() as session: + works = MarcExporter.query_works( + session, + collection_id, + work_id_offset=last_work_id, + batch_size=batch_size, + ) + for work in works: + MarcExporter.process_work( + work, libraries_info, base_url, uploader=uploader + ) + + # Sync the uploader to ensure that all the data is written to storage. + uploader.sync() + + if len(works) == batch_size: + # This task is complete, but there are more works waiting to be exported. So we requeue ourselves + # to process the next batch. + raise task.replace( + marc_export_collection.s( + collection_id=collection_id, + start_time=start_time, + libraries=[l.dict() for l in libraries_info], + batch_size=batch_size, + last_work_id=works[-1].id, + update_number=uploader.update_number, + ) + ) + + # If we got here, we have finished generating MARC records. Cleanup and exit. + with task.transaction() as session: + collection = MarcExporter.collection(session, collection_id) + collection_name = collection.name if collection else "unknown" + completed_uploads = uploader.complete() + MarcExporter.create_marc_upload_records( + session, start_time, collection_id, libraries_info, completed_uploads + ) + uploader.delete() + task.log.info( + f"Finished generating MARC records for collection '{collection_name}' ({collection_id})." + ) diff --git a/src/palace/manager/marc/__init__.py b/src/palace/manager/marc/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/palace/manager/core/marc.py b/src/palace/manager/marc/annotator.py similarity index 60% rename from src/palace/manager/core/marc.py rename to src/palace/manager/marc/annotator.py index d6720d7f72..47446955f0 100644 --- a/src/palace/manager/core/marc.py +++ b/src/palace/manager/marc/annotator.py @@ -2,40 +2,20 @@ import re import urllib.parse -from collections.abc import Mapping -from datetime import datetime -from io import BytesIO -from uuid import UUID, uuid4 +from collections.abc import Mapping, Sequence -import pytz -from pydantic import NonNegativeInt from pymarc import Field, Indicators, Record, Subfield -from sqlalchemy import select -from sqlalchemy.engine import ScalarResult -from sqlalchemy.orm.session import Session +from sqlalchemy.orm import Session from palace.manager.core.classifier import Classifier -from palace.manager.integration.base import HasLibraryIntegrationConfiguration -from palace.manager.integration.settings import ( - BaseSettings, - ConfigurationFormItem, - ConfigurationFormItemType, - FormField, -) -from palace.manager.service.storage.s3 import S3Service -from palace.manager.sqlalchemy.model.collection import Collection from palace.manager.sqlalchemy.model.edition import Edition from palace.manager.sqlalchemy.model.identifier import Identifier -from palace.manager.sqlalchemy.model.library import Library from palace.manager.sqlalchemy.model.licensing import DeliveryMechanism, LicensePool -from palace.manager.sqlalchemy.model.marcfile import MarcFile from palace.manager.sqlalchemy.model.resource import Representation from palace.manager.sqlalchemy.model.work import Work -from palace.manager.sqlalchemy.util import create from palace.manager.util.datetime_helpers import utc_now from palace.manager.util.languages import LanguageCodes from palace.manager.util.log import LoggerMixin -from palace.manager.util.uuid import uuid_encode class Annotator(LoggerMixin): @@ -63,83 +43,90 @@ class Annotator(LoggerMixin): (Representation.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM): "Adobe PDF eBook", } - def __init__( - self, - cm_url: str, + @classmethod + def marc_record(cls, work: Work, license_pool: LicensePool) -> Record: + edition = license_pool.presentation_edition + identifier = license_pool.identifier + + record = cls._record() + cls.add_control_fields(record, identifier, license_pool, edition) + cls.add_isbn(record, identifier) + + # TODO: The 240 and 130 fields are for translated works, so they can be grouped even + # though they have different titles. We do not group editions of the same work in + # different languages, so we can't use those yet. + + cls.add_title(record, edition) + cls.add_contributors(record, edition) + cls.add_publisher(record, edition) + cls.add_physical_description(record, edition) + cls.add_audience(record, work) + cls.add_series(record, edition) + cls.add_system_details(record) + cls.add_ebooks_subject(record) + cls.add_distributor(record, license_pool) + cls.add_formats(record, license_pool) + cls.add_summary(record, work) + cls.add_genres(record, work) + + return record + + @classmethod + def library_marc_record( + cls, + record: Record, + identifier: Identifier, + base_url: str, library_short_name: str, - web_client_urls: list[str], + web_client_urls: Sequence[str], organization_code: str | None, include_summary: bool, include_genres: bool, - ) -> None: - self.cm_url = cm_url - self.library_short_name = library_short_name - self.web_client_urls = web_client_urls - self.organization_code = organization_code - self.include_summary = include_summary - self.include_genres = include_genres - - def annotate_work_record( - self, - revised: bool, - work: Work, - active_license_pool: LicensePool, - edition: Edition, - identifier: Identifier, ) -> Record: - """Add metadata from this work to a MARC record. - - :param revised: Whether this record is being revised. - :param work: The Work whose record is being annotated. - :param active_license_pool: Of all the LicensePools associated with this - Work, the client has expressed interest in this one. - :param edition: The Edition to use when associating bibliographic - metadata with this entry. - :param identifier: Of all the Identifiers associated with this - Work, the client has expressed interest in this one. - - :return: A pymarc Record object. - """ - record = Record(leader=self.leader(revised), force_utf8=True) - self.add_control_fields(record, identifier, active_license_pool, edition) - self.add_isbn(record, identifier) + record = cls._copy_record(record) - # TODO: The 240 and 130 fields are for translated works, so they can be grouped even - # though they have different titles. We do not group editions of the same work in - # different languages, so we can't use those yet. + if organization_code: + cls.add_marc_organization_code(record, organization_code) - self.add_title(record, edition) - self.add_contributors(record, edition) - self.add_publisher(record, edition) - self.add_physical_description(record, edition) - self.add_audience(record, work) - self.add_series(record, edition) - self.add_system_details(record) - self.add_ebooks_subject(record) - self.add_distributor(record, active_license_pool) - self.add_formats(record, active_license_pool) + fields_to_remove = [] - if self.organization_code: - self.add_marc_organization_code(record, self.organization_code) + if not include_summary: + fields_to_remove.append("520") - if self.include_summary: - self.add_summary(record, work) + if not include_genres: + fields_to_remove.append("650") - if self.include_genres: - self.add_genres(record, work) + if fields_to_remove: + record.remove_fields(*fields_to_remove) - self.add_web_client_urls( + cls.add_web_client_urls( record, identifier, - self.library_short_name, - self.cm_url, - self.web_client_urls, + library_short_name, + base_url, + web_client_urls, ) return record @classmethod - def leader(cls, revised: bool) -> str: + def _record(cls, leader: str | None = None) -> Record: + leader = leader or cls.leader() + return Record(leader=leader, force_utf8=True) + + @classmethod + def _copy_record(cls, record: Record) -> Record: + copied = cls._record(record.leader) + copied.add_field(*record.get_fields()) + return copied + + @classmethod + def set_revised(cls, record: Record, revised: bool = True) -> Record: + record.leader.record_status = "c" if revised else "n" + return record + + @classmethod + def leader(cls, revised: bool = False) -> str: # The record length is automatically updated once fields are added. initial_record_length = "00000" @@ -558,20 +545,20 @@ def add_web_client_urls( record: Record, identifier: Identifier, library_short_name: str, - cm_url: str, - web_client_urls: list[str], + base_url: str, + web_client_urls: Sequence[str], ) -> None: qualified_identifier = urllib.parse.quote( f"{identifier.type}/{identifier.identifier}", safe="" ) + link = "{}/{}/works/{}".format( + base_url, + library_short_name, + qualified_identifier, + ) + encoded_link = urllib.parse.quote(link, safe="") for web_client_base_url in web_client_urls: - link = "{}/{}/works/{}".format( - cm_url, - library_short_name, - qualified_identifier, - ) - encoded_link = urllib.parse.quote(link, safe="") url = f"{web_client_base_url}/book/{encoded_link}" record.add_field( Field( @@ -580,244 +567,3 @@ def add_web_client_urls( subfields=[Subfield(code="u", value=url)], ) ) - - -class MarcExporterSettings(BaseSettings): - # This setting (in days) controls how often MARC files should be - # automatically updated. Since the crontab in docker isn't easily - # configurable, we can run a script daily but check this to decide - # whether to do anything. - update_frequency: NonNegativeInt = FormField( - 30, - form=ConfigurationFormItem( - label="Update frequency (in days)", - type=ConfigurationFormItemType.NUMBER, - required=True, - ), - alias="marc_update_frequency", - ) - - -class MarcExporterLibrarySettings(BaseSettings): - # MARC organization codes are assigned by the - # Library of Congress and can be found here: - # http://www.loc.gov/marc/organizations/org-search.php - organization_code: str | None = FormField( - None, - form=ConfigurationFormItem( - label="The MARC organization code for this library (003 field).", - description="MARC organization codes are assigned by the Library of Congress.", - type=ConfigurationFormItemType.TEXT, - ), - alias="marc_organization_code", - ) - - web_client_url: str | None = FormField( - None, - form=ConfigurationFormItem( - label="The base URL for the web catalog for this library, for the 856 field.", - description="If using a library registry that provides a web catalog, this can be left blank.", - type=ConfigurationFormItemType.TEXT, - ), - alias="marc_web_client_url", - ) - - include_summary: bool = FormField( - False, - form=ConfigurationFormItem( - label="Include summaries in MARC records (520 field)", - type=ConfigurationFormItemType.SELECT, - options={"false": "Do not include summaries", "true": "Include summaries"}, - ), - ) - - include_genres: bool = FormField( - False, - form=ConfigurationFormItem( - label="Include Palace Collection Manager genres in MARC records (650 fields)", - type=ConfigurationFormItemType.SELECT, - options={ - "false": "Do not include Palace Collection Manager genres", - "true": "Include Palace Collection Manager genres", - }, - ), - alias="include_simplified_genres", - ) - - -class MARCExporter( - HasLibraryIntegrationConfiguration[ - MarcExporterSettings, MarcExporterLibrarySettings - ], - LoggerMixin, -): - """Turn a work into a record for a MARC file.""" - - # The minimum size each piece of a multipart upload should be - MINIMUM_UPLOAD_BATCH_SIZE_BYTES = 5 * 1024 * 1024 # 5MB - - def __init__( - self, - _db: Session, - storage_service: S3Service, - ): - self._db = _db - self.storage_service = storage_service - - @classmethod - def label(cls) -> str: - return "MARC Export" - - @classmethod - def description(cls) -> str: - return ( - "Export metadata into MARC files that can be imported into an ILS manually." - ) - - @classmethod - def settings_class(cls) -> type[MarcExporterSettings]: - return MarcExporterSettings - - @classmethod - def library_settings_class(cls) -> type[MarcExporterLibrarySettings]: - return MarcExporterLibrarySettings - - @classmethod - def create_record( - cls, - revised: bool, - work: Work, - annotator: Annotator, - ) -> Record | None: - """Build a complete MARC record for a given work.""" - pool = work.active_license_pool() - if not pool: - return None - - edition = pool.presentation_edition - identifier = pool.identifier - - return annotator.annotate_work_record(revised, work, pool, edition, identifier) - - @staticmethod - def _date_to_string(date: datetime) -> str: - return date.astimezone(pytz.UTC).strftime("%Y-%m-%d") - - def _file_key( - self, - uuid: UUID, - library: Library, - collection: Collection, - creation_time: datetime, - since_time: datetime | None = None, - ) -> str: - """The path to the hosted MARC file for the given library, collection, - and date range.""" - root = "marc" - short_name = str(library.short_name) - creation = self._date_to_string(creation_time) - - if since_time: - file_type = f"delta.{self._date_to_string(since_time)}.{creation}" - else: - file_type = f"full.{creation}" - - uuid_encoded = uuid_encode(uuid) - collection_name = collection.name.replace(" ", "_") - filename = f"{collection_name}.{file_type}.{uuid_encoded}.mrc" - parts = [root, short_name, filename] - return "/".join(parts) - - def query_works( - self, - collection: Collection, - since_time: datetime | None, - creation_time: datetime, - batch_size: int, - ) -> ScalarResult: - query = ( - select(Work) - .join(LicensePool) - .join(Collection) - .where( - Collection.id == collection.id, - Work.last_update_time <= creation_time, - ) - ) - - if since_time is not None: - query = query.where(Work.last_update_time >= since_time) - - return self._db.execute(query).unique().yield_per(batch_size).scalars() - - def records( - self, - library: Library, - collection: Collection, - annotator: Annotator, - *, - creation_time: datetime, - since_time: datetime | None = None, - batch_size: int = 500, - ) -> None: - """ - Create and export a MARC file for the books in a collection. - """ - uuid = uuid4() - key = self._file_key(uuid, library, collection, creation_time, since_time) - - with self.storage_service.multipart( - key, - content_type=Representation.MARC_MEDIA_TYPE, - ) as upload: - this_batch = BytesIO() - - works = self.query_works(collection, since_time, creation_time, batch_size) - for work in works: - # Create a record for each work and add it to the MARC file in progress. - record = self.create_record( - since_time is not None, - work, - annotator, - ) - if record: - record_bytes = record.as_marc() - this_batch.write(record_bytes) - if ( - this_batch.getbuffer().nbytes - >= self.MINIMUM_UPLOAD_BATCH_SIZE_BYTES - ): - # We've reached or exceeded the upload threshold. - # Upload one part of the multipart document. - upload.upload_part(this_batch.getvalue()) - this_batch.seek(0) - this_batch.truncate() - - # Upload the final part of the multi-document, if - # necessary. - if this_batch.getbuffer().nbytes > 0: - upload.upload_part(this_batch.getvalue()) - - if upload.complete: - create( - self._db, - MarcFile, - id=uuid, - library=library, - collection=collection, - created=creation_time, - since=since_time, - key=key, - ) - else: - if upload.exception: - # Log the exception and move on to the next file. We will try again next script run. - self.log.error( - f"Failed to upload MARC file for {library.short_name}/{collection.name}: {upload.exception}", - exc_info=upload.exception, - ) - else: - # There were no records to upload. This is not an error, but we should log it. - self.log.info( - f"No MARC records to upload for {library.short_name}/{collection.name}." - ) diff --git a/src/palace/manager/marc/exporter.py b/src/palace/manager/marc/exporter.py new file mode 100644 index 0000000000..7745ba79cb --- /dev/null +++ b/src/palace/manager/marc/exporter.py @@ -0,0 +1,373 @@ +from __future__ import annotations + +import datetime +from collections.abc import Iterable, Sequence +from uuid import UUID, uuid4 + +import pytz +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.orm import Session, aliased + +from palace.manager.integration.base import HasLibraryIntegrationConfiguration +from palace.manager.integration.goals import Goals +from palace.manager.marc.annotator import Annotator +from palace.manager.marc.settings import ( + MarcExporterLibrarySettings, + MarcExporterSettings, +) +from palace.manager.marc.uploader import MarcUploader +from palace.manager.service.integration_registry.catalog_services import ( + CatalogServicesRegistry, +) +from palace.manager.sqlalchemy.model.collection import Collection +from palace.manager.sqlalchemy.model.discovery_service_registration import ( + DiscoveryServiceRegistration, +) +from palace.manager.sqlalchemy.model.integration import ( + IntegrationConfiguration, + IntegrationLibraryConfiguration, +) +from palace.manager.sqlalchemy.model.library import Library +from palace.manager.sqlalchemy.model.licensing import LicensePool +from palace.manager.sqlalchemy.model.marcfile import MarcFile +from palace.manager.sqlalchemy.model.work import Work +from palace.manager.sqlalchemy.util import create +from palace.manager.util.datetime_helpers import utc_now +from palace.manager.util.log import LoggerMixin +from palace.manager.util.uuid import uuid_encode + + +class LibraryInfo(BaseModel): + library_id: int + library_short_name: str + last_updated: datetime.datetime | None + needs_update: bool + organization_code: str | None + include_summary: bool + include_genres: bool + web_client_urls: tuple[str, ...] + + s3_key_full_uuid: str + s3_key_full: str + + s3_key_delta_uuid: str + s3_key_delta: str | None = None + + class Config: + frozen = True + + +class MarcExporter( + HasLibraryIntegrationConfiguration[ + MarcExporterSettings, MarcExporterLibrarySettings + ], + LoggerMixin, +): + """ + This class provides the logic for exporting MARC records for a collection to S3. + """ + + @classmethod + def label(cls) -> str: + return "MARC Export" + + @classmethod + def description(cls) -> str: + return ( + "Export metadata into MARC files that can be imported into an ILS manually." + ) + + @classmethod + def settings_class(cls) -> type[MarcExporterSettings]: + return MarcExporterSettings + + @classmethod + def library_settings_class(cls) -> type[MarcExporterLibrarySettings]: + return MarcExporterLibrarySettings + + @staticmethod + def _s3_key( + library: Library, + collection: Collection, + creation_time: datetime.datetime, + uuid: UUID, + since_time: datetime.datetime | None = None, + ) -> str: + """The path to the hosted MARC file for the given library, collection, + and date range.""" + + def date_to_string(date: datetime.datetime) -> str: + return date.astimezone(pytz.UTC).strftime("%Y-%m-%d") + + root = "marc" + short_name = str(library.short_name) + creation = date_to_string(creation_time) + + if since_time: + file_type = f"delta.{date_to_string(since_time)}.{creation}" + else: + file_type = f"full.{creation}" + + uuid_encoded = uuid_encode(uuid) + collection_name = collection.name.replace(" ", "_") + filename = f"{collection_name}.{file_type}.{uuid_encoded}.mrc" + parts = [root, short_name, filename] + return "/".join(parts) + + @staticmethod + def _needs_update( + last_updated_time: datetime.datetime | None, update_frequency: int + ) -> bool: + return not last_updated_time or ( + last_updated_time.date() + <= (utc_now() - datetime.timedelta(days=update_frequency)).date() + ) + + @staticmethod + def _web_client_urls( + session: Session, library: Library, url: str | None = None + ) -> tuple[str, ...]: + """Find web client URLs configured by the registry for this library.""" + urls = [ + s.web_client + for s in session.execute( + select(DiscoveryServiceRegistration.web_client).where( + DiscoveryServiceRegistration.library == library, + DiscoveryServiceRegistration.web_client != None, + ) + ).all() + ] + + if url: + urls.append(url) + + return tuple(urls) + + @classmethod + def _enabled_collections_and_libraries( + cls, + session: Session, + registry: CatalogServicesRegistry, + collection_id: int | None = None, + ) -> set[tuple[Collection, IntegrationLibraryConfiguration]]: + collection_integration_configuration = aliased(IntegrationConfiguration) + collection_integration_library_configuration = aliased( + IntegrationLibraryConfiguration + ) + library_integration_library_configuration = aliased( + IntegrationLibraryConfiguration, + name="library_integration_library_configuration", + ) + library_integration_configuration = aliased(IntegrationConfiguration) + + protocols = registry.get_protocols(cls) + + collection_query = ( + select(Collection, library_integration_library_configuration) + .select_from(Collection) + .join(collection_integration_configuration) + .join(collection_integration_library_configuration) + .join(Library) + .join(library_integration_library_configuration) + .join(library_integration_configuration) + .where( + Collection.export_marc_records == True, + library_integration_configuration.goal == Goals.CATALOG_GOAL, + library_integration_configuration.protocol.in_(protocols), + ) + ) + if collection_id is not None: + collection_query = collection_query.where(Collection.id == collection_id) + return { + (r.Collection, r.library_integration_library_configuration) + for r in session.execute(collection_query) + } + + @staticmethod + def _last_updated( + session: Session, library: Library, collection: Collection + ) -> datetime.datetime | None: + """Find the most recent MarcFile creation time.""" + last_updated_file = session.execute( + select(MarcFile.created) + .where( + MarcFile.library == library, + MarcFile.collection == collection, + ) + .order_by(MarcFile.created.desc()) + ).first() + + return last_updated_file.created if last_updated_file else None + + @classmethod + def enabled_collections( + cls, session: Session, registry: CatalogServicesRegistry + ) -> set[Collection]: + return {c for c, _ in cls._enabled_collections_and_libraries(session, registry)} + + @classmethod + def enabled_libraries( + cls, session: Session, registry: CatalogServicesRegistry, collection_id: int + ) -> Sequence[LibraryInfo]: + library_info = [] + creation_time = utc_now() + for collection, library_integration in cls._enabled_collections_and_libraries( + session, registry, collection_id + ): + library = library_integration.library + library_id = library.id + library_short_name = library.short_name + if library_id is None or library_short_name is None: + cls.logger().warning( + f"Library {library} is missing an ID or short name." + ) + continue + last_updated_time = cls._last_updated(session, library, collection) + update_frequency = cls.settings_load( + library_integration.parent + ).update_frequency + library_settings = cls.library_settings_load(library_integration) + needs_update = cls._needs_update(last_updated_time, update_frequency) + web_client_urls = cls._web_client_urls( + session, library, library_settings.web_client_url + ) + s3_key_full_uuid = uuid4() + s3_key_full = cls._s3_key( + library, + collection, + creation_time, + s3_key_full_uuid, + ) + s3_key_delta_uuid = uuid4() + s3_key_delta = ( + cls._s3_key( + library, + collection, + creation_time, + s3_key_delta_uuid, + since_time=last_updated_time, + ) + if last_updated_time + else None + ) + library_info.append( + LibraryInfo( + library_id=library_id, + library_short_name=library_short_name, + last_updated=last_updated_time, + needs_update=needs_update, + organization_code=library_settings.organization_code, + include_summary=library_settings.include_summary, + include_genres=library_settings.include_genres, + web_client_urls=web_client_urls, + s3_key_full_uuid=str(s3_key_full_uuid), + s3_key_full=s3_key_full, + s3_key_delta_uuid=str(s3_key_delta_uuid), + s3_key_delta=s3_key_delta, + ) + ) + library_info.sort(key=lambda info: info.library_id) + return library_info + + @staticmethod + def query_works( + session: Session, + collection_id: int, + work_id_offset: int | None, + batch_size: int, + ) -> list[Work]: + query = ( + select(Work) + .join(LicensePool) + .where( + LicensePool.collection_id == collection_id, + ) + .limit(batch_size) + .order_by(Work.id.asc()) + ) + + if work_id_offset is not None: + query = query.where(Work.id > work_id_offset) + + return session.execute(query).scalars().unique().all() + + @staticmethod + def collection(session: Session, collection_id: int) -> Collection | None: + return session.execute( + select(Collection).where(Collection.id == collection_id) + ).scalar_one_or_none() + + @classmethod + def process_work( + cls, + work: Work, + libraries_info: Iterable[LibraryInfo], + base_url: str, + *, + uploader: MarcUploader, + annotator: type[Annotator] = Annotator, + ) -> None: + pool = work.active_license_pool() + if pool is None: + return + base_record = annotator.marc_record(work, pool) + + for library_info in libraries_info: + library_record = annotator.library_marc_record( + base_record, + pool.identifier, + base_url, + library_info.library_short_name, + library_info.web_client_urls, + library_info.organization_code, + library_info.include_summary, + library_info.include_genres, + ) + + uploader.add_record( + library_info.s3_key_full, + library_record.as_marc(), + ) + + if ( + library_info.last_updated + and library_info.s3_key_delta + and work.last_update_time + and work.last_update_time > library_info.last_updated + ): + uploader.add_record( + library_info.s3_key_delta, + annotator.set_revised(library_record).as_marc(), + ) + + @staticmethod + def create_marc_upload_records( + session: Session, + start_time: datetime.datetime, + collection_id: int, + libraries_info: Iterable[LibraryInfo], + uploaded_keys: set[str], + ) -> None: + for library_info in libraries_info: + if library_info.s3_key_full in uploaded_keys: + create( + session, + MarcFile, + id=library_info.s3_key_full_uuid, + library_id=library_info.library_id, + collection_id=collection_id, + created=start_time, + key=library_info.s3_key_full, + ) + if library_info.s3_key_delta and library_info.s3_key_delta in uploaded_keys: + create( + session, + MarcFile, + id=library_info.s3_key_delta_uuid, + library_id=library_info.library_id, + collection_id=collection_id, + created=start_time, + since=library_info.last_updated, + key=library_info.s3_key_delta, + ) diff --git a/src/palace/manager/marc/settings.py b/src/palace/manager/marc/settings.py new file mode 100644 index 0000000000..a6517fb73f --- /dev/null +++ b/src/palace/manager/marc/settings.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from pydantic import NonNegativeInt + +from palace.manager.integration.settings import ( + BaseSettings, + ConfigurationFormItem, + ConfigurationFormItemType, + FormField, +) + + +class MarcExporterSettings(BaseSettings): + # This setting (in days) controls how often MARC files should be + # automatically updated. Since the crontab in docker isn't easily + # configurable, we can run a script daily but check this to decide + # whether to do anything. + update_frequency: NonNegativeInt = FormField( + 30, + form=ConfigurationFormItem( + label="Update frequency (in days)", + type=ConfigurationFormItemType.NUMBER, + required=True, + ), + alias="marc_update_frequency", + ) + + +class MarcExporterLibrarySettings(BaseSettings): + # MARC organization codes are assigned by the + # Library of Congress and can be found here: + # http://www.loc.gov/marc/organizations/org-search.php + organization_code: str | None = FormField( + None, + form=ConfigurationFormItem( + label="The MARC organization code for this library (003 field).", + description="MARC organization codes are assigned by the Library of Congress.", + type=ConfigurationFormItemType.TEXT, + ), + alias="marc_organization_code", + ) + + web_client_url: str | None = FormField( + None, + form=ConfigurationFormItem( + label="The base URL for the web catalog for this library, for the 856 field.", + description="If using a library registry that provides a web catalog, this can be left blank.", + type=ConfigurationFormItemType.TEXT, + ), + alias="marc_web_client_url", + ) + + include_summary: bool = FormField( + False, + form=ConfigurationFormItem( + label="Include summaries in MARC records (520 field)", + type=ConfigurationFormItemType.SELECT, + options={"false": "Do not include summaries", "true": "Include summaries"}, + ), + ) + + include_genres: bool = FormField( + False, + form=ConfigurationFormItem( + label="Include Palace Collection Manager genres in MARC records (650 fields)", + type=ConfigurationFormItemType.SELECT, + options={ + "false": "Do not include Palace Collection Manager genres", + "true": "Include Palace Collection Manager genres", + }, + ), + alias="include_simplified_genres", + ) diff --git a/src/palace/manager/marc/uploader.py b/src/palace/manager/marc/uploader.py new file mode 100644 index 0000000000..976e5be0f4 --- /dev/null +++ b/src/palace/manager/marc/uploader.py @@ -0,0 +1,142 @@ +from collections import defaultdict +from collections.abc import Generator, Sequence +from contextlib import contextmanager + +from celery.exceptions import Ignore, Retry +from typing_extensions import Self + +from palace.manager.service.redis.models.marc import MarcFileUploads +from palace.manager.service.storage.s3 import S3Service +from palace.manager.sqlalchemy.model.resource import Representation +from palace.manager.util.log import LoggerMixin + + +class MarcUploader(LoggerMixin): + """ + This class is used to manage the upload of MARC files to S3. The upload is done in multiple + parts, so that the Celery task can be broken up into multiple steps, saving the progress + between steps to redis, and flushing them to S3 when the buffer is large enough. + + This class orchestrates the upload process, delegating the redis operation to the + `MarcFileUploads` class, and the S3 upload to the `S3Service` class. + """ + + def __init__(self, storage_service: S3Service, marc_uploads: MarcFileUploads): + self.storage_service = storage_service + self.marc_uploads = marc_uploads + self._buffers: defaultdict[str, str] = defaultdict(str) + self._locked = False + + @property + def locked(self) -> bool: + return self._locked + + @property + def update_number(self) -> int: + return self.marc_uploads.update_number + + def add_record(self, key: str, record: bytes) -> None: + self._buffers[key] += record.decode() + + def _s3_sync(self, needs_upload: Sequence[str]) -> None: + upload_ids = self.marc_uploads.get_upload_ids(needs_upload) + for key in needs_upload: + if upload_ids.get(key) is None: + upload_id = self.storage_service.multipart_create( + key, content_type=Representation.MARC_MEDIA_TYPE + ) + self.marc_uploads.set_upload_id(key, upload_id) + upload_ids[key] = upload_id + + part_number, data = self.marc_uploads.get_part_num_and_buffer(key) + upload_part = self.storage_service.multipart_upload( + key, upload_ids[key], part_number, data.encode() + ) + self.marc_uploads.add_part_and_clear_buffer(key, upload_part) + + def sync(self) -> None: + # First sync our buffers to redis + buffer_lengths = self.marc_uploads.append_buffers(self._buffers) + self._buffers.clear() + + # Then, if any of our redis buffers are large enough, upload them to S3 + needs_upload = [ + key + for key, length in buffer_lengths.items() + if length > self.storage_service.MINIMUM_MULTIPART_UPLOAD_SIZE + ] + + if not needs_upload: + return + + self._s3_sync(needs_upload) + + def _abort(self) -> None: + in_progress = self.marc_uploads.get() + for key, upload in in_progress.items(): + if upload.upload_id is None: + # This upload has not started, so there is nothing to abort. + continue + try: + self.storage_service.multipart_abort(key, upload.upload_id) + except Exception as e: + # We log and keep going, since we want to abort as many uploads as possible + # even if some fail, this is likely already being called in an exception handler. + # So we want to do as much cleanup as possible. + self.log.exception( + f"Failed to abort upload {key} (UploadID: {upload.upload_id}) due to exception ({e})." + ) + + # Delete our in-progress uploads from redis as well + self.delete() + + def complete(self) -> set[str]: + # Make sure any local data we have is synced + self.sync() + + in_progress = self.marc_uploads.get() + for key, upload in in_progress.items(): + if upload.upload_id is None: + # We haven't started the upload. At this point there is no reason to start a + # multipart upload, just upload the file directly and continue. + self.storage_service.store( + key, upload.buffer, Representation.MARC_MEDIA_TYPE + ) + else: + if upload.buffer != "": + # Upload the last chunk if the buffer is not empty, the final part has no + # minimum size requirement. + upload_part = self.storage_service.multipart_upload( + key, upload.upload_id, len(upload.parts), upload.buffer.encode() + ) + upload.parts.append(upload_part) + + # Complete the multipart upload + self.storage_service.multipart_complete( + key, upload.upload_id, upload.parts + ) + + # Delete our in-progress uploads data from redis + if in_progress: + self.marc_uploads.clear_uploads() + + # Return the keys that were uploaded + return set(in_progress.keys()) + + def delete(self) -> None: + self.marc_uploads.delete() + + @contextmanager + def begin(self) -> Generator[Self, None, None]: + self._locked = self.marc_uploads.acquire() + try: + yield self + except Exception as e: + # We want to ignore any celery exceptions that are expected, but + # handle cleanup for any other cases. + if not isinstance(e, (Retry, Ignore)): + self._abort() + raise + finally: + self.marc_uploads.release() + self._locked = False diff --git a/src/palace/manager/scripts/marc.py b/src/palace/manager/scripts/marc.py deleted file mode 100644 index 572c6fe761..0000000000 --- a/src/palace/manager/scripts/marc.py +++ /dev/null @@ -1,223 +0,0 @@ -from __future__ import annotations - -import argparse -import datetime -from collections.abc import Sequence -from datetime import timedelta -from typing import Any - -from sqlalchemy import select -from sqlalchemy.exc import NoResultFound -from sqlalchemy.orm import Session - -from palace.manager.core.config import CannotLoadConfiguration -from palace.manager.core.marc import Annotator as MarcAnnotator -from palace.manager.core.marc import ( - MARCExporter, - MarcExporterLibrarySettings, - MarcExporterSettings, -) -from palace.manager.integration.goals import Goals -from palace.manager.scripts.input import LibraryInputScript -from palace.manager.sqlalchemy.model.collection import Collection -from palace.manager.sqlalchemy.model.discovery_service_registration import ( - DiscoveryServiceRegistration, -) -from palace.manager.sqlalchemy.model.integration import ( - IntegrationConfiguration, - IntegrationLibraryConfiguration, -) -from palace.manager.sqlalchemy.model.library import Library -from palace.manager.sqlalchemy.model.marcfile import MarcFile -from palace.manager.util.datetime_helpers import utc_now - - -class CacheMARCFiles(LibraryInputScript): - """Generate and cache MARC files for each input library.""" - - name = "Cache MARC files" - - @classmethod - def arg_parser(cls, _db: Session) -> argparse.ArgumentParser: # type: ignore[override] - parser = super().arg_parser(_db) - parser.add_argument( - "--force", - help="Generate new MARC files even if MARC files have already been generated recently enough", - dest="force", - action="store_true", - ) - return parser - - def __init__( - self, - _db: Session | None = None, - cmd_args: Sequence[str] | None = None, - exporter: MARCExporter | None = None, - *args: Any, - **kwargs: Any, - ) -> None: - super().__init__(_db, *args, **kwargs) - self.force = False - self.parse_args(cmd_args) - self.storage_service = self.services.storage.public() - self.base_url = self.services.config.sitewide.base_url() - if self.base_url is None: - raise CannotLoadConfiguration( - f"Missing required environment variable: PALACE_BASE_URL." - ) - - self.exporter = exporter or MARCExporter(self._db, self.storage_service) - - def parse_args(self, cmd_args: Sequence[str] | None = None) -> argparse.Namespace: - parser = self.arg_parser(self._db) - parsed = parser.parse_args(cmd_args) - self.force = parsed.force - return parsed - - def settings( - self, library: Library - ) -> tuple[MarcExporterSettings, MarcExporterLibrarySettings]: - integration_query = ( - select(IntegrationLibraryConfiguration) - .join(IntegrationConfiguration) - .where( - IntegrationConfiguration.goal == Goals.CATALOG_GOAL, - IntegrationConfiguration.protocol == MARCExporter.__name__, - IntegrationLibraryConfiguration.library == library, - ) - ) - integration = self._db.execute(integration_query).scalar_one() - - library_settings = MARCExporter.library_settings_load(integration) - settings = MARCExporter.settings_load(integration.parent) - - return settings, library_settings - - def process_libraries(self, libraries: Sequence[Library]) -> None: - if not self.storage_service: - self.log.info("No storage service was found.") - return - - super().process_libraries(libraries) - - def get_collections(self, library: Library) -> Sequence[Collection]: - return self._db.scalars( - select(Collection).where( - Collection.libraries.contains(library), - Collection.export_marc_records == True, - ) - ).all() - - def get_web_client_urls( - self, library: Library, url: str | None = None - ) -> list[str]: - """Find web client URLs configured by the registry for this library.""" - urls = [ - s.web_client - for s in self._db.execute( - select(DiscoveryServiceRegistration.web_client).where( - DiscoveryServiceRegistration.library == library, - DiscoveryServiceRegistration.web_client != None, - ) - ).all() - ] - - if url: - urls.append(url) - - return urls - - def process_library( - self, library: Library, annotator_cls: type[MarcAnnotator] = MarcAnnotator - ) -> None: - try: - settings, library_settings = self.settings(library) - except NoResultFound: - return - - self.log.info("Processing library %s" % library.name) - - update_frequency = int(settings.update_frequency) - - # Find the collections for this library. - collections = self.get_collections(library) - - # Find web client URLs configured by the registry for this library. - web_client_urls = self.get_web_client_urls( - library, library_settings.web_client_url - ) - - annotator = annotator_cls( - self.base_url, - library.short_name or "", - web_client_urls, - library_settings.organization_code, - library_settings.include_summary, - library_settings.include_genres, - ) - - # We set the creation time to be the start of the batch. Any updates that happen during the batch will be - # included in the next batch. - creation_time = utc_now() - - for collection in collections: - self.process_collection( - library, - collection, - annotator, - update_frequency, - creation_time, - ) - - def last_updated( - self, library: Library, collection: Collection - ) -> datetime.datetime | None: - """Find the most recent MarcFile creation time.""" - last_updated_file = self._db.execute( - select(MarcFile.created) - .where( - MarcFile.library == library, - MarcFile.collection == collection, - ) - .order_by(MarcFile.created.desc()) - ).first() - - return last_updated_file.created if last_updated_file else None - - def process_collection( - self, - library: Library, - collection: Collection, - annotator: MarcAnnotator, - update_frequency: int, - creation_time: datetime.datetime, - ) -> None: - last_update = self.last_updated(library, collection) - - if ( - not self.force - and last_update - and (last_update > creation_time - timedelta(days=update_frequency)) - ): - self.log.info( - f"Skipping collection {collection.name} because last update was less than {update_frequency} days ago" - ) - return - - # First update the file with ALL the records. - self.exporter.records( - library, collection, annotator, creation_time=creation_time - ) - - # Then create a new file with changes since the last update. - if last_update: - self.exporter.records( - library, - collection, - annotator, - creation_time=creation_time, - since_time=last_update, - ) - - self._db.commit() - self.log.info("Processed collection %s" % collection.name) diff --git a/src/palace/manager/service/celery/celery.py b/src/palace/manager/service/celery/celery.py index adb4486d61..cb73dba995 100644 --- a/src/palace/manager/service/celery/celery.py +++ b/src/palace/manager/service/celery/celery.py @@ -37,6 +37,12 @@ def beat_schedule() -> dict[str, Any]: "task": "search.search_indexing", "schedule": crontab(minute="*"), # Run every minute }, + "marc_export": { + "task": "marc.marc_export", + "schedule": crontab( + hour="3,11", minute="0" + ), # Run twice a day at 3:00 AM and 11:00 AM + }, } diff --git a/src/palace/manager/service/integration_registry/catalog_services.py b/src/palace/manager/service/integration_registry/catalog_services.py index 913b8c4f1f..1a6593d62f 100644 --- a/src/palace/manager/service/integration_registry/catalog_services.py +++ b/src/palace/manager/service/integration_registry/catalog_services.py @@ -1,9 +1,17 @@ -from palace.manager.core.marc import MARCExporter +from __future__ import annotations + +from typing import TYPE_CHECKING + from palace.manager.integration.goals import Goals from palace.manager.service.integration_registry.base import IntegrationRegistry +if TYPE_CHECKING: + from palace.manager.marc.exporter import MarcExporter # noqa: autoflake -class CatalogServicesRegistry(IntegrationRegistry[MARCExporter]): + +class CatalogServicesRegistry(IntegrationRegistry["MarcExporter"]): def __init__(self) -> None: + from palace.manager.marc.exporter import MarcExporter + super().__init__(Goals.CATALOG_GOAL) - self.register(MARCExporter) + self.register(MarcExporter, aliases=["MARCExporter"]) diff --git a/src/palace/manager/service/redis/models/lock.py b/src/palace/manager/service/redis/models/lock.py index dfc2a74f60..1fa232ef27 100644 --- a/src/palace/manager/service/redis/models/lock.py +++ b/src/palace/manager/service/redis/models/lock.py @@ -6,7 +6,7 @@ from contextlib import contextmanager from datetime import timedelta from functools import cached_property -from typing import Any, TypeVar, cast +from typing import TypeVar, cast from uuid import uuid4 from palace.manager.celery.task import Task @@ -70,18 +70,6 @@ def key(self) -> str: :return: The key used to store the lock in Redis. """ - def _exception_exit(self) -> None: - """ - Clean up before exiting the context manager, if an exception occurs. - """ - self.release() - - def _normal_exit(self) -> None: - """ - Clean up before exiting the context manager, if no exception occurs. - """ - self.release() - @contextmanager def lock( self, @@ -107,10 +95,10 @@ def lock( exception_occurred = True raise finally: - if release_on_error and exception_occurred: - self._exception_exit() - elif release_on_exit and not exception_occurred: - self._normal_exit() + if (release_on_error and exception_occurred) or ( + release_on_exit and not exception_occurred + ): + self.release() class RedisLock(BaseRedisLock): @@ -248,12 +236,23 @@ def __init__( class RedisJsonLock(BaseRedisLock, ABC): - _ACQUIRE_SCRIPT = """ + _GET_LOCK_FUNCTION = """ + local function get_lock_value(key, json_key) + local value = redis.call("json.get", key, json_key) + if not value then + return nil + end + return cjson.decode(value)[1] + end + """ + + _ACQUIRE_SCRIPT = f""" + {_GET_LOCK_FUNCTION} -- If the locks json object doesn't exist, create it with the initial value redis.call("json.set", KEYS[1], "$", ARGV[4], "nx") -- Get the current lock value - local lock_value = cjson.decode(redis.call("json.get", KEYS[1], ARGV[1]))[1] + local lock_value = get_lock_value(KEYS[1], ARGV[1]) if not lock_value then -- The lock isn't currently locked, so we lock it and set the timeout redis.call("json.set", KEYS[1], ARGV[1], cjson.encode(ARGV[2])) @@ -269,8 +268,9 @@ class RedisJsonLock(BaseRedisLock, ABC): end """ - _RELEASE_SCRIPT = """ - if cjson.decode(redis.call("json.get", KEYS[1], ARGV[1]))[1] == ARGV[2] then + _RELEASE_SCRIPT = f""" + {_GET_LOCK_FUNCTION} + if get_lock_value(KEYS[1], ARGV[1]) == ARGV[2] then redis.call("json.del", KEYS[1], ARGV[1]) return 1 else @@ -278,8 +278,9 @@ class RedisJsonLock(BaseRedisLock, ABC): end """ - _EXTEND_SCRIPT = """ - if cjson.decode(redis.call("json.get", KEYS[1], ARGV[1]))[1] == ARGV[2] then + _EXTEND_SCRIPT = f""" + {_GET_LOCK_FUNCTION} + if get_lock_value(KEYS[1], ARGV[1]) == ARGV[2] then redis.call("pexpire", KEYS[1], ARGV[3]) return 1 else @@ -287,8 +288,9 @@ class RedisJsonLock(BaseRedisLock, ABC): end """ - _DELETE_SCRIPT = """ - if cjson.decode(redis.call("json.get", KEYS[1], ARGV[1]))[1] == ARGV[2] then + _DELETE_SCRIPT = f""" + {_GET_LOCK_FUNCTION} + if get_lock_value(KEYS[1], ARGV[1]) == ARGV[2] then redis.call("del", KEYS[1]) return 1 else @@ -341,12 +343,20 @@ def _initial_value(self) -> str: def _parse_multi( cls, value: Mapping[str, Sequence[T]] | None ) -> dict[str, T | None]: + """ + Helper function that makes it easier to work with the results of a JSON GET command, + where you request multiple keys. + """ if value is None: return {} return {k: cls._parse_value(v) for k, v in value.items()} @staticmethod def _parse_value(value: Sequence[T] | None) -> T | None: + """ + Helper function to parse the value from the results of a JSON GET command, where you + expect the JSONPath to return a single value. + """ if value is None: return None try: @@ -356,17 +366,14 @@ def _parse_value(value: Sequence[T] | None) -> T | None: @classmethod def _parse_value_or_raise(cls, value: Sequence[T] | None) -> T: + """ + Wrapper around _parse_value that raises an exception if the value is None. + """ parsed_value = cls._parse_value(value) if parsed_value is None: raise LockError(f"Could not parse value ({json.dumps(value)})") return parsed_value - def _get_value(self, json_key: str) -> Any | None: - value = self._redis_client.json().get(self.key, json_key) - if value is None or len(value) != 1: - return None - return value[0] - def acquire(self) -> bool: return ( self._acquire_script( diff --git a/src/palace/manager/service/redis/models/marc.py b/src/palace/manager/service/redis/models/marc.py new file mode 100644 index 0000000000..7612f05f27 --- /dev/null +++ b/src/palace/manager/service/redis/models/marc.py @@ -0,0 +1,271 @@ +from __future__ import annotations + +import json +from collections.abc import Callable, Generator, Mapping, Sequence +from contextlib import contextmanager +from functools import cached_property +from typing import Any + +from pydantic import BaseModel +from redis import ResponseError, WatchError + +from palace.manager.service.redis.models.lock import LockError, RedisJsonLock +from palace.manager.service.redis.redis import Pipeline, Redis +from palace.manager.service.storage.s3 import MultipartS3UploadPart +from palace.manager.sqlalchemy.model.collection import Collection +from palace.manager.util.log import LoggerMixin + + +class RedisMarcError(LockError): + pass + + +class MarcFileUpload(BaseModel): + buffer: str = "" + upload_id: str | None = None + parts: list[MultipartS3UploadPart] = [] + + +class MarcFileUploads(RedisJsonLock, LoggerMixin): + """ + This class is used as a lock for the Celery MARC export task, to ensure that only one + task can upload MARC files for a given collection at a time. It increments an update + number each time an update is made, to guard against corruption if a task gets run + twice. + + It stores the intermediate results of the MARC file generation process, so that the task + can complete in multiple steps, saving the progress between steps to redis, and flushing + them to S3 when the buffer is full. + + This object is focused on the redis part of this operation, the actual s3 upload orchestration + is handled by the `MarcUploader` class. + """ + + def __init__( + self, + redis_client: Redis, + collection_id: int, + update_number: int = 0, + ): + super().__init__(redis_client) + self._collection_id = collection_id + self._update_number = update_number + + @cached_property + def key(self) -> str: + return self._redis_client.get_key( + self.__class__.__name__, + Collection.redis_key_from_id(self._collection_id), + ) + + @property + def _lock_timeout_ms(self) -> int: + return 20 * 60 * 1000 # 20 minutes + + @property + def update_number(self) -> int: + return self._update_number + + @property + def _initial_value(self) -> str: + """ + The initial value to use for the locks JSON object. + """ + return json.dumps({"uploads": {}, "update_number": 0}) + + @property + def _update_number_json_key(self) -> str: + return "$.update_number" + + @property + def _uploads_json_key(self) -> str: + return "$.uploads" + + @staticmethod + def _upload_initial_value(buffer_data: str) -> dict[str, Any]: + return MarcFileUpload(buffer=buffer_data).dict(exclude_none=True) + + def _upload_path(self, upload_key: str) -> str: + return f"{self._uploads_json_key}['{upload_key}']" + + def _buffer_path(self, upload_key: str) -> str: + upload_path = self._upload_path(upload_key) + return f"{upload_path}.buffer" + + def _upload_id_path(self, upload_key: str) -> str: + upload_path = self._upload_path(upload_key) + return f"{upload_path}.upload_id" + + def _parts_path(self, upload_key: str) -> str: + upload_path = self._upload_path(upload_key) + return f"{upload_path}.parts" + + @contextmanager + def _pipeline( + self, begin_transaction: bool = True + ) -> Generator[Pipeline, None, None]: + with self._redis_client.pipeline() as pipe: + pipe.watch(self.key) + fetched_data = self._parse_multi( + pipe.json().get( + self.key, self._lock_json_key, self._update_number_json_key + ) + ) + # Check that we hold the lock + if ( + remote_random := fetched_data.get(self._lock_json_key) + ) != self._random_value: + raise RedisMarcError( + f"Must hold lock to append to buffer. " + f"Expected: {self._random_value}, got: {remote_random}" + ) + # Check that the update number is correct + if ( + remote_update_number := fetched_data.get(self._update_number_json_key) + ) != self._update_number: + raise RedisMarcError( + f"Update number mismatch. " + f"Expected: {self._update_number}, got: {remote_update_number}" + ) + if begin_transaction: + pipe.multi() + yield pipe + + def _execute_pipeline(self, pipe: Pipeline, updates: int) -> list[Any]: + if not pipe.explicit_transaction: + raise RedisMarcError( + "Pipeline should be in explicit transaction mode before executing." + ) + pipe.json().numincrby(self.key, self._update_number_json_key, updates) + pipe.pexpire(self.key, self._lock_timeout_ms) + try: + pipe_results = pipe.execute() + except WatchError as e: + raise RedisMarcError( + "Failed to update buffers. Another process is modifying the buffers." + ) from e + self._update_number = self._parse_value_or_raise(pipe_results[-2]) + return pipe_results[:-2] + + def append_buffers(self, data: Mapping[str, str]) -> dict[str, int]: + if not data: + return {} + + set_results = {} + with self._pipeline(begin_transaction=False) as pipe: + existing_uploads: list[str] = self._parse_value_or_raise( + pipe.json().objkeys(self.key, self._uploads_json_key) + ) + pipe.multi() + for key, value in data.items(): + if value == "": + continue + if key in existing_uploads: + pipe.json().strappend( + self.key, path=self._buffer_path(key), value=value + ) + else: + pipe.json().set( + self.key, + path=self._upload_path(key), + obj=self._upload_initial_value(value), + ) + set_results[key] = len(value) + + pipe_results = self._execute_pipeline(pipe, len(data)) + + if not all(pipe_results): + raise RedisMarcError("Failed to append buffers.") + + return { + k: set_results[k] if v is True else self._parse_value_or_raise(v) + for k, v in zip(data.keys(), pipe_results) + } + + def add_part_and_clear_buffer(self, key: str, part: MultipartS3UploadPart) -> None: + with self._pipeline() as pipe: + pipe.json().arrappend( + self.key, + self._parts_path(key), + part.dict(), + ) + pipe.json().set( + self.key, + path=self._buffer_path(key), + obj="", + ) + pipe_results = self._execute_pipeline(pipe, 1) + + if not all(pipe_results): + raise RedisMarcError("Failed to add part and clear buffer.") + + def set_upload_id(self, key: str, upload_id: str) -> None: + with self._pipeline() as pipe: + pipe.json().set( + self.key, + path=self._upload_id_path(key), + obj=upload_id, + nx=True, + ) + pipe_results = self._execute_pipeline(pipe, 1) + + if not all(pipe_results): + raise RedisMarcError("Failed to set upload ID.") + + def clear_uploads(self) -> None: + with self._pipeline() as pipe: + pipe.json().clear(self.key, self._uploads_json_key) + pipe_results = self._execute_pipeline(pipe, 1) + + if not all(pipe_results): + raise RedisMarcError("Failed to clear uploads.") + + def _get_specific( + self, + keys: str | Sequence[str], + get_path: Callable[[str], str], + ) -> dict[str, Any]: + if isinstance(keys, str): + keys = [keys] + paths = {get_path(k): k for k in keys} + results = self._redis_client.json().get(self.key, *paths.keys()) + if len(keys) == 1: + return {keys[0]: self._parse_value(results)} + else: + return {paths[k]: v for k, v in self._parse_multi(results).items()} + + def _get_all(self, key: str) -> dict[str, Any]: + get_results = self._redis_client.json().get(self.key, key) + results: dict[str, Any] | None = self._parse_value(get_results) + + if results is None: + return {} + + return results + + def get(self, keys: str | Sequence[str] | None = None) -> dict[str, MarcFileUpload]: + if keys is None: + uploads = self._get_all(self._uploads_json_key) + else: + uploads = self._get_specific(keys, self._upload_path) + + return { + k: MarcFileUpload.parse_obj(v) for k, v in uploads.items() if v is not None + } + + def get_upload_ids(self, keys: str | Sequence[str]) -> dict[str, str]: + return self._get_specific(keys, self._upload_id_path) + + def get_part_num_and_buffer(self, key: str) -> tuple[int, str]: + try: + with self._redis_client.pipeline() as pipe: + pipe.json().get(self.key, self._buffer_path(key)) + pipe.json().arrlen(self.key, self._parts_path(key)) + results = pipe.execute() + except ResponseError as e: + raise RedisMarcError("Failed to get part number and buffer data.") from e + + buffer_data: str = self._parse_value_or_raise(results[0]) + part_number: int = self._parse_value_or_raise(results[1]) + + return part_number, buffer_data diff --git a/src/palace/manager/service/redis/redis.py b/src/palace/manager/service/redis/redis.py index cd73c4edd6..25b06f91b5 100644 --- a/src/palace/manager/service/redis/redis.py +++ b/src/palace/manager/service/redis/redis.py @@ -95,6 +95,7 @@ def key_args(self, args: list[Any]) -> Sequence[str]: RedisCommandArgs("MGET", args_end=None), RedisCommandArgs("EXISTS", args_end=None), RedisCommandArgs("EXPIRETIME"), + RedisCommandArgs("JSON.CLEAR"), RedisCommandArgs("JSON.SET"), RedisCommandArgs("JSON.STRLEN"), RedisCommandArgs("JSON.STRAPPEND"), @@ -102,6 +103,7 @@ def key_args(self, args: list[Any]) -> Sequence[str]: RedisCommandArgs("JSON.GET"), RedisCommandArgs("JSON.OBJKEYS"), RedisCommandArgs("JSON.ARRAPPEND"), + RedisCommandArgs("JSON.ARRLEN"), RedisVariableCommandArgs("EVALSHA", key_index=1), ] } diff --git a/tests/conftest.py b/tests/conftest.py index 395e05d264..387bbb70b4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ "tests.fixtures.files", "tests.fixtures.flask", "tests.fixtures.library", + "tests.fixtures.marc", "tests.fixtures.odl", "tests.fixtures.redis", "tests.fixtures.s3", diff --git a/tests/fixtures/marc.py b/tests/fixtures/marc.py new file mode 100644 index 0000000000..06ed89d44e --- /dev/null +++ b/tests/fixtures/marc.py @@ -0,0 +1,99 @@ +import datetime +from collections.abc import Sequence + +import pytest + +from palace.manager.integration.goals import Goals +from palace.manager.marc.exporter import LibraryInfo, MarcExporter +from palace.manager.marc.settings import MarcExporterLibrarySettings +from palace.manager.sqlalchemy.model.collection import Collection +from palace.manager.sqlalchemy.model.integration import IntegrationConfiguration +from palace.manager.sqlalchemy.model.marcfile import MarcFile +from palace.manager.sqlalchemy.model.work import Work +from palace.manager.sqlalchemy.util import create +from palace.manager.util.datetime_helpers import utc_now +from tests.fixtures.database import DatabaseTransactionFixture +from tests.fixtures.services import ServicesFixture + + +class MarcExporterFixture: + def __init__( + self, db: DatabaseTransactionFixture, services_fixture: ServicesFixture + ): + self._db = db + self._services_fixture = services_fixture + + self.registry = ( + services_fixture.services.integration_registry.catalog_services() + ) + self.session = db.session + + self.library1 = db.default_library() + self.library1.short_name = "library1" + self.library2 = db.library(short_name="library2") + + self.collection1 = db.collection(name="collection1") + self.collection2 = db.collection() + self.collection3 = db.collection() + + self.collection1.libraries = [self.library1, self.library2] + self.collection2.libraries = [self.library1] + self.collection3.libraries = [self.library2] + + self.test_marc_file_key = "test-file-1.mrc" + + def integration(self) -> IntegrationConfiguration: + return self._db.integration_configuration(MarcExporter, Goals.CATALOG_GOAL) + + def work(self, collection: Collection | None = None) -> Work: + collection = collection or self.collection1 + edition = self._db.edition() + self._db.licensepool(edition, collection=collection) + work = self._db.work(presentation_edition=edition) + work.last_update_time = utc_now() + return work + + def works(self, collection: Collection | None = None) -> list[Work]: + return [self.work(collection) for _ in range(5)] + + def configure_export(self) -> None: + marc_integration = self.integration() + self._db.integration_library_configuration( + marc_integration, + self.library1, + MarcExporterLibrarySettings(organization_code="library1-org"), + ) + self._db.integration_library_configuration( + marc_integration, + self.library2, + MarcExporterLibrarySettings(organization_code="library2-org"), + ) + + self.collection1.export_marc_records = True + self.collection2.export_marc_records = True + self.collection3.export_marc_records = True + + create( + self.session, + MarcFile, + library=self.library1, + collection=self.collection1, + key=self.test_marc_file_key, + created=utc_now() - datetime.timedelta(days=7), + ) + + def enabled_libraries( + self, collection: Collection | None = None + ) -> Sequence[LibraryInfo]: + collection = collection or self.collection1 + assert collection.id is not None + return MarcExporter.enabled_libraries( + self.session, self.registry, collection_id=collection.id + ) + + +@pytest.fixture +def marc_exporter_fixture( + db: DatabaseTransactionFixture, services_fixture: ServicesFixture +) -> MarcExporterFixture: + return MarcExporterFixture(db, services_fixture) diff --git a/tests/fixtures/s3.py b/tests/fixtures/s3.py index 19ec790b60..5dc1de11c0 100644 --- a/tests/fixtures/s3.py +++ b/tests/fixtures/s3.py @@ -2,12 +2,27 @@ import functools import sys +import uuid +from collections.abc import Generator +from dataclasses import dataclass, field from typing import TYPE_CHECKING, BinaryIO, NamedTuple, Protocol from unittest.mock import MagicMock +from uuid import uuid4 import pytest +from mypy_boto3_s3 import S3Client +from pydantic import AnyHttpUrl -from palace.manager.service.storage.s3 import MultipartS3ContextManager, S3Service +from palace.manager.service.configuration.service_configuration import ( + ServiceConfiguration, +) +from palace.manager.service.storage.container import Storage +from palace.manager.service.storage.s3 import ( + MultipartS3ContextManager, + MultipartS3UploadPart, + S3Service, +) +from tests.fixtures.config import FixtureTestUrlConfiguration if sys.version_info >= (3, 11): from typing import Self @@ -54,14 +69,28 @@ def upload_part(self, content: bytes) -> None: def _upload_complete(self) -> None: if self.content: self._complete = True - self.parent.uploads.append( - MockS3ServiceUpload(self.key, self.content, self.media_type) + self.parent.uploads[self.key] = MockS3ServiceUpload( + self.key, self.content, self.media_type ) def _upload_abort(self) -> None: ... +@dataclass +class MockMultipartUploadPart: + part_data: MultipartS3UploadPart + content: bytes + + +@dataclass +class MockMultipartUpload: + key: str + upload_id: str + parts: list[MockMultipartUploadPart] = field(default_factory=list) + content_type: str | None = None + + class MockS3Service(S3Service): def __init__( self, @@ -71,16 +100,19 @@ def __init__( url_template: str, ) -> None: super().__init__(client, region, bucket, url_template) - self.uploads: list[MockS3ServiceUpload] = [] + self.uploads: dict[str, MockS3ServiceUpload] = {} self.mocked_multipart_upload: MockMultipartS3ContextManager | None = None + self.upload_in_progress: dict[str, MockMultipartUpload] = {} + self.aborted: list[str] = [] + def store_stream( self, key: str, stream: BinaryIO, content_type: str | None = None, ) -> str | None: - self.uploads.append(MockS3ServiceUpload(key, stream.read(), content_type)) + self.uploads[key] = MockS3ServiceUpload(key, stream.read(), content_type) return self.generate_url(key) def multipart( @@ -91,6 +123,45 @@ def multipart( ) return self.mocked_multipart_upload + def multipart_create(self, key: str, content_type: str | None = None) -> str: + upload_id = str(uuid4()) + self.upload_in_progress[key] = MockMultipartUpload( + key, upload_id, content_type=content_type + ) + return upload_id + + def multipart_upload( + self, key: str, upload_id: str, part_number: int, content: bytes + ) -> MultipartS3UploadPart: + etag = str(uuid4()) + part = MultipartS3UploadPart(etag=etag, part_number=part_number) + assert key in self.upload_in_progress + assert self.upload_in_progress[key].upload_id == upload_id + self.upload_in_progress[key].parts.append( + MockMultipartUploadPart(part, content) + ) + return part + + def multipart_complete( + self, key: str, upload_id: str, parts: list[MultipartS3UploadPart] + ) -> None: + assert key in self.upload_in_progress + assert self.upload_in_progress[key].upload_id == upload_id + complete_upload = self.upload_in_progress.pop(key) + for part_stored, part_passed_in in zip(complete_upload.parts, parts): + assert part_stored.part_data == part_passed_in + self.uploads[key] = MockS3ServiceUpload( + key, + b"".join(part_stored.content for part_stored in complete_upload.parts), + complete_upload.content_type, + ) + + def multipart_abort(self, key: str, upload_id: str) -> None: + assert key in self.upload_in_progress + assert self.upload_in_progress[key].upload_id == upload_id + self.upload_in_progress.pop(key) + self.aborted.append(key) + class S3ServiceProtocol(Protocol): def __call__( @@ -133,3 +204,95 @@ def mock_service(self) -> MockS3Service: @pytest.fixture def s3_service_fixture() -> S3ServiceFixture: return S3ServiceFixture() + + +class S3UploaderIntegrationConfiguration(FixtureTestUrlConfiguration): + url: AnyHttpUrl + user: str + password: str + + class Config(ServiceConfiguration.Config): + env_prefix = "PALACE_TEST_MINIO_" + + +class S3ServiceIntegrationFixture: + def __init__(self): + self.container = Storage() + self.configuration = S3UploaderIntegrationConfiguration.from_env() + self.analytics_bucket = self.random_name("analytics") + self.public_access_bucket = self.random_name("public") + self.container.config.from_dict( + { + "access_key": self.configuration.user, + "secret_key": self.configuration.password, + "endpoint_url": self.configuration.url, + "region": "us-east-1", + "analytics_bucket": self.analytics_bucket, + "public_access_bucket": self.public_access_bucket, + "url_template": self.configuration.url + "/{bucket}/{key}", + } + ) + self.buckets = [] + self.create_buckets() + + @classmethod + def random_name(cls, prefix: str = "test"): + return f"{prefix}-{uuid.uuid4()}" + + @property + def s3_client(self) -> S3Client: + return self.container.s3_client() + + @property + def public(self) -> S3Service: + return self.container.public() + + @property + def analytics(self) -> S3Service: + return self.container.analytics() + + def create_bucket(self, bucket_name: str) -> None: + client = self.s3_client + client.create_bucket(Bucket=bucket_name) + self.buckets.append(bucket_name) + + def get_bucket(self, bucket_name: str) -> str: + if bucket_name == "public": + return self.public_access_bucket + elif bucket_name == "analytics": + return self.analytics_bucket + else: + raise ValueError(f"Unknown bucket name: {bucket_name}") + + def create_buckets(self) -> None: + for bucket in [self.analytics_bucket, self.public_access_bucket]: + self.create_bucket(bucket) + + def list_objects(self, bucket_name: str) -> list[str]: + bucket = self.get_bucket(bucket_name) + response = self.s3_client.list_objects(Bucket=bucket) + return [object["Key"] for object in response.get("Contents", [])] + + def get_object(self, bucket_name: str, key: str) -> bytes: + bucket = self.get_bucket(bucket_name) + response = self.s3_client.get_object(Bucket=bucket, Key=key) + return response["Body"].read() + + def close(self): + for bucket in self.buckets: + response = self.s3_client.list_objects(Bucket=bucket) + + for object in response.get("Contents", []): + object_key = object["Key"] + self.s3_client.delete_object(Bucket=bucket, Key=object_key) + + self.s3_client.delete_bucket(Bucket=bucket) + + +@pytest.fixture +def s3_service_integration_fixture() -> ( + Generator[S3ServiceIntegrationFixture, None, None] +): + fixture = S3ServiceIntegrationFixture() + yield fixture + fixture.close() diff --git a/tests/manager/api/admin/controller/test_catalog_services.py b/tests/manager/api/admin/controller/test_catalog_services.py index 60119e5a6d..81ae168430 100644 --- a/tests/manager/api/admin/controller/test_catalog_services.py +++ b/tests/manager/api/admin/controller/test_catalog_services.py @@ -19,8 +19,9 @@ NO_PROTOCOL_FOR_NEW_SERVICE, UNKNOWN_PROTOCOL, ) -from palace.manager.core.marc import MARCExporter, MarcExporterLibrarySettings from palace.manager.integration.goals import Goals +from palace.manager.marc.exporter import MarcExporter +from palace.manager.marc.settings import MarcExporterLibrarySettings from palace.manager.sqlalchemy.model.integration import IntegrationConfiguration from palace.manager.sqlalchemy.util import get_one from palace.manager.util.problem_detail import ProblemDetail @@ -60,7 +61,7 @@ def test_catalog_services_get_with_no_services( assert 1 == len(protocols) assert protocols[0].get("name") == controller.registry.get_protocol( - MARCExporter + MarcExporter ) assert "settings" in protocols[0] assert "library_settings" in protocols[0] @@ -76,7 +77,7 @@ def test_catalog_services_get_with_marc_exporter( ) integration = db.integration_configuration( - MARCExporter, + MarcExporter, Goals.CATALOG_GOAL, name="name", libraries=[db.default_library()], @@ -84,7 +85,7 @@ def test_catalog_services_get_with_marc_exporter( library_settings_integration = integration.for_library(db.default_library()) assert library_settings_integration is not None - MARCExporter.library_settings_update( + MarcExporter.library_settings_update( library_settings_integration, library_settings ) @@ -120,28 +121,28 @@ def test_catalog_services_get_with_marc_exporter( id="unknown protocol", ), pytest.param( - {"protocol": "MARCExporter", "id": "123"}, + {"protocol": "MarcExporter", "id": "123"}, MISSING_SERVICE, True, None, id="unknown id", ), pytest.param( - {"protocol": "MARCExporter", "id": ""}, + {"protocol": "MarcExporter", "id": ""}, CANNOT_CHANGE_PROTOCOL, True, None, id="cannot change protocol", ), pytest.param( - {"protocol": "MARCExporter"}, + {"protocol": "MarcExporter"}, MISSING_SERVICE_NAME, True, None, id="no name", ), pytest.param( - {"protocol": "MARCExporter", "name": "existing integration"}, + {"protocol": "MarcExporter", "name": "existing integration"}, INTEGRATION_NAME_ALREADY_IN_USE, True, None, @@ -149,7 +150,7 @@ def test_catalog_services_get_with_marc_exporter( ), pytest.param( { - "protocol": "MARCExporter", + "protocol": "MarcExporter", "name": "new name", "libraries": json.dumps([{"short_name": "default"}]), }, @@ -203,7 +204,7 @@ def test_catalog_services_post_create( controller: CatalogServicesController, db: DatabaseTransactionFixture, ): - protocol = controller.registry.get_protocol(MARCExporter) + protocol = controller.registry.get_protocol(MarcExporter) assert protocol is not None with flask_app_fixture.test_request_context_system_admin("/", method="POST"): @@ -241,7 +242,7 @@ def test_catalog_services_post_create( assert service.name == "exporter name" assert service.libraries == [db.default_library()] - settings = MARCExporter.library_settings_load(service.library_configurations[0]) + settings = MarcExporter.library_settings_load(service.library_configurations[0]) assert settings.include_summary is False assert settings.include_genres is True @@ -252,7 +253,7 @@ def test_catalog_services_post_edit( db: DatabaseTransactionFixture, ): service = db.integration_configuration( - MARCExporter, + MarcExporter, Goals.CATALOG_GOAL, name="name", ) @@ -287,7 +288,7 @@ def test_catalog_services_post_edit( assert service.name == "exporter name" assert service.libraries == [db.default_library()] - settings = MARCExporter.library_settings_load(service.library_configurations[0]) + settings = MarcExporter.library_settings_load(service.library_configurations[0]) assert settings.include_summary is True assert settings.include_genres is False @@ -298,7 +299,7 @@ def test_catalog_services_delete( db: DatabaseTransactionFixture, ): service = db.integration_configuration( - MARCExporter, + MarcExporter, Goals.CATALOG_GOAL, ) diff --git a/tests/manager/api/controller/test_marc.py b/tests/manager/api/controller/test_marc.py index cadc2584ee..d4ebec061d 100644 --- a/tests/manager/api/controller/test_marc.py +++ b/tests/manager/api/controller/test_marc.py @@ -7,8 +7,11 @@ from flask import Response from palace.manager.api.controller.marc import MARCRecordController -from palace.manager.core.marc import MARCExporter from palace.manager.integration.goals import Goals +from palace.manager.marc.exporter import MarcExporter +from palace.manager.service.integration_registry.catalog_services import ( + CatalogServicesRegistry, +) from palace.manager.service.storage.s3 import S3Service from palace.manager.sqlalchemy.model.collection import Collection from palace.manager.sqlalchemy.model.library import Library @@ -16,14 +19,18 @@ from palace.manager.sqlalchemy.util import create from palace.manager.util.datetime_helpers import utc_now from tests.fixtures.database import DatabaseTransactionFixture +from tests.fixtures.services import ServicesFixture class MARCRecordControllerFixture: - def __init__(self, db: DatabaseTransactionFixture): + def __init__( + self, db: DatabaseTransactionFixture, registry: CatalogServicesRegistry + ): self.db = db + self.registry = registry self.mock_s3_service = MagicMock(spec=S3Service) self.mock_s3_service.generate_url = lambda x: "http://s3.url/" + x - self.controller = MARCRecordController(self.mock_s3_service) + self.controller = MARCRecordController(self.mock_s3_service, self.registry) self.library = db.default_library() self.collection = db.default_collection() self.collection.export_marc_records = True @@ -35,7 +42,7 @@ def __init__(self, db: DatabaseTransactionFixture): def integration(self, library: Library | None = None): library = library or self.library return self.db.integration_configuration( - MARCExporter, + MarcExporter, Goals.CATALOG_GOAL, libraries=[library], ) @@ -73,9 +80,11 @@ def get_response_html(self, response: Response) -> str: @pytest.fixture def marc_record_controller_fixture( - db: DatabaseTransactionFixture, + db: DatabaseTransactionFixture, services_fixture: ServicesFixture ) -> MARCRecordControllerFixture: - return MARCRecordControllerFixture(db) + return MARCRecordControllerFixture( + db, services_fixture.services.integration_registry.catalog_services() + ) class TestMARCRecordController: diff --git a/tests/manager/celery/tasks/test_marc.py b/tests/manager/celery/tasks/test_marc.py new file mode 100644 index 0000000000..3f38c4d6b9 --- /dev/null +++ b/tests/manager/celery/tasks/test_marc.py @@ -0,0 +1,287 @@ +from typing import Any +from unittest.mock import ANY, call, patch + +import pytest +from pymarc import MARCReader +from sqlalchemy import select + +from palace.manager.celery.tasks import marc +from palace.manager.marc.exporter import MarcExporter +from palace.manager.marc.uploader import MarcUploader +from palace.manager.service.logging.configuration import LogLevel +from palace.manager.service.redis.models.marc import MarcFileUploads, RedisMarcError +from palace.manager.sqlalchemy.model.collection import Collection +from palace.manager.sqlalchemy.model.marcfile import MarcFile +from palace.manager.sqlalchemy.model.work import Work +from palace.manager.sqlalchemy.util import create +from palace.manager.util.datetime_helpers import utc_now +from tests.fixtures.celery import CeleryFixture +from tests.fixtures.database import DatabaseTransactionFixture +from tests.fixtures.marc import MarcExporterFixture +from tests.fixtures.redis import RedisFixture +from tests.fixtures.s3 import S3ServiceFixture, S3ServiceIntegrationFixture +from tests.fixtures.services import ServicesFixture + + +def test_marc_export( + db: DatabaseTransactionFixture, + redis_fixture: RedisFixture, + marc_exporter_fixture: MarcExporterFixture, + celery_fixture: CeleryFixture, +): + marc_exporter_fixture.configure_export() + with (patch.object(marc, "marc_export_collection") as marc_export_collection,): + # Because none of the collections have works, we should skip all of them. + marc.marc_export.delay().wait() + marc_export_collection.delay.assert_not_called() + + # Runs against all the expected collections + collections = [ + marc_exporter_fixture.collection1, + marc_exporter_fixture.collection2, + marc_exporter_fixture.collection3, + ] + for collection in collections: + marc_exporter_fixture.work(collection) + marc.marc_export.delay().wait() + marc_export_collection.delay.assert_has_calls( + [ + call(collection_id=collection.id, start_time=ANY, libraries=ANY) + for collection in collections + ], + any_order=True, + ) + + marc_export_collection.reset_mock() + + # Collection 1 should be skipped because it is locked + assert marc_exporter_fixture.collection1.id is not None + MarcFileUploads( + redis_fixture.client, marc_exporter_fixture.collection1.id + ).acquire() + + # Collection 2 should be skipped because it was updated recently + create( + db.session, + MarcFile, + library=marc_exporter_fixture.library1, + collection=marc_exporter_fixture.collection2, + created=utc_now(), + key="test-file-2.mrc", + ) + + marc.marc_export.delay().wait() + marc_export_collection.delay.assert_called_once_with( + collection_id=marc_exporter_fixture.collection3.id, + start_time=ANY, + libraries=ANY, + ) + + +class MarcExportCollectionFixture: + def __init__( + self, + db: DatabaseTransactionFixture, + celery_fixture: CeleryFixture, + redis_fixture: RedisFixture, + marc_exporter_fixture: MarcExporterFixture, + s3_service_integration_fixture: S3ServiceIntegrationFixture, + s3_service_fixture: S3ServiceFixture, + services_fixture: ServicesFixture, + ): + self.db = db + self.celery_fixture = celery_fixture + self.redis_fixture = redis_fixture + self.marc_exporter_fixture = marc_exporter_fixture + self.s3_service_integration_fixture = s3_service_integration_fixture + self.s3_service_fixture = s3_service_fixture + self.services_fixture = services_fixture + + self.mock_s3 = self.s3_service_fixture.mock_service() + self.mock_s3.MINIMUM_MULTIPART_UPLOAD_SIZE = 10 + marc_exporter_fixture.configure_export() + + self.start_time = utc_now() + + def marc_files(self) -> list[MarcFile]: + # We need to ignore the test-file-1.mrc file, which is created by our call to configure_export. + return [ + f + for f in self.db.session.execute(select(MarcFile)).scalars().all() + if f.key != self.marc_exporter_fixture.test_marc_file_key + ] + + def redis_data(self, collection: Collection) -> dict[str, Any] | None: + assert collection.id is not None + uploads = MarcFileUploads(self.redis_fixture.client, collection.id) + return self.redis_fixture.client.json().get(uploads.key) + + def setup_minio_storage(self) -> None: + self.services_fixture.services.storage.override( + self.s3_service_integration_fixture.container + ) + + def setup_mock_storage(self) -> None: + self.services_fixture.services.storage.public.override(self.mock_s3) + + def works(self, collection: Collection) -> list[Work]: + return [self.marc_exporter_fixture.work(collection) for _ in range(15)] + + def export_collection(self, collection: Collection) -> None: + service = self.services_fixture.services.integration_registry.catalog_services() + assert collection.id is not None + info = MarcExporter.enabled_libraries(self.db.session, service, collection.id) + libraries = [l.dict() for l in info] + marc.marc_export_collection.delay( + collection.id, batch_size=5, start_time=self.start_time, libraries=libraries + ).wait() + + +@pytest.fixture +def marc_export_collection_fixture( + db: DatabaseTransactionFixture, + celery_fixture: CeleryFixture, + redis_fixture: RedisFixture, + marc_exporter_fixture: MarcExporterFixture, + s3_service_integration_fixture: S3ServiceIntegrationFixture, + s3_service_fixture: S3ServiceFixture, + services_fixture: ServicesFixture, +) -> MarcExportCollectionFixture: + return MarcExportCollectionFixture( + db, + celery_fixture, + redis_fixture, + marc_exporter_fixture, + s3_service_integration_fixture, + s3_service_fixture, + services_fixture, + ) + + +class TestMarcExportCollection: + def test_normal_run( + self, + s3_service_integration_fixture: S3ServiceIntegrationFixture, + marc_exporter_fixture: MarcExporterFixture, + marc_export_collection_fixture: MarcExportCollectionFixture, + ): + marc_export_collection_fixture.setup_minio_storage() + collection = marc_exporter_fixture.collection1 + work_uris = [ + work.license_pools[0].identifier.urn + for work in marc_export_collection_fixture.works(collection) + ] + + # Run the full end-to-end process for exporting a collection, this should generate + # 3 batches of 5 works each, putting the results into minio. + marc_export_collection_fixture.export_collection(collection) + + # Verify that we didn't leave anything in the redis cache. + assert marc_export_collection_fixture.redis_data(collection) is None + + # Verify that the expected number of files were uploaded to minio. + uploaded_files = s3_service_integration_fixture.list_objects("public") + assert len(uploaded_files) == 3 + + # Verify that the expected number of marc files were created in the database. + marc_files = marc_export_collection_fixture.marc_files() + assert len(marc_files) == 3 + filenames = [marc_file.key for marc_file in marc_files] + + # Verify that the uploaded files are the expected ones. + assert set(uploaded_files) == set(filenames) + + # Verify that the marc files contain the expected works. + for file in uploaded_files: + data = s3_service_integration_fixture.get_object("public", file) + records = list(MARCReader(data)) + assert len(records) == len(work_uris) + marc_uris = [record["001"].data for record in records] + assert set(marc_uris) == set(work_uris) + + # Make sure the records have the correct organization code. + expected_org = "library1-org" if "library1" in file else "library2-org" + assert all(record["003"].data == expected_org for record in records) + + # Make sure records have the correct status + expected_status = "c" if "delta" in file else "n" + assert all( + record.leader.record_status == expected_status for record in records + ) + + def test_collection_no_works( + self, + marc_exporter_fixture: MarcExporterFixture, + s3_service_integration_fixture: S3ServiceIntegrationFixture, + marc_export_collection_fixture: MarcExportCollectionFixture, + ): + marc_export_collection_fixture.setup_minio_storage() + collection = marc_exporter_fixture.collection2 + marc_export_collection_fixture.export_collection(collection) + + assert marc_export_collection_fixture.marc_files() == [] + assert s3_service_integration_fixture.list_objects("public") == [] + assert marc_export_collection_fixture.redis_data(collection) is None + + def test_exception_handled( + self, + marc_exporter_fixture: MarcExporterFixture, + marc_export_collection_fixture: MarcExportCollectionFixture, + ): + marc_export_collection_fixture.setup_mock_storage() + collection = marc_exporter_fixture.collection1 + marc_export_collection_fixture.works(collection) + + with patch.object(MarcUploader, "complete") as complete: + complete.side_effect = Exception("Test Exception") + with pytest.raises(Exception, match="Test Exception"): + marc_export_collection_fixture.export_collection(collection) + + # After the exception, we should have aborted the multipart uploads and deleted the redis data. + assert marc_export_collection_fixture.marc_files() == [] + assert marc_export_collection_fixture.redis_data(collection) is None + assert len(marc_export_collection_fixture.mock_s3.aborted) == 3 + + def test_locked( + self, + redis_fixture: RedisFixture, + marc_exporter_fixture: MarcExporterFixture, + marc_export_collection_fixture: MarcExportCollectionFixture, + caplog: pytest.LogCaptureFixture, + ): + caplog.set_level(LogLevel.info) + collection = marc_exporter_fixture.collection1 + assert collection.id is not None + MarcFileUploads(redis_fixture.client, collection.id).acquire() + marc_export_collection_fixture.setup_mock_storage() + with patch.object(MarcExporter, "query_works") as query: + marc_export_collection_fixture.export_collection(collection) + query.assert_not_called() + assert "another task is already processing it" in caplog.text + + def test_outdated_task_run( + self, + redis_fixture: RedisFixture, + marc_exporter_fixture: MarcExporterFixture, + marc_export_collection_fixture: MarcExportCollectionFixture, + caplog: pytest.LogCaptureFixture, + ): + # In the case that an old task is run again for some reason, it should + # detect that its update number is incorrect and exit. + caplog.set_level(LogLevel.info) + collection = marc_exporter_fixture.collection1 + marc_export_collection_fixture.setup_mock_storage() + assert collection.id is not None + + # Acquire the lock and start an upload, this simulates another task having done work + # that the current task doesn't know about. + uploads = MarcFileUploads(redis_fixture.client, collection.id) + with uploads.lock() as locked: + assert locked + uploads.append_buffers({"test": "data"}) + + with pytest.raises(RedisMarcError, match="Update number mismatch"): + marc_export_collection_fixture.export_collection(collection) + + assert marc_export_collection_fixture.marc_files() == [] + assert marc_export_collection_fixture.redis_data(collection) is None diff --git a/tests/manager/core/test_marc.py b/tests/manager/core/test_marc.py deleted file mode 100644 index 12671e3f2d..0000000000 --- a/tests/manager/core/test_marc.py +++ /dev/null @@ -1,900 +0,0 @@ -from __future__ import annotations - -import datetime -import functools -import logging -import urllib -from typing import TYPE_CHECKING -from unittest.mock import MagicMock, create_autospec, patch - -import pytest -from pymarc import Indicators, MARCReader, Record -from pytest import LogCaptureFixture - -from palace.manager.core.marc import Annotator, MARCExporter -from palace.manager.sqlalchemy.model.classification import Genre -from palace.manager.sqlalchemy.model.contributor import Contributor -from palace.manager.sqlalchemy.model.datasource import DataSource -from palace.manager.sqlalchemy.model.edition import Edition -from palace.manager.sqlalchemy.model.identifier import Identifier -from palace.manager.sqlalchemy.model.licensing import ( - DeliveryMechanism, - LicensePoolDeliveryMechanism, - RightsStatus, -) -from palace.manager.sqlalchemy.model.marcfile import MarcFile -from palace.manager.sqlalchemy.model.resource import Representation -from palace.manager.util.datetime_helpers import datetime_utc, utc_now -from palace.manager.util.uuid import uuid_encode - -if TYPE_CHECKING: - from tests.fixtures.database import DatabaseTransactionFixture - from tests.fixtures.s3 import MockS3Service, S3ServiceFixture - - -class AnnotateWorkRecordFixture: - def __init__(self): - self.cm_url = "http://cm.url" - self.short_name = "short_name" - self.web_client_urls = ["http://webclient.url"] - self.organization_name = "org" - self.include_summary = True - self.include_genres = True - - self.annotator = Annotator( - self.cm_url, - self.short_name, - self.web_client_urls, - self.organization_name, - self.include_summary, - self.include_genres, - ) - - self.revised = MagicMock() - self.work = MagicMock() - self.pool = MagicMock() - self.edition = MagicMock() - self.identifier = MagicMock() - - self.mock_leader = create_autospec(self.annotator.leader, return_value=" " * 24) - self.mock_add_control_fields = create_autospec( - self.annotator.add_control_fields - ) - self.mock_add_marc_organization_code = create_autospec( - self.annotator.add_marc_organization_code - ) - self.mock_add_isbn = create_autospec(self.annotator.add_isbn) - self.mock_add_title = create_autospec(self.annotator.add_title) - self.mock_add_contributors = create_autospec(self.annotator.add_contributors) - self.mock_add_publisher = create_autospec(self.annotator.add_publisher) - self.mock_add_distributor = create_autospec(self.annotator.add_distributor) - self.mock_add_physical_description = create_autospec( - self.annotator.add_physical_description - ) - self.mock_add_audience = create_autospec(self.annotator.add_audience) - self.mock_add_series = create_autospec(self.annotator.add_series) - self.mock_add_system_details = create_autospec( - self.annotator.add_system_details - ) - self.mock_add_formats = create_autospec(self.annotator.add_formats) - self.mock_add_summary = create_autospec(self.annotator.add_summary) - self.mock_add_genres = create_autospec(self.annotator.add_genres) - self.mock_add_ebooks_subject = create_autospec( - self.annotator.add_ebooks_subject - ) - self.mock_add_web_client_urls = create_autospec( - self.annotator.add_web_client_urls - ) - - self.annotator.leader = self.mock_leader - self.annotator.add_control_fields = self.mock_add_control_fields - self.annotator.add_marc_organization_code = self.mock_add_marc_organization_code - self.annotator.add_isbn = self.mock_add_isbn - self.annotator.add_title = self.mock_add_title - self.annotator.add_contributors = self.mock_add_contributors - self.annotator.add_publisher = self.mock_add_publisher - self.annotator.add_distributor = self.mock_add_distributor - self.annotator.add_physical_description = self.mock_add_physical_description - self.annotator.add_audience = self.mock_add_audience - self.annotator.add_series = self.mock_add_series - self.annotator.add_system_details = self.mock_add_system_details - self.annotator.add_formats = self.mock_add_formats - self.annotator.add_summary = self.mock_add_summary - self.annotator.add_genres = self.mock_add_genres - self.annotator.add_ebooks_subject = self.mock_add_ebooks_subject - self.annotator.add_web_client_urls = self.mock_add_web_client_urls - - self.annotate_work_record = functools.partial( - self.annotator.annotate_work_record, - self.revised, - self.work, - self.pool, - self.edition, - self.identifier, - ) - - -@pytest.fixture -def annotate_work_record_fixture() -> AnnotateWorkRecordFixture: - return AnnotateWorkRecordFixture() - - -class TestAnnotator: - def test_annotate_work_record( - self, annotate_work_record_fixture: AnnotateWorkRecordFixture - ) -> None: - fixture = annotate_work_record_fixture - with patch("palace.manager.core.marc.Record") as mock_record: - fixture.annotate_work_record() - - mock_record.assert_called_once_with( - force_utf8=True, leader=fixture.mock_leader.return_value - ) - fixture.mock_leader.assert_called_once_with(fixture.revised) - record = mock_record() - fixture.mock_add_control_fields.assert_called_once_with( - record, fixture.identifier, fixture.pool, fixture.edition - ) - fixture.mock_add_marc_organization_code.assert_called_once_with( - record, fixture.organization_name - ) - fixture.mock_add_isbn.assert_called_once_with(record, fixture.identifier) - fixture.mock_add_title.assert_called_once_with(record, fixture.edition) - fixture.mock_add_contributors.assert_called_once_with(record, fixture.edition) - fixture.mock_add_publisher.assert_called_once_with(record, fixture.edition) - fixture.mock_add_distributor.assert_called_once_with(record, fixture.pool) - fixture.mock_add_physical_description.assert_called_once_with( - record, fixture.edition - ) - fixture.mock_add_audience.assert_called_once_with(record, fixture.work) - fixture.mock_add_series.assert_called_once_with(record, fixture.edition) - fixture.mock_add_system_details.assert_called_once_with(record) - fixture.mock_add_formats.assert_called_once_with(record, fixture.pool) - fixture.mock_add_summary.assert_called_once_with(record, fixture.work) - fixture.mock_add_genres.assert_called_once_with(record, fixture.work) - fixture.mock_add_ebooks_subject.assert_called_once_with(record) - fixture.mock_add_web_client_urls.assert_called_once_with( - record, - fixture.identifier, - fixture.short_name, - fixture.cm_url, - fixture.web_client_urls, - ) - - def test_annotate_work_record_no_summary( - self, annotate_work_record_fixture: AnnotateWorkRecordFixture - ) -> None: - fixture = annotate_work_record_fixture - fixture.annotator.include_summary = False - fixture.annotate_work_record() - - assert fixture.mock_add_summary.call_count == 0 - - def test_annotate_work_record_no_genres( - self, annotate_work_record_fixture: AnnotateWorkRecordFixture - ) -> None: - fixture = annotate_work_record_fixture - fixture.annotator.include_genres = False - fixture.annotate_work_record() - - assert fixture.mock_add_genres.call_count == 0 - - def test_annotate_work_record_no_organization_code( - self, annotate_work_record_fixture: AnnotateWorkRecordFixture - ) -> None: - fixture = annotate_work_record_fixture - fixture.annotator.organization_code = None - fixture.annotate_work_record() - - assert fixture.mock_add_marc_organization_code.call_count == 0 - - def test_leader(self): - leader = Annotator.leader(False) - assert leader == "00000nam 2200000 4500" - - # If the record is revised, the leader is different. - leader = Annotator.leader(True) - assert leader == "00000cam 2200000 4500" - - @staticmethod - def _check_control_field(record, tag, expected): - [field] = record.get_fields(tag) - assert field.value() == expected - - @staticmethod - def _check_field( - record, tag, expected_subfields, expected_indicators: Indicators | None = None - ): - if not expected_indicators: - expected_indicators = Indicators(" ", " ") - [field] = record.get_fields(tag) - assert field.indicators == expected_indicators - for subfield, value in expected_subfields.items(): - assert field.get_subfields(subfield)[0] == value - - def test_add_control_fields(self, db: DatabaseTransactionFixture): - # This edition has one format and was published before 1900. - edition, pool = db.edition(with_license_pool=True) - identifier = pool.identifier - edition.issued = datetime_utc(956, 1, 1) - - now = utc_now() - record = Record() - - Annotator.add_control_fields(record, identifier, pool, edition) - self._check_control_field(record, "001", identifier.urn) - assert now.strftime("%Y%m%d") in record.get_fields("005")[0].value() - self._check_control_field(record, "006", "m d ") - self._check_control_field(record, "007", "cr cn ---anuuu") - self._check_control_field( - record, "008", now.strftime("%y%m%d") + "s0956 xxu eng " - ) - - # This French edition has two formats and was published in 2018. - edition2, pool2 = db.edition(with_license_pool=True) - identifier2 = pool2.identifier - edition2.issued = datetime_utc(2018, 2, 3) - edition2.language = "fre" - LicensePoolDeliveryMechanism.set( - pool2.data_source, - identifier2, - Representation.PDF_MEDIA_TYPE, - DeliveryMechanism.ADOBE_DRM, - RightsStatus.IN_COPYRIGHT, - ) - - record = Record() - Annotator.add_control_fields(record, identifier2, pool2, edition2) - self._check_control_field(record, "001", identifier2.urn) - assert now.strftime("%Y%m%d") in record.get_fields("005")[0].value() - self._check_control_field(record, "006", "m d ") - self._check_control_field(record, "007", "cr cn ---mnuuu") - self._check_control_field( - record, "008", now.strftime("%y%m%d") + "s2018 xxu fre " - ) - - def test_add_marc_organization_code(self): - record = Record() - Annotator.add_marc_organization_code(record, "US-MaBoDPL") - self._check_control_field(record, "003", "US-MaBoDPL") - - def test_add_isbn(self, db: DatabaseTransactionFixture): - isbn = db.identifier(identifier_type=Identifier.ISBN) - record = Record() - Annotator.add_isbn(record, isbn) - self._check_field(record, "020", {"a": isbn.identifier}) - - # If the identifier isn't an ISBN, but has an equivalent that is, it still - # works. - equivalent = db.identifier() - data_source = DataSource.lookup(db.session, DataSource.OCLC) - equivalent.equivalent_to(data_source, isbn, 1) - record = Record() - Annotator.add_isbn(record, equivalent) - self._check_field(record, "020", {"a": isbn.identifier}) - - # If there is no ISBN, the field is left out. - non_isbn = db.identifier() - record = Record() - Annotator.add_isbn(record, non_isbn) - assert [] == record.get_fields("020") - - def test_add_title(self, db: DatabaseTransactionFixture): - edition = db.edition() - edition.title = "The Good Soldier" - edition.sort_title = "Good Soldier, The" - edition.subtitle = "A Tale of Passion" - - record = Record() - Annotator.add_title(record, edition) - [field] = record.get_fields("245") - self._check_field( - record, - "245", - { - "a": edition.title, - "b": edition.subtitle, - "c": edition.author, - }, - Indicators("0", "4"), - ) - - # If there's no subtitle or no author, those subfields are left out. - edition.subtitle = None - edition.author = None - - record = Record() - Annotator.add_title(record, edition) - [field] = record.get_fields("245") - self._check_field( - record, - "245", - { - "a": edition.title, - }, - Indicators("0", "4"), - ) - assert [] == field.get_subfields("b") - assert [] == field.get_subfields("c") - - def test_add_contributors(self, db: DatabaseTransactionFixture): - author = "a" - author2 = "b" - translator = "c" - - # Edition with one author gets a 100 field and no 700 fields. - edition = db.edition(authors=[author]) - edition.sort_author = "sorted" - - record = Record() - Annotator.add_contributors(record, edition) - assert [] == record.get_fields("700") - self._check_field( - record, "100", {"a": edition.sort_author}, Indicators("1", " ") - ) - - # Edition with two authors and a translator gets three 700 fields and no 100 fields. - edition = db.edition(authors=[author, author2]) - edition.add_contributor(translator, Contributor.Role.TRANSLATOR) - - record = Record() - Annotator.add_contributors(record, edition) - assert [] == record.get_fields("100") - fields = record.get_fields("700") - for field in fields: - assert Indicators("1", " ") == field.indicators - [author_field, author2_field, translator_field] = sorted( - fields, key=lambda x: x.get_subfields("a")[0] - ) - assert author == author_field.get_subfields("a")[0] - assert Contributor.Role.PRIMARY_AUTHOR == author_field.get_subfields("e")[0] - assert author2 == author2_field.get_subfields("a")[0] - assert Contributor.Role.AUTHOR == author2_field.get_subfields("e")[0] - assert translator == translator_field.get_subfields("a")[0] - assert Contributor.Role.TRANSLATOR == translator_field.get_subfields("e")[0] - - def test_add_publisher(self, db: DatabaseTransactionFixture): - edition = db.edition() - edition.publisher = db.fresh_str() - edition.issued = datetime_utc(1894, 4, 5) - - record = Record() - Annotator.add_publisher(record, edition) - self._check_field( - record, - "264", - { - "a": "[Place of publication not identified]", - "b": edition.publisher, - "c": "1894", - }, - Indicators(" ", "1"), - ) - - # If there's no publisher, the field is left out. - record = Record() - edition.publisher = None - Annotator.add_publisher(record, edition) - assert [] == record.get_fields("264") - - def test_add_distributor(self, db: DatabaseTransactionFixture): - edition, pool = db.edition(with_license_pool=True) - record = Record() - Annotator.add_distributor(record, pool) - self._check_field( - record, "264", {"b": pool.data_source.name}, Indicators(" ", "2") - ) - - def test_add_physical_description(self, db: DatabaseTransactionFixture): - book = db.edition() - book.medium = Edition.BOOK_MEDIUM - audio = db.edition() - audio.medium = Edition.AUDIO_MEDIUM - - record = Record() - Annotator.add_physical_description(record, book) - self._check_field(record, "300", {"a": "1 online resource"}) - self._check_field( - record, - "336", - { - "a": "text", - "b": "txt", - "2": "rdacontent", - }, - ) - self._check_field( - record, - "337", - { - "a": "computer", - "b": "c", - "2": "rdamedia", - }, - ) - self._check_field( - record, - "338", - { - "a": "online resource", - "b": "cr", - "2": "rdacarrier", - }, - ) - self._check_field( - record, - "347", - { - "a": "text file", - "2": "rda", - }, - ) - self._check_field( - record, - "380", - { - "a": "eBook", - "2": "tlcgt", - }, - ) - - record = Record() - Annotator.add_physical_description(record, audio) - self._check_field( - record, - "300", - { - "a": "1 sound file", - "b": "digital", - }, - ) - self._check_field( - record, - "336", - { - "a": "spoken word", - "b": "spw", - "2": "rdacontent", - }, - ) - self._check_field( - record, - "337", - { - "a": "computer", - "b": "c", - "2": "rdamedia", - }, - ) - self._check_field( - record, - "338", - { - "a": "online resource", - "b": "cr", - "2": "rdacarrier", - }, - ) - self._check_field( - record, - "347", - { - "a": "audio file", - "2": "rda", - }, - ) - assert [] == record.get_fields("380") - - def test_add_audience(self, db: DatabaseTransactionFixture): - for audience, term in list(Annotator.AUDIENCE_TERMS.items()): - work = db.work(audience=audience) - record = Record() - Annotator.add_audience(record, work) - self._check_field( - record, - "385", - { - "a": term, - "2": "tlctarget", - }, - ) - - def test_add_series(self, db: DatabaseTransactionFixture): - edition = db.edition() - edition.series = db.fresh_str() - edition.series_position = 5 - record = Record() - Annotator.add_series(record, edition) - self._check_field( - record, - "490", - { - "a": edition.series, - "v": str(edition.series_position), - }, - Indicators("0", " "), - ) - - # If there's no series position, the same field is used without - # the v subfield. - edition.series_position = None - record = Record() - Annotator.add_series(record, edition) - self._check_field( - record, - "490", - { - "a": edition.series, - }, - Indicators("0", " "), - ) - [field] = record.get_fields("490") - assert [] == field.get_subfields("v") - - # If there's no series, the field is left out. - edition.series = None - record = Record() - Annotator.add_series(record, edition) - assert [] == record.get_fields("490") - - def test_add_system_details(self): - record = Record() - Annotator.add_system_details(record) - self._check_field(record, "538", {"a": "Mode of access: World Wide Web."}) - - def test_add_formats(self, db: DatabaseTransactionFixture): - edition, pool = db.edition(with_license_pool=True) - epub_no_drm, ignore = DeliveryMechanism.lookup( - db.session, Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM - ) - pool.delivery_mechanisms[0].delivery_mechanism = epub_no_drm - LicensePoolDeliveryMechanism.set( - pool.data_source, - pool.identifier, - Representation.PDF_MEDIA_TYPE, - DeliveryMechanism.ADOBE_DRM, - RightsStatus.IN_COPYRIGHT, - ) - - record = Record() - Annotator.add_formats(record, pool) - fields = record.get_fields("538") - assert 2 == len(fields) - [pdf, epub] = sorted(fields, key=lambda x: x.get_subfields("a")[0]) - assert "Adobe PDF eBook" == pdf.get_subfields("a")[0] - assert Indicators(" ", " ") == pdf.indicators - assert "EPUB eBook" == epub.get_subfields("a")[0] - assert Indicators(" ", " ") == epub.indicators - - def test_add_summary(self, db: DatabaseTransactionFixture): - work = db.work(with_license_pool=True) - work.summary_text = "

Summary

" - - # Build and validate a record with a `520|a` summary. - record = Record() - Annotator.add_summary(record, work) - self._check_field(record, "520", {"a": " Summary "}) - exported_record = record.as_marc() - - # Round trip the exported record to validate it. - marc_reader = MARCReader(exported_record) - round_tripped_record = next(marc_reader) - self._check_field(round_tripped_record, "520", {"a": " Summary "}) - - def test_add_simplified_genres(self, db: DatabaseTransactionFixture): - work = db.work(with_license_pool=True) - fantasy, ignore = Genre.lookup(db.session, "Fantasy", autocreate=True) - romance, ignore = Genre.lookup(db.session, "Romance", autocreate=True) - work.genres = [fantasy, romance] - - record = Record() - Annotator.add_genres(record, work) - fields = record.get_fields("650") - [fantasy_field, romance_field] = sorted( - fields, key=lambda x: x.get_subfields("a")[0] - ) - assert Indicators("0", "7") == fantasy_field.indicators - assert "Fantasy" == fantasy_field.get_subfields("a")[0] - assert "Library Simplified" == fantasy_field.get_subfields("2")[0] - assert Indicators("0", "7") == romance_field.indicators - assert "Romance" == romance_field.get_subfields("a")[0] - assert "Library Simplified" == romance_field.get_subfields("2")[0] - - def test_add_ebooks_subject(self): - record = Record() - Annotator.add_ebooks_subject(record) - self._check_field( - record, "655", {"a": "Electronic books."}, Indicators(" ", "0") - ) - - def test_add_web_client_urls_empty(self): - record = MagicMock(spec=Record) - identifier = MagicMock() - Annotator.add_web_client_urls(record, identifier, "", "", []) - assert record.add_field.call_count == 0 - - def test_add_web_client_urls(self, db: DatabaseTransactionFixture): - record = Record() - identifier = db.identifier() - short_name = "short_name" - cm_url = "http://cm.url" - web_client_urls = ["http://webclient1.url", "http://webclient2.url"] - Annotator.add_web_client_urls( - record, identifier, short_name, cm_url, web_client_urls - ) - fields = record.get_fields("856") - assert len(fields) == 2 - [field1, field2] = fields - assert field1.indicators == Indicators("4", "0") - assert field2.indicators == Indicators("4", "0") - - # The URL for a work is constructed as: - # - //works/ - work_link_template = "{cm_base}/{lib}/works/{qid}" - # It is then encoded and the web client URL is constructed in this form: - # - /book/ - client_url_template = "{client_base}/book/{work_link}" - - qualified_identifier = urllib.parse.quote( - identifier.type + "/" + identifier.identifier, safe="" - ) - - expected_work_link = work_link_template.format( - cm_base=cm_url, lib=short_name, qid=qualified_identifier - ) - encoded_work_link = urllib.parse.quote(expected_work_link, safe="") - - expected_client_url_1 = client_url_template.format( - client_base=web_client_urls[0], work_link=encoded_work_link - ) - expected_client_url_2 = client_url_template.format( - client_base=web_client_urls[1], work_link=encoded_work_link - ) - - # A few checks to ensure that our setup is useful. - assert web_client_urls[0] != web_client_urls[1] - assert expected_client_url_1 != expected_client_url_2 - assert expected_client_url_1.startswith(web_client_urls[0]) - assert expected_client_url_2.startswith(web_client_urls[1]) - - assert field1.get_subfields("u")[0] == expected_client_url_1 - assert field2.get_subfields("u")[0] == expected_client_url_2 - - -class MarcExporterFixture: - def __init__(self, db: DatabaseTransactionFixture, s3: MockS3Service): - self.db = db - - self.now = utc_now() - self.library = db.default_library() - self.s3_service = s3 - self.exporter = MARCExporter(self.db.session, s3) - self.mock_annotator = MagicMock(spec=Annotator) - assert self.library.short_name is not None - self.annotator = Annotator( - "http://cm.url", - self.library.short_name, - ["http://webclient.url"], - "org", - True, - True, - ) - - self.library = db.library() - self.collection = db.collection() - self.collection.libraries.append(self.library) - - self.now = utc_now() - self.yesterday = self.now - datetime.timedelta(days=1) - self.last_week = self.now - datetime.timedelta(days=7) - - self.w1 = db.work( - genre="Mystery", with_open_access_download=True, collection=self.collection - ) - self.w1.last_update_time = self.yesterday - self.w2 = db.work( - genre="Mystery", with_open_access_download=True, collection=self.collection - ) - self.w2.last_update_time = self.last_week - - self.records = functools.partial( - self.exporter.records, - self.library, - self.collection, - annotator=self.annotator, - creation_time=self.now, - ) - - -@pytest.fixture -def marc_exporter_fixture( - db: DatabaseTransactionFixture, - s3_service_fixture: S3ServiceFixture, -) -> MarcExporterFixture: - return MarcExporterFixture(db, s3_service_fixture.mock_service()) - - -class TestMARCExporter: - def test_create_record( - self, db: DatabaseTransactionFixture, marc_exporter_fixture: MarcExporterFixture - ): - work = db.work( - with_license_pool=True, - title="old title", - authors=["old author"], - data_source_name=DataSource.OVERDRIVE, - ) - - mock_revised = MagicMock() - - create_record = functools.partial( - MARCExporter.create_record, - revised=mock_revised, - work=work, - annotator=marc_exporter_fixture.mock_annotator, - ) - - record = create_record() - assert record is not None - - # Make sure we pass the expected arguments to Annotator.annotate_work_record - marc_exporter_fixture.mock_annotator.annotate_work_record.assert_called_once_with( - mock_revised, - work, - work.license_pools[0], - work.license_pools[0].presentation_edition, - work.license_pools[0].identifier, - ) - - def test_records( - self, - db: DatabaseTransactionFixture, - marc_exporter_fixture: MarcExporterFixture, - ): - storage_service = marc_exporter_fixture.s3_service - creation_time = marc_exporter_fixture.now - - marc_exporter_fixture.records() - - # The file was mirrored and a MarcFile was created to track the mirrored file. - assert len(storage_service.uploads) == 1 - [cache] = db.session.query(MarcFile).all() - assert cache.library == marc_exporter_fixture.library - assert cache.collection == marc_exporter_fixture.collection - - short_name = marc_exporter_fixture.library.short_name - collection_name = marc_exporter_fixture.collection.name - date_str = creation_time.strftime("%Y-%m-%d") - uuid_str = uuid_encode(cache.id) - - assert ( - cache.key - == f"marc/{short_name}/{collection_name}.full.{date_str}.{uuid_str}.mrc" - ) - assert cache.created == creation_time - assert cache.since is None - - records = list(MARCReader(storage_service.uploads[0].content)) - assert len(records) == 2 - - title_fields = [record.get_fields("245") for record in records] - titles = {fields[0].get_subfields("a")[0] for fields in title_fields} - assert titles == { - marc_exporter_fixture.w1.title, - marc_exporter_fixture.w2.title, - } - - def test_records_since_time( - self, - db: DatabaseTransactionFixture, - marc_exporter_fixture: MarcExporterFixture, - ): - # If the `since` parameter is set, only works updated since that time - # are included in the export and the filename reflects that we created - # a partial export. - since = marc_exporter_fixture.now - datetime.timedelta(days=3) - storage_service = marc_exporter_fixture.s3_service - creation_time = marc_exporter_fixture.now - - marc_exporter_fixture.records( - since_time=since, - ) - [cache] = db.session.query(MarcFile).all() - assert cache.library == marc_exporter_fixture.library - assert cache.collection == marc_exporter_fixture.collection - - short_name = marc_exporter_fixture.library.short_name - collection_name = marc_exporter_fixture.collection.name - from_date = since.strftime("%Y-%m-%d") - to_date = creation_time.strftime("%Y-%m-%d") - uuid_str = uuid_encode(cache.id) - - assert ( - cache.key - == f"marc/{short_name}/{collection_name}.delta.{from_date}.{to_date}.{uuid_str}.mrc" - ) - assert cache.created == creation_time - assert cache.since == since - - # Only the work updated since the `since` time is included in the export. - [record] = list(MARCReader(storage_service.uploads[0].content)) - [title_field] = record.get_fields("245") - assert title_field.get_subfields("a")[0] == marc_exporter_fixture.w1.title - - def test_records_none( - self, - db: DatabaseTransactionFixture, - marc_exporter_fixture: MarcExporterFixture, - caplog: LogCaptureFixture, - ): - # If there are no works to export, no file is created and a log message is generated. - caplog.set_level(logging.INFO) - - storage_service = marc_exporter_fixture.s3_service - - # Remove the works from the database. - db.session.delete(marc_exporter_fixture.w1) - db.session.delete(marc_exporter_fixture.w2) - - marc_exporter_fixture.records() - - assert [] == storage_service.uploads - assert db.session.query(MarcFile).count() == 0 - assert len(caplog.records) == 1 - assert "No MARC records to upload" in caplog.text - - def test_records_exception( - self, - db: DatabaseTransactionFixture, - marc_exporter_fixture: MarcExporterFixture, - caplog: LogCaptureFixture, - ): - # If an exception occurs while exporting, no file is created and a log message is generated. - caplog.set_level(logging.ERROR) - - exporter = marc_exporter_fixture.exporter - storage_service = marc_exporter_fixture.s3_service - - # Mock our query function to raise an exception. - exporter.query_works = MagicMock(side_effect=Exception("Boom!")) - - marc_exporter_fixture.records() - - assert [] == storage_service.uploads - assert db.session.query(MarcFile).count() == 0 - assert len(caplog.records) == 1 - assert "Failed to upload MARC file" in caplog.text - assert "Boom!" in caplog.text - - def test_records_minimum_size( - self, - marc_exporter_fixture: MarcExporterFixture, - ): - exporter = marc_exporter_fixture.exporter - storage_service = marc_exporter_fixture.s3_service - - exporter.MINIMUM_UPLOAD_BATCH_SIZE_BYTES = 100 - - # Mock the "records" generated, and force the response to be of certain sizes - created_record_mock = MagicMock() - created_record_mock.as_marc = MagicMock( - side_effect=[b"1" * 600, b"2" * 20, b"3" * 500, b"4" * 10] - ) - exporter.create_record = lambda *args: created_record_mock - - # Mock the query_works to return 4 works - exporter.query_works = MagicMock( - return_value=[MagicMock(), MagicMock(), MagicMock(), MagicMock()] - ) - - marc_exporter_fixture.records() - - assert storage_service.mocked_multipart_upload is not None - # Even though there are 4 parts, we upload in 3 batches due to minimum size limitations - # The "4"th part gets uploaded due it being the tail piece - assert len(storage_service.mocked_multipart_upload.content_parts) == 3 - assert storage_service.mocked_multipart_upload.content_parts == [ - b"1" * 600, - b"2" * 20 + b"3" * 500, - b"4" * 10, - ] diff --git a/tests/manager/marc/__init__.py b/tests/manager/marc/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/manager/marc/test_annotator.py b/tests/manager/marc/test_annotator.py new file mode 100644 index 0000000000..41d59c1254 --- /dev/null +++ b/tests/manager/marc/test_annotator.py @@ -0,0 +1,716 @@ +from __future__ import annotations + +import functools +import urllib +from unittest.mock import MagicMock + +import pytest +from freezegun import freeze_time +from pymarc import Indicators, MARCReader, Record + +from palace.manager.marc.annotator import Annotator +from palace.manager.sqlalchemy.model.classification import Genre +from palace.manager.sqlalchemy.model.contributor import Contributor +from palace.manager.sqlalchemy.model.datasource import DataSource +from palace.manager.sqlalchemy.model.edition import Edition +from palace.manager.sqlalchemy.model.identifier import Identifier +from palace.manager.sqlalchemy.model.licensing import ( + DeliveryMechanism, + LicensePool, + LicensePoolDeliveryMechanism, + RightsStatus, +) +from palace.manager.sqlalchemy.model.resource import Representation +from palace.manager.sqlalchemy.model.work import Work +from palace.manager.util.datetime_helpers import datetime_utc, utc_now +from tests.fixtures.database import DatabaseTransactionFixture + + +class AnnotatorFixture: + def __init__(self, db: DatabaseTransactionFixture): + self._db = db + self.cm_url = "http://cm.url" + self.short_name = "short_name" + self.web_client_urls = ["http://webclient.url"] + self.organization_name = "org" + self.include_summary = True + self.include_genres = True + + self.annotator = Annotator() + + @staticmethod + def assert_control_field(record: Record, tag: str, expected: str) -> None: + [field] = record.get_fields(tag) + assert field.value() == expected + + @staticmethod + def assert_field( + record: Record, + tag: str, + expected_subfields: dict[str, str], + expected_indicators: Indicators | None = None, + ) -> None: + if not expected_indicators: + expected_indicators = Indicators(" ", " ") + [field] = record.get_fields(tag) + assert field.indicators == expected_indicators + for subfield, value in expected_subfields.items(): + assert field.get_subfields(subfield)[0] == value + + @staticmethod + def record_tags(record: Record) -> set[int]: + return {int(f.tag) for f in record.fields} + + def assert_record_tags( + self, + record: Record, + includes: set[int] | None = None, + excludes: set[int] | None = None, + ) -> None: + tags = self.record_tags(record) + assert includes or excludes + if includes: + assert includes.issubset(tags) + if excludes: + assert excludes.isdisjoint(tags) + + def record(self) -> Record: + return self.annotator._record() + + def test_work(self) -> tuple[Work, LicensePool]: + edition, pool = self._db.edition( + with_license_pool=True, identifier_type=Identifier.ISBN + ) + work = self._db.work(presentation_edition=edition) + work.summary_text = "Summary" + fantasy, ignore = Genre.lookup(self._db.session, "Fantasy", autocreate=True) + romance, ignore = Genre.lookup(self._db.session, "Romance", autocreate=True) + work.genres = [fantasy, romance] + edition.issued = datetime_utc(956, 1, 1) + edition.series = self._db.fresh_str() + edition.series_position = 5 + return work, pool + + +@pytest.fixture +def annotator_fixture( + db: DatabaseTransactionFixture, +) -> AnnotatorFixture: + return AnnotatorFixture(db) + + +class TestAnnotator: + def test_marc_record( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ) -> None: + work, pool = annotator_fixture.test_work() + annotator = annotator_fixture.annotator + + record = annotator.marc_record(work, pool) + assert annotator_fixture.record_tags(record) == { + 1, + 5, + 6, + 7, + 8, + 20, + 245, + 100, + 264, + 300, + 336, + 385, + 490, + 538, + 655, + 520, + 650, + 337, + 338, + 347, + 380, + } + + def test__copy_record(self, annotator_fixture: AnnotatorFixture): + work, pool = annotator_fixture.test_work() + annotator = annotator_fixture.annotator + record = annotator.marc_record(work, pool) + copied = annotator_fixture.annotator._copy_record(record) + assert copied is not record + assert copied.as_marc() == record.as_marc() + + def test_library_marc_record(self, annotator_fixture: AnnotatorFixture): + work, pool = annotator_fixture.test_work() + annotator = annotator_fixture.annotator + generic_record = annotator.marc_record(work, pool) + + library_marc_record = functools.partial( + annotator.library_marc_record, + record=generic_record, + identifier=pool.identifier, + base_url="http://cm.url", + library_short_name="short_name", + web_client_urls=["http://webclient.url"], + organization_code="xyz", + include_summary=True, + include_genres=True, + ) + + library_record = library_marc_record() + annotator_fixture.assert_record_tags( + library_record, includes={3, 520, 650, 856} + ) + + # Make sure the generic record did not get modified. + assert generic_record != library_record + assert generic_record.as_marc() != library_record.as_marc() + annotator_fixture.assert_record_tags(generic_record, excludes={3, 856}) + + # If the summary is not included, the 520 field is left out. + library_record = library_marc_record(include_summary=False) + annotator_fixture.assert_record_tags( + library_record, includes={3, 650, 856}, excludes={520} + ) + + # If the genres are not included, the 650 field is left out. + library_record = library_marc_record(include_genres=False) + annotator_fixture.assert_record_tags( + library_record, includes={3, 520, 856}, excludes={650} + ) + + # If the genres and summary are not included, the 520 and 650 fields are left out. + library_record = library_marc_record( + include_summary=False, include_genres=False + ) + annotator_fixture.assert_record_tags( + library_record, includes={3, 856}, excludes={520, 650} + ) + + # If the organization code is not provided, the 003 field is left out. + library_record = library_marc_record(organization_code=None) + annotator_fixture.assert_record_tags( + library_record, includes={520, 650, 856}, excludes={3} + ) + + # If the web client URLs are not provided, the 856 fields are left out. + library_record = library_marc_record(web_client_urls=[]) + annotator_fixture.assert_record_tags( + library_record, includes={3, 520, 650}, excludes={856} + ) + + def test_leader(self, annotator_fixture: AnnotatorFixture): + leader = annotator_fixture.annotator.leader(False) + assert leader == "00000nam 2200000 4500" + + # If the record is revised, the leader is different. + leader = Annotator.leader(True) + assert leader == "00000cam 2200000 4500" + + @freeze_time() + def test_add_control_fields( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + # This edition has one format and was published before 1900. + edition, pool = db.edition(with_license_pool=True) + identifier = pool.identifier + edition.issued = datetime_utc(956, 1, 1) + + now = utc_now() + record = annotator_fixture.record() + + annotator_fixture.annotator.add_control_fields( + record, identifier, pool, edition + ) + annotator_fixture.assert_control_field(record, "001", identifier.urn) + assert now.strftime("%Y%m%d") in record.get_fields("005")[0].value() + annotator_fixture.assert_control_field(record, "006", "m d ") + annotator_fixture.assert_control_field(record, "007", "cr cn ---anuuu") + annotator_fixture.assert_control_field( + record, "008", now.strftime("%y%m%d") + "s0956 xxu eng " + ) + + # This French edition has two formats and was published in 2018. + edition2, pool2 = db.edition(with_license_pool=True) + identifier2 = pool2.identifier + edition2.issued = datetime_utc(2018, 2, 3) + edition2.language = "fre" + LicensePoolDeliveryMechanism.set( + pool2.data_source, + identifier2, + Representation.PDF_MEDIA_TYPE, + DeliveryMechanism.ADOBE_DRM, + RightsStatus.IN_COPYRIGHT, + ) + + record = annotator_fixture.record() + annotator_fixture.annotator.add_control_fields( + record, identifier2, pool2, edition2 + ) + annotator_fixture.assert_control_field(record, "001", identifier2.urn) + assert now.strftime("%Y%m%d") in record.get_fields("005")[0].value() + annotator_fixture.assert_control_field(record, "006", "m d ") + annotator_fixture.assert_control_field(record, "007", "cr cn ---mnuuu") + annotator_fixture.assert_control_field( + record, "008", now.strftime("%y%m%d") + "s2018 xxu fre " + ) + + def test_add_marc_organization_code(self, annotator_fixture: AnnotatorFixture): + record = annotator_fixture.record() + annotator_fixture.annotator.add_marc_organization_code(record, "US-MaBoDPL") + annotator_fixture.assert_control_field(record, "003", "US-MaBoDPL") + + def test_add_isbn( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + isbn = db.identifier(identifier_type=Identifier.ISBN) + record = annotator_fixture.record() + annotator_fixture.annotator.add_isbn(record, isbn) + annotator_fixture.assert_field(record, "020", {"a": isbn.identifier}) + + # If the identifier isn't an ISBN, but has an equivalent that is, it still + # works. + equivalent = db.identifier() + data_source = DataSource.lookup(db.session, DataSource.OCLC) + equivalent.equivalent_to(data_source, isbn, 1) + record = annotator_fixture.record() + annotator_fixture.annotator.add_isbn(record, equivalent) + annotator_fixture.assert_field(record, "020", {"a": isbn.identifier}) + + # If there is no ISBN, the field is left out. + non_isbn = db.identifier() + record = annotator_fixture.record() + annotator_fixture.annotator.add_isbn(record, non_isbn) + assert [] == record.get_fields("020") + + def test_add_title( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + edition = db.edition() + edition.title = "The Good Soldier" + edition.sort_title = "Good Soldier, The" + edition.subtitle = "A Tale of Passion" + + record = annotator_fixture.record() + annotator_fixture.annotator.add_title(record, edition) + assert len(record.get_fields("245")) == 1 + annotator_fixture.assert_field( + record, + "245", + { + "a": edition.title, + "b": edition.subtitle, + "c": edition.author, + }, + Indicators("0", "4"), + ) + + # If there's no subtitle or no author, those subfields are left out. + edition.subtitle = None + edition.author = None + + record = annotator_fixture.record() + annotator_fixture.annotator.add_title(record, edition) + [field] = record.get_fields("245") + annotator_fixture.assert_field( + record, + "245", + { + "a": edition.title, + }, + Indicators("0", "4"), + ) + assert [] == field.get_subfields("b") + assert [] == field.get_subfields("c") + + def test_add_contributors( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + author = "a" + author2 = "b" + translator = "c" + + # Edition with one author gets a 100 field and no 700 fields. + edition = db.edition(authors=[author]) + edition.sort_author = "sorted" + + record = annotator_fixture.record() + annotator_fixture.annotator.add_contributors(record, edition) + assert [] == record.get_fields("700") + annotator_fixture.assert_field( + record, "100", {"a": edition.sort_author}, Indicators("1", " ") + ) + + # Edition with two authors and a translator gets three 700 fields and no 100 fields. + edition = db.edition(authors=[author, author2]) + edition.add_contributor(translator, Contributor.Role.TRANSLATOR) + + record = annotator_fixture.record() + annotator_fixture.annotator.add_contributors(record, edition) + assert [] == record.get_fields("100") + fields = record.get_fields("700") + for field in fields: + assert Indicators("1", " ") == field.indicators + [author_field, author2_field, translator_field] = sorted( + fields, key=lambda x: x.get_subfields("a")[0] + ) + assert author == author_field.get_subfields("a")[0] + assert Contributor.Role.PRIMARY_AUTHOR == author_field.get_subfields("e")[0] + assert author2 == author2_field.get_subfields("a")[0] + assert Contributor.Role.AUTHOR == author2_field.get_subfields("e")[0] + assert translator == translator_field.get_subfields("a")[0] + assert Contributor.Role.TRANSLATOR == translator_field.get_subfields("e")[0] + + def test_add_publisher( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + edition = db.edition() + edition.publisher = db.fresh_str() + edition.issued = datetime_utc(1894, 4, 5) + + record = annotator_fixture.record() + annotator_fixture.annotator.add_publisher(record, edition) + annotator_fixture.assert_field( + record, + "264", + { + "a": "[Place of publication not identified]", + "b": edition.publisher, + "c": "1894", + }, + Indicators(" ", "1"), + ) + + # If there's no publisher, the field is left out. + record = annotator_fixture.record() + edition.publisher = None + annotator_fixture.annotator.add_publisher(record, edition) + assert [] == record.get_fields("264") + + def test_add_distributor( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + edition, pool = db.edition(with_license_pool=True) + record = annotator_fixture.record() + annotator_fixture.annotator.add_distributor(record, pool) + annotator_fixture.assert_field( + record, "264", {"b": pool.data_source.name}, Indicators(" ", "2") + ) + + def test_add_physical_description( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + book = db.edition() + book.medium = Edition.BOOK_MEDIUM + audio = db.edition() + audio.medium = Edition.AUDIO_MEDIUM + + record = annotator_fixture.record() + annotator_fixture.annotator.add_physical_description(record, book) + annotator_fixture.assert_field(record, "300", {"a": "1 online resource"}) + annotator_fixture.assert_field( + record, + "336", + { + "a": "text", + "b": "txt", + "2": "rdacontent", + }, + ) + annotator_fixture.assert_field( + record, + "337", + { + "a": "computer", + "b": "c", + "2": "rdamedia", + }, + ) + annotator_fixture.assert_field( + record, + "338", + { + "a": "online resource", + "b": "cr", + "2": "rdacarrier", + }, + ) + annotator_fixture.assert_field( + record, + "347", + { + "a": "text file", + "2": "rda", + }, + ) + annotator_fixture.assert_field( + record, + "380", + { + "a": "eBook", + "2": "tlcgt", + }, + ) + + record = annotator_fixture.record() + annotator_fixture.annotator.add_physical_description(record, audio) + annotator_fixture.assert_field( + record, + "300", + { + "a": "1 sound file", + "b": "digital", + }, + ) + annotator_fixture.assert_field( + record, + "336", + { + "a": "spoken word", + "b": "spw", + "2": "rdacontent", + }, + ) + annotator_fixture.assert_field( + record, + "337", + { + "a": "computer", + "b": "c", + "2": "rdamedia", + }, + ) + annotator_fixture.assert_field( + record, + "338", + { + "a": "online resource", + "b": "cr", + "2": "rdacarrier", + }, + ) + annotator_fixture.assert_field( + record, + "347", + { + "a": "audio file", + "2": "rda", + }, + ) + assert [] == record.get_fields("380") + + def test_add_audience( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + for audience, term in list(annotator_fixture.annotator.AUDIENCE_TERMS.items()): + work = db.work(audience=audience) + record = annotator_fixture.record() + annotator_fixture.annotator.add_audience(record, work) + annotator_fixture.assert_field( + record, + "385", + { + "a": term, + "2": "tlctarget", + }, + ) + + def test_add_series( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + edition = db.edition() + edition.series = db.fresh_str() + edition.series_position = 5 + record = annotator_fixture.record() + annotator_fixture.annotator.add_series(record, edition) + annotator_fixture.assert_field( + record, + "490", + { + "a": edition.series, + "v": str(edition.series_position), + }, + Indicators("0", " "), + ) + + # If there's no series position, the same field is used without + # the v subfield. + edition.series_position = None + record = annotator_fixture.record() + annotator_fixture.annotator.add_series(record, edition) + annotator_fixture.assert_field( + record, + "490", + { + "a": edition.series, + }, + Indicators("0", " "), + ) + [field] = record.get_fields("490") + assert [] == field.get_subfields("v") + + # If there's no series, the field is left out. + edition.series = None + record = annotator_fixture.record() + annotator_fixture.annotator.add_series(record, edition) + assert [] == record.get_fields("490") + + def test_add_system_details(self, annotator_fixture: AnnotatorFixture): + record = annotator_fixture.record() + annotator_fixture.annotator.add_system_details(record) + annotator_fixture.assert_field( + record, "538", {"a": "Mode of access: World Wide Web."} + ) + + def test_add_formats( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + edition, pool = db.edition(with_license_pool=True) + epub_no_drm, ignore = DeliveryMechanism.lookup( + db.session, Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM + ) + pool.delivery_mechanisms[0].delivery_mechanism = epub_no_drm + LicensePoolDeliveryMechanism.set( + pool.data_source, + pool.identifier, + Representation.PDF_MEDIA_TYPE, + DeliveryMechanism.ADOBE_DRM, + RightsStatus.IN_COPYRIGHT, + ) + + record = annotator_fixture.record() + annotator_fixture.annotator.add_formats(record, pool) + fields = record.get_fields("538") + assert 2 == len(fields) + [pdf, epub] = sorted(fields, key=lambda x: x.get_subfields("a")[0]) + assert "Adobe PDF eBook" == pdf.get_subfields("a")[0] + assert Indicators(" ", " ") == pdf.indicators + assert "EPUB eBook" == epub.get_subfields("a")[0] + assert Indicators(" ", " ") == epub.indicators + + def test_add_summary( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + work = db.work(with_license_pool=True) + work.summary_text = "

Summary

" + + # Build and validate a record with a `520|a` summary. + record = annotator_fixture.record() + annotator_fixture.annotator.add_summary(record, work) + annotator_fixture.assert_field(record, "520", {"a": " Summary "}) + exported_record = record.as_marc() + + # Round trip the exported record to validate it. + marc_reader = MARCReader(exported_record) + round_tripped_record = next(marc_reader) + annotator_fixture.assert_field(round_tripped_record, "520", {"a": " Summary "}) + + def test_add_simplified_genres( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + work = db.work(with_license_pool=True) + fantasy, ignore = Genre.lookup(db.session, "Fantasy", autocreate=True) + romance, ignore = Genre.lookup(db.session, "Romance", autocreate=True) + work.genres = [fantasy, romance] + + record = annotator_fixture.record() + annotator_fixture.annotator.add_genres(record, work) + fields = record.get_fields("650") + [fantasy_field, romance_field] = sorted( + fields, key=lambda x: x.get_subfields("a")[0] + ) + assert Indicators("0", "7") == fantasy_field.indicators + assert "Fantasy" == fantasy_field.get_subfields("a")[0] + assert "Library Simplified" == fantasy_field.get_subfields("2")[0] + assert Indicators("0", "7") == romance_field.indicators + assert "Romance" == romance_field.get_subfields("a")[0] + assert "Library Simplified" == romance_field.get_subfields("2")[0] + + def test_add_ebooks_subject(self, annotator_fixture: AnnotatorFixture): + record = annotator_fixture.record() + annotator_fixture.annotator.add_ebooks_subject(record) + annotator_fixture.assert_field( + record, "655", {"a": "Electronic books."}, Indicators(" ", "0") + ) + + def test_add_web_client_urls_empty(self, annotator_fixture: AnnotatorFixture): + record = MagicMock(spec=Record) + identifier = MagicMock() + annotator_fixture.annotator.add_web_client_urls(record, identifier, "", "", []) + assert record.add_field.call_count == 0 + + def test_add_web_client_urls( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + record = annotator_fixture.record() + identifier = db.identifier() + short_name = "short_name" + cm_url = "http://cm.url" + web_client_urls = ["http://webclient1.url", "http://webclient2.url"] + annotator_fixture.annotator.add_web_client_urls( + record, identifier, short_name, cm_url, web_client_urls + ) + fields = record.get_fields("856") + assert len(fields) == 2 + [field1, field2] = fields + assert field1.indicators == Indicators("4", "0") + assert field2.indicators == Indicators("4", "0") + + # The URL for a work is constructed as: + # - //works/ + work_link_template = "{cm_base}/{lib}/works/{qid}" + # It is then encoded and the web client URL is constructed in this form: + # - /book/ + client_url_template = "{client_base}/book/{work_link}" + + qualified_identifier = urllib.parse.quote( + identifier.type + "/" + identifier.identifier, safe="" + ) + + expected_work_link = work_link_template.format( + cm_base=cm_url, lib=short_name, qid=qualified_identifier + ) + encoded_work_link = urllib.parse.quote(expected_work_link, safe="") + + expected_client_url_1 = client_url_template.format( + client_base=web_client_urls[0], work_link=encoded_work_link + ) + expected_client_url_2 = client_url_template.format( + client_base=web_client_urls[1], work_link=encoded_work_link + ) + + # A few checks to ensure that our setup is useful. + assert web_client_urls[0] != web_client_urls[1] + assert expected_client_url_1 != expected_client_url_2 + assert expected_client_url_1.startswith(web_client_urls[0]) + assert expected_client_url_2.startswith(web_client_urls[1]) + + assert field1.get_subfields("u")[0] == expected_client_url_1 + assert field2.get_subfields("u")[0] == expected_client_url_2 diff --git a/tests/manager/marc/test_exporter.py b/tests/manager/marc/test_exporter.py new file mode 100644 index 0000000000..4d40b5f2c0 --- /dev/null +++ b/tests/manager/marc/test_exporter.py @@ -0,0 +1,425 @@ +import datetime +from functools import partial +from unittest.mock import ANY, call, create_autospec +from uuid import UUID + +import pytest +from freezegun import freeze_time + +from palace.manager.marc.exporter import LibraryInfo, MarcExporter +from palace.manager.marc.settings import MarcExporterLibrarySettings +from palace.manager.marc.uploader import MarcUploader +from palace.manager.sqlalchemy.model.discovery_service_registration import ( + DiscoveryServiceRegistration, +) +from palace.manager.sqlalchemy.model.marcfile import MarcFile +from palace.manager.sqlalchemy.util import create, get_one +from palace.manager.util.datetime_helpers import datetime_utc, utc_now +from tests.fixtures.database import DatabaseTransactionFixture +from tests.fixtures.marc import MarcExporterFixture + + +class TestMarcExporter: + def test__s3_key(self, marc_exporter_fixture: MarcExporterFixture) -> None: + library = marc_exporter_fixture.library1 + collection = marc_exporter_fixture.collection1 + + uuid = UUID("c2370bf2-28e1-40ff-9f04-4864306bd11c") + now = datetime_utc(2024, 8, 27) + since = datetime_utc(2024, 8, 20) + + s3_key = partial(MarcExporter._s3_key, library, collection, now, uuid) + + assert ( + s3_key() + == f"marc/{library.short_name}/{collection.name}.full.2024-08-27.wjcL8ijhQP-fBEhkMGvRHA.mrc" + ) + + assert ( + s3_key(since_time=since) + == f"marc/{library.short_name}/{collection.name}.delta.2024-08-20.2024-08-27.wjcL8ijhQP-fBEhkMGvRHA.mrc" + ) + + @freeze_time("2020-02-20T10:00:00Z") + @pytest.mark.parametrize( + "last_updated_time, update_frequency, expected", + [ + (None, 60, True), + (None, 1, True), + (datetime.datetime.fromisoformat("2020-02-20T09:00:00"), 1, False), + (datetime.datetime.fromisoformat("2020-02-19T10:02:00"), 1, True), + (datetime.datetime.fromisoformat("2020-01-31T10:02:00"), 20, True), + (datetime.datetime.fromisoformat("2020-02-01T10:00:00"), 20, False), + ], + ) + def test__needs_update( + self, + last_updated_time: datetime.datetime, + update_frequency: int, + expected: bool, + ): + assert ( + MarcExporter._needs_update(last_updated_time, update_frequency) == expected + ) + + def test__web_client_urls( + self, + db: DatabaseTransactionFixture, + marc_exporter_fixture: MarcExporterFixture, + ): + library = marc_exporter_fixture.library1 + web_client_urls = partial(MarcExporter._web_client_urls, db.session, library) + + # No web client URLs are returned if there are no discovery service registrations. + assert web_client_urls() == () + + # If we pass in a configured web client URL, that URL is returned. + assert web_client_urls(url="http://web-client") == ("http://web-client",) + + # Add a URL from a library registry. + registry = db.discovery_service_integration() + create( + db.session, + DiscoveryServiceRegistration, + library=library, + integration=registry, + web_client="http://web-client/registry", + ) + assert web_client_urls() == ("http://web-client/registry",) + + # URL from library registry and configured URL are both returned. + assert web_client_urls(url="http://web-client") == ( + "http://web-client/registry", + "http://web-client", + ) + + def test__enabled_collections_and_libraries( + self, + db: DatabaseTransactionFixture, + marc_exporter_fixture: MarcExporterFixture, + ) -> None: + enabled_collections_and_libraries = partial( + MarcExporter._enabled_collections_and_libraries, + db.session, + marc_exporter_fixture.registry, + ) + + assert enabled_collections_and_libraries() == set() + + # Marc export is enabled on the collections, but since the libraries don't have a marc exporter, they are + # not included. + marc_exporter_fixture.collection1.export_marc_records = True + marc_exporter_fixture.collection2.export_marc_records = True + assert enabled_collections_and_libraries() == set() + + # Marc export is enabled, but no libraries are added to it + marc_integration = marc_exporter_fixture.integration() + assert enabled_collections_and_libraries() == set() + + # Add a marc exporter to library1 + marc_l1_config = db.integration_library_configuration( + marc_integration, marc_exporter_fixture.library1 + ) + assert enabled_collections_and_libraries() == { + (marc_exporter_fixture.collection1, marc_l1_config), + (marc_exporter_fixture.collection2, marc_l1_config), + } + + # Add a marc exporter to library2 + marc_l2_config = db.integration_library_configuration( + marc_integration, marc_exporter_fixture.library2 + ) + assert enabled_collections_and_libraries() == { + (marc_exporter_fixture.collection1, marc_l1_config), + (marc_exporter_fixture.collection1, marc_l2_config), + (marc_exporter_fixture.collection2, marc_l1_config), + } + + # Enable marc export on collection3 + marc_exporter_fixture.collection3.export_marc_records = True + assert enabled_collections_and_libraries() == { + (marc_exporter_fixture.collection1, marc_l1_config), + (marc_exporter_fixture.collection1, marc_l2_config), + (marc_exporter_fixture.collection2, marc_l1_config), + (marc_exporter_fixture.collection3, marc_l2_config), + } + + # We can also filter by a collection id + assert enabled_collections_and_libraries( + collection_id=marc_exporter_fixture.collection1.id + ) == { + (marc_exporter_fixture.collection1, marc_l1_config), + (marc_exporter_fixture.collection1, marc_l2_config), + } + + def test__last_updated(self, marc_exporter_fixture: MarcExporterFixture) -> None: + library = marc_exporter_fixture.library1 + collection = marc_exporter_fixture.collection1 + + last_updated = partial( + MarcExporter._last_updated, + marc_exporter_fixture.session, + library, + collection, + ) + + # If there is no cached file, we return None. + assert last_updated() is None + + # If there is a cached file, we return the time it was created. + file1 = MarcFile( + library=library, + collection=collection, + created=datetime_utc(1984, 5, 8), + key="file1", + ) + marc_exporter_fixture.session.add(file1) + assert last_updated() == file1.created + + # If there are multiple cached files, we return the time of the most recent one. + file2 = MarcFile( + library=library, + collection=collection, + created=utc_now(), + key="file2", + ) + marc_exporter_fixture.session.add(file2) + assert last_updated() == file2.created + + def test_enabled_collections( + self, + db: DatabaseTransactionFixture, + marc_exporter_fixture: MarcExporterFixture, + ): + enabled_collections = partial( + MarcExporter.enabled_collections, + db.session, + marc_exporter_fixture.registry, + ) + + assert enabled_collections() == set() + + # Marc export is enabled on the collections, but since the libraries don't have a marc exporter, they are + # not included. + marc_exporter_fixture.collection1.export_marc_records = True + marc_exporter_fixture.collection2.export_marc_records = True + assert enabled_collections() == set() + + # Marc export is enabled, but no libraries are added to it + marc_integration = marc_exporter_fixture.integration() + assert enabled_collections() == set() + + # Add a marc exporter to library2 + db.integration_library_configuration( + marc_integration, marc_exporter_fixture.library2 + ) + assert enabled_collections() == {marc_exporter_fixture.collection1} + + # Enable marc export on collection3 + marc_exporter_fixture.collection3.export_marc_records = True + assert enabled_collections() == { + marc_exporter_fixture.collection1, + marc_exporter_fixture.collection3, + } + + def test_enabled_libraries( + self, + db: DatabaseTransactionFixture, + marc_exporter_fixture: MarcExporterFixture, + ): + assert marc_exporter_fixture.collection1.id is not None + enabled_libraries = partial( + MarcExporter.enabled_libraries, + db.session, + marc_exporter_fixture.registry, + collection_id=marc_exporter_fixture.collection1.id, + ) + + assert enabled_libraries() == [] + + # Collections have marc export enabled, and the marc exporter integration is setup, but + # no libraries are configured to use it. + marc_exporter_fixture.collection1.export_marc_records = True + marc_exporter_fixture.collection2.export_marc_records = True + marc_integration = marc_exporter_fixture.integration() + assert enabled_libraries() == [] + + # Add a marc exporter to library2 + db.integration_library_configuration( + marc_integration, + marc_exporter_fixture.library2, + MarcExporterLibrarySettings( + organization_code="org", web_client_url="http://web-client" + ), + ) + [library_2_info] = enabled_libraries() + + def assert_library_2(library_info: LibraryInfo) -> None: + assert library_info.library_id == marc_exporter_fixture.library2.id + assert ( + library_info.library_short_name + == marc_exporter_fixture.library2.short_name + ) + assert library_info.last_updated is None + assert library_info.needs_update + assert library_info.organization_code == "org" + assert library_info.include_summary is False + assert library_info.include_genres is False + assert library_info.web_client_urls == ("http://web-client",) + assert library_info.s3_key_full.startswith("marc/library2/collection1.full") + assert library_info.s3_key_delta is None + + assert_library_2(library_2_info) + + # Add a marc exporter to library1 + db.integration_library_configuration( + marc_integration, + marc_exporter_fixture.library1, + MarcExporterLibrarySettings( + organization_code="org2", include_summary=True, include_genres=True + ), + ) + [library_1_info, library_2_info] = enabled_libraries() + assert_library_2(library_2_info) + + assert library_1_info.library_id == marc_exporter_fixture.library1.id + assert ( + library_1_info.library_short_name + == marc_exporter_fixture.library1.short_name + ) + assert library_1_info.last_updated is None + assert library_1_info.needs_update + assert library_1_info.organization_code == "org2" + assert library_1_info.include_summary is True + assert library_1_info.include_genres is True + assert library_1_info.web_client_urls == () + assert library_1_info.s3_key_full.startswith("marc/library1/collection1.full") + assert library_1_info.s3_key_delta is None + + def test_query_works(self, marc_exporter_fixture: MarcExporterFixture) -> None: + assert marc_exporter_fixture.collection1.id is not None + query_works = partial( + MarcExporter.query_works, + marc_exporter_fixture.session, + collection_id=marc_exporter_fixture.collection1.id, + work_id_offset=None, + batch_size=3, + ) + + assert query_works() == [] + + works = marc_exporter_fixture.works() + + assert query_works() == works[:3] + assert query_works(work_id_offset=works[3].id) == works[4:] + + def test_collection(self, marc_exporter_fixture: MarcExporterFixture) -> None: + collection_id = marc_exporter_fixture.collection1.id + assert collection_id is not None + collection = MarcExporter.collection( + marc_exporter_fixture.session, collection_id + ) + assert collection == marc_exporter_fixture.collection1 + + marc_exporter_fixture.session.delete(collection) + collection = MarcExporter.collection( + marc_exporter_fixture.session, collection_id + ) + assert collection is None + + def test_process_work(self, marc_exporter_fixture: MarcExporterFixture) -> None: + marc_exporter_fixture.configure_export() + + collection = marc_exporter_fixture.collection1 + work = marc_exporter_fixture.work(collection) + enabled_libraries = marc_exporter_fixture.enabled_libraries(collection) + + mock_uploader = create_autospec(MarcUploader) + + process_work = partial( + MarcExporter.process_work, + work, + enabled_libraries, + "http://base.url", + uploader=mock_uploader, + ) + + process_work() + mock_uploader.add_record.assert_has_calls( + [ + call(enabled_libraries[0].s3_key_full, ANY), + call(enabled_libraries[0].s3_key_delta, ANY), + call(enabled_libraries[1].s3_key_full, ANY), + ] + ) + + # If the work has no license pools, it is skipped. + mock_uploader.reset_mock() + work.license_pools = [] + process_work() + mock_uploader.add_record.assert_not_called() + + def test_create_marc_upload_records( + self, marc_exporter_fixture: MarcExporterFixture + ) -> None: + marc_exporter_fixture.configure_export() + + collection = marc_exporter_fixture.collection1 + assert collection.id is not None + enabled_libraries = marc_exporter_fixture.enabled_libraries(collection) + + marc_exporter_fixture.session.query(MarcFile).delete() + + start_time = utc_now() + + # If there are no uploads, then no records are created. + MarcExporter.create_marc_upload_records( + marc_exporter_fixture.session, + start_time, + collection.id, + enabled_libraries, + set(), + ) + + assert len(marc_exporter_fixture.session.query(MarcFile).all()) == 0 + + # If there are uploads, then records are created. + assert enabled_libraries[0].s3_key_delta is not None + MarcExporter.create_marc_upload_records( + marc_exporter_fixture.session, + start_time, + collection.id, + enabled_libraries, + { + enabled_libraries[0].s3_key_full, + enabled_libraries[1].s3_key_full, + enabled_libraries[0].s3_key_delta, + }, + ) + + assert len(marc_exporter_fixture.session.query(MarcFile).all()) == 3 + + assert get_one( + marc_exporter_fixture.session, + MarcFile, + collection=collection, + library_id=enabled_libraries[0].library_id, + key=enabled_libraries[0].s3_key_full, + ) + + assert get_one( + marc_exporter_fixture.session, + MarcFile, + collection=collection, + library_id=enabled_libraries[1].library_id, + key=enabled_libraries[1].s3_key_full, + ) + + assert get_one( + marc_exporter_fixture.session, + MarcFile, + collection=collection, + library_id=enabled_libraries[0].library_id, + key=enabled_libraries[0].s3_key_delta, + since=enabled_libraries[0].last_updated, + ) diff --git a/tests/manager/marc/test_uploader.py b/tests/manager/marc/test_uploader.py new file mode 100644 index 0000000000..90fd32623e --- /dev/null +++ b/tests/manager/marc/test_uploader.py @@ -0,0 +1,314 @@ +from unittest.mock import MagicMock, call + +import pytest +from celery.exceptions import Ignore, Retry + +from palace.manager.marc.uploader import MarcUploader +from palace.manager.service.redis.models.marc import MarcFileUpload, MarcFileUploads +from palace.manager.sqlalchemy.model.resource import Representation +from tests.fixtures.redis import RedisFixture +from tests.fixtures.s3 import S3ServiceFixture + + +class MarcUploaderFixture: + def __init__( + self, redis_fixture: RedisFixture, s3_service_fixture: S3ServiceFixture + ): + self._redis_fixture = redis_fixture + self._s3_service_fixture = s3_service_fixture + + self.test_key1 = "test.123" + self.test_record1 = b"test_record_123" + self.test_key2 = "test*456" + self.test_record2 = b"test_record_456" + self.test_key3 = "test--?789" + self.test_record3 = b"test_record_789" + + self.mock_s3_service = s3_service_fixture.mock_service() + # Reduce the minimum upload size to make testing easier + self.mock_s3_service.MINIMUM_MULTIPART_UPLOAD_SIZE = len(self.test_record1) * 4 + self.redis_client = redis_fixture.client + + self.mock_collection_id = 52 + + self.uploads = MarcFileUploads(self.redis_client, self.mock_collection_id) + self.uploader = MarcUploader(self.mock_s3_service, self.uploads) + + +@pytest.fixture +def marc_uploader_fixture( + redis_fixture: RedisFixture, s3_service_fixture: S3ServiceFixture +): + return MarcUploaderFixture(redis_fixture, s3_service_fixture) + + +class TestMarcUploader: + def test_begin( + self, marc_uploader_fixture: MarcUploaderFixture, redis_fixture: RedisFixture + ): + uploader = marc_uploader_fixture.uploader + + assert uploader.locked is False + assert marc_uploader_fixture.uploads.locked(by_us=True) is False + + with uploader.begin() as u: + # The context manager returns the uploader object + assert u is uploader + + # It directly tells us the lock status + assert uploader.locked is True + + # The lock is also reflected in the uploads object + assert marc_uploader_fixture.uploads.locked(by_us=True) is True # type: ignore[unreachable] + + # The lock is released after the context manager exits + assert uploader.locked is False # type: ignore[unreachable] + assert marc_uploader_fixture.uploads.locked(by_us=True) is False + + # If an exception occurs, the lock is deleted and the exception is raised by calling + # the _abort method + mock_abort = MagicMock(wraps=uploader._abort) + uploader._abort = mock_abort + with pytest.raises(Exception): + with uploader.begin(): + assert uploader.locked is True + raise Exception() + assert ( + redis_fixture.client.json().get(marc_uploader_fixture.uploads.key) is None + ) + mock_abort.assert_called_once() + + # If a expected celery exception occurs, the lock is released, but not deleted + # and the abort method isn't called + mock_abort.reset_mock() + for exception in Retry, Ignore: + with pytest.raises(exception): + with uploader.begin(): + assert uploader.locked is True + raise exception() + assert marc_uploader_fixture.uploads.locked(by_us=True) is False + assert ( + redis_fixture.client.json().get(marc_uploader_fixture.uploads.key) + is not None + ) + mock_abort.assert_not_called() + + def test_add_record(self, marc_uploader_fixture: MarcUploaderFixture): + uploader = marc_uploader_fixture.uploader + + uploader.add_record( + marc_uploader_fixture.test_key1, marc_uploader_fixture.test_record1 + ) + assert ( + uploader._buffers[marc_uploader_fixture.test_key1] + == marc_uploader_fixture.test_record1.decode() + ) + + uploader.add_record( + marc_uploader_fixture.test_key1, marc_uploader_fixture.test_record1 + ) + assert ( + uploader._buffers[marc_uploader_fixture.test_key1] + == marc_uploader_fixture.test_record1.decode() * 2 + ) + + def test_sync(self, marc_uploader_fixture: MarcUploaderFixture): + uploader = marc_uploader_fixture.uploader + + uploader.add_record( + marc_uploader_fixture.test_key1, marc_uploader_fixture.test_record1 + ) + uploader.add_record( + marc_uploader_fixture.test_key2, marc_uploader_fixture.test_record2 * 2 + ) + with uploader.begin(): + uploader.sync() + + # Sync clears the local buffer + assert uploader._buffers == {} + + # And pushes the local records to redis + assert marc_uploader_fixture.uploads.get() == { + marc_uploader_fixture.test_key1: MarcFileUpload( + buffer=marc_uploader_fixture.test_record1 + ), + marc_uploader_fixture.test_key2: MarcFileUpload( + buffer=marc_uploader_fixture.test_record2 * 2 + ), + } + + # Because the buffer did not contain enough data, it was not uploaded to S3 + assert marc_uploader_fixture.mock_s3_service.upload_in_progress == {} + + # Add enough data for test_key1 to be uploaded to S3 + uploader.add_record( + marc_uploader_fixture.test_key1, marc_uploader_fixture.test_record1 * 2 + ) + uploader.add_record( + marc_uploader_fixture.test_key1, marc_uploader_fixture.test_record1 * 2 + ) + uploader.add_record( + marc_uploader_fixture.test_key2, marc_uploader_fixture.test_record2 + ) + + with uploader.begin(): + uploader.sync() + + # The buffer is cleared + assert uploader._buffers == {} + + # Because the data for test_key1 was large enough, it was uploaded to S3, and its redis data structure was + # updated to reflect this. test_key2 was not large enough to upload, so it remains in redis and not in s3. + redis_data = marc_uploader_fixture.uploads.get() + assert redis_data[marc_uploader_fixture.test_key2] == MarcFileUpload( + buffer=marc_uploader_fixture.test_record2 * 3 + ) + redis_data_test1 = redis_data[marc_uploader_fixture.test_key1] + assert redis_data_test1.buffer == "" + + assert len(marc_uploader_fixture.mock_s3_service.upload_in_progress) == 1 + assert ( + marc_uploader_fixture.test_key1 + in marc_uploader_fixture.mock_s3_service.upload_in_progress + ) + upload = marc_uploader_fixture.mock_s3_service.upload_in_progress[ + marc_uploader_fixture.test_key1 + ] + assert upload.upload_id is not None + assert upload.content_type is Representation.MARC_MEDIA_TYPE + [part] = upload.parts + assert part.content == marc_uploader_fixture.test_record1 * 5 + + # And the s3 part data and upload_id is synced to redis + assert redis_data_test1.parts == [part.part_data] + assert redis_data_test1.upload_id == upload.upload_id + + def test_complete(self, marc_uploader_fixture: MarcUploaderFixture): + uploader = marc_uploader_fixture.uploader + + # Wrap the clear method so we can check if it was called + mock_clear_uploads = MagicMock( + wraps=marc_uploader_fixture.uploads.clear_uploads + ) + marc_uploader_fixture.uploads.clear_uploads = mock_clear_uploads + + # Set up the records for the test + uploader.add_record( + marc_uploader_fixture.test_key1, marc_uploader_fixture.test_record1 * 5 + ) + uploader.add_record( + marc_uploader_fixture.test_key2, marc_uploader_fixture.test_record2 * 5 + ) + with uploader.begin(): + uploader.sync() + + uploader.add_record( + marc_uploader_fixture.test_key1, marc_uploader_fixture.test_record1 * 5 + ) + with uploader.begin(): + uploader.sync() + + uploader.add_record( + marc_uploader_fixture.test_key2, marc_uploader_fixture.test_record2 + ) + + uploader.add_record( + marc_uploader_fixture.test_key3, marc_uploader_fixture.test_record3 + ) + + # Complete the uploads + with uploader.begin(): + completed = uploader.complete() + + # The complete method should return the keys that were completed + assert completed == { + marc_uploader_fixture.test_key1, + marc_uploader_fixture.test_key2, + marc_uploader_fixture.test_key3, + } + + # The local buffers should be empty + assert uploader._buffers == {} + + # The redis record should have the completed uploads cleared + mock_clear_uploads.assert_called_once() + + # The s3 service should have the completed uploads + assert len(marc_uploader_fixture.mock_s3_service.uploads) == 3 + assert len(marc_uploader_fixture.mock_s3_service.upload_in_progress) == 0 + + test_key1_upload = marc_uploader_fixture.mock_s3_service.uploads[ + marc_uploader_fixture.test_key1 + ] + assert test_key1_upload.key == marc_uploader_fixture.test_key1 + assert test_key1_upload.content == marc_uploader_fixture.test_record1 * 10 + assert test_key1_upload.media_type == Representation.MARC_MEDIA_TYPE + + test_key2_upload = marc_uploader_fixture.mock_s3_service.uploads[ + marc_uploader_fixture.test_key2 + ] + assert test_key2_upload.key == marc_uploader_fixture.test_key2 + assert test_key2_upload.content == marc_uploader_fixture.test_record2 * 6 + assert test_key2_upload.media_type == Representation.MARC_MEDIA_TYPE + + test_key3_upload = marc_uploader_fixture.mock_s3_service.uploads[ + marc_uploader_fixture.test_key3 + ] + assert test_key3_upload.key == marc_uploader_fixture.test_key3 + assert test_key3_upload.content == marc_uploader_fixture.test_record3 + assert test_key3_upload.media_type == Representation.MARC_MEDIA_TYPE + + def test__abort( + self, + marc_uploader_fixture: MarcUploaderFixture, + caplog: pytest.LogCaptureFixture, + ): + uploader = marc_uploader_fixture.uploader + + # Set up the records for the test + uploader.add_record( + marc_uploader_fixture.test_key1, marc_uploader_fixture.test_record1 * 10 + ) + uploader.add_record( + marc_uploader_fixture.test_key2, marc_uploader_fixture.test_record2 * 10 + ) + with uploader.begin(): + uploader.sync() + + # Mock the multipart_abort method so we can check if it was called and have it + # raise an exception on the first call + mock_abort = MagicMock(side_effect=[Exception("Boom"), None]) + marc_uploader_fixture.mock_s3_service.multipart_abort = mock_abort + + # Wrap the delete method so we can check if it was called + mock_delete = MagicMock(wraps=marc_uploader_fixture.uploads.delete) + marc_uploader_fixture.uploads.delete = mock_delete + + upload_id_1 = marc_uploader_fixture.mock_s3_service.upload_in_progress[ + marc_uploader_fixture.test_key1 + ].upload_id + upload_id_2 = marc_uploader_fixture.mock_s3_service.upload_in_progress[ + marc_uploader_fixture.test_key2 + ].upload_id + + # Abort the uploads, the original exception should propagate, and the exception + # thrown by the first call to abort should be logged + with pytest.raises(Exception) as exc_info: + with uploader.begin(): + raise Exception("Bang") + assert str(exc_info.value) == "Bang" + + assert ( + f"Failed to abort upload {marc_uploader_fixture.test_key1} (UploadID: {upload_id_1}) due to exception (Boom)" + in caplog.text + ) + + mock_abort.assert_has_calls( + [ + call(marc_uploader_fixture.test_key1, upload_id_1), + call(marc_uploader_fixture.test_key2, upload_id_2), + ] + ) + + # The redis record should have been deleted + mock_delete.assert_called_once() diff --git a/tests/manager/scripts/test_marc.py b/tests/manager/scripts/test_marc.py deleted file mode 100644 index 3b83d359fb..0000000000 --- a/tests/manager/scripts/test_marc.py +++ /dev/null @@ -1,466 +0,0 @@ -from __future__ import annotations - -import datetime -import logging -from unittest.mock import MagicMock, call, create_autospec - -import pytest -from _pytest.logging import LogCaptureFixture -from sqlalchemy.exc import NoResultFound - -from palace.manager.core.config import CannotLoadConfiguration -from palace.manager.core.marc import ( - MARCExporter, - MarcExporterLibrarySettings, - MarcExporterSettings, -) -from palace.manager.integration.goals import Goals -from palace.manager.scripts.marc import CacheMARCFiles -from palace.manager.sqlalchemy.model.discovery_service_registration import ( - DiscoveryServiceRegistration, -) -from palace.manager.sqlalchemy.model.integration import IntegrationConfiguration -from palace.manager.sqlalchemy.model.library import Library -from palace.manager.sqlalchemy.model.marcfile import MarcFile -from palace.manager.sqlalchemy.util import create -from palace.manager.util.datetime_helpers import datetime_utc, utc_now -from tests.fixtures.database import DatabaseTransactionFixture -from tests.fixtures.services import ServicesFixture - - -class CacheMARCFilesFixture: - def __init__( - self, db: DatabaseTransactionFixture, services_fixture: ServicesFixture - ): - self.db = db - self.services_fixture = services_fixture - self.base_url = "http://test-circulation-manager" - services_fixture.set_base_url(self.base_url) - self.exporter = MagicMock(spec=MARCExporter) - self.library = self.db.default_library() - self.collection = self.db.collection() - self.collection.export_marc_records = True - self.collection.libraries += [self.library] - - def integration(self, library: Library | None = None) -> IntegrationConfiguration: - if library is None: - library = self.library - - return self.db.integration_configuration( - protocol=MARCExporter, - goal=Goals.CATALOG_GOAL, - libraries=[library], - ) - - def script(self, cmd_args: list[str] | None = None) -> CacheMARCFiles: - cmd_args = cmd_args or [] - return CacheMARCFiles( - self.db.session, - exporter=self.exporter, - services=self.services_fixture.services, - cmd_args=cmd_args, - ) - - -@pytest.fixture -def cache_marc_files( - db: DatabaseTransactionFixture, services_fixture: ServicesFixture -) -> CacheMARCFilesFixture: - return CacheMARCFilesFixture(db, services_fixture) - - -class TestCacheMARCFiles: - def test_constructor(self, cache_marc_files: CacheMARCFilesFixture): - cache_marc_files.services_fixture.set_base_url(None) - with pytest.raises(CannotLoadConfiguration): - cache_marc_files.script() - - cache_marc_files.services_fixture.set_base_url("http://test.com") - script = cache_marc_files.script() - assert script.base_url == "http://test.com" - - def test_settings(self, cache_marc_files: CacheMARCFilesFixture): - # Test that the script gets the correct settings. - test_library = cache_marc_files.library - other_library = cache_marc_files.db.library() - - expected_settings = MarcExporterSettings(update_frequency=3) - expected_library_settings = MarcExporterLibrarySettings( - organization_code="test", - include_summary=True, - include_genres=True, - ) - - other_library_settings = MarcExporterLibrarySettings( - organization_code="other", - ) - - integration = cache_marc_files.integration(test_library) - integration.libraries += [other_library] - - test_library_integration = integration.for_library(test_library) - assert test_library_integration is not None - other_library_integration = integration.for_library(other_library) - assert other_library_integration is not None - MARCExporter.settings_update(integration, expected_settings) - MARCExporter.library_settings_update( - test_library_integration, expected_library_settings - ) - MARCExporter.library_settings_update( - other_library_integration, other_library_settings - ) - - script = cache_marc_files.script() - actual_settings, actual_library_settings = script.settings(test_library) - - assert actual_settings == expected_settings - assert actual_library_settings == expected_library_settings - - def test_settings_none(self, cache_marc_files: CacheMARCFilesFixture): - # If there are no settings, the setting function raises an exception. - test_library = cache_marc_files.library - script = cache_marc_files.script() - with pytest.raises(NoResultFound): - script.settings(test_library) - - def test_process_libraries_no_storage( - self, cache_marc_files: CacheMARCFilesFixture, caplog: LogCaptureFixture - ): - # If there is no storage integration, the script logs an error and returns. - script = cache_marc_files.script() - script.storage_service = None - caplog.set_level(logging.INFO) - script.process_libraries([MagicMock(), MagicMock()]) - assert "No storage service was found" in caplog.text - - def test_get_collections(self, cache_marc_files: CacheMARCFilesFixture): - # Test that the script gets the correct collections. - test_library = cache_marc_files.library - collection1 = cache_marc_files.collection - - # Second collection is configured to export MARC records. - collection2 = cache_marc_files.db.collection() - collection2.export_marc_records = True - collection2.libraries += [test_library] - - # Third collection is not configured to export MARC records. - collection3 = cache_marc_files.db.collection() - collection3.export_marc_records = False - collection3.libraries += [test_library] - - # Fourth collection is configured to export MARC records, but is - # configured to export only to a different library. - other_library = cache_marc_files.db.library() - other_collection = cache_marc_files.db.collection() - other_collection.export_marc_records = True - other_collection.libraries += [other_library] - - script = cache_marc_files.script() - - # We should get back the two collections that are configured to export - # MARC records to this library. - collections = script.get_collections(test_library) - assert set(collections) == {collection1, collection2} - - # Set collection3 to export MARC records to this library. - collection3.export_marc_records = True - - # We should get back all three collections that are configured to export - # MARC records to this library. - collections = script.get_collections(test_library) - assert set(collections) == {collection1, collection2, collection3} - - def test_get_web_client_urls( - self, - db: DatabaseTransactionFixture, - cache_marc_files: CacheMARCFilesFixture, - ): - # No web client URLs are returned if there are no discovery service registrations. - script = cache_marc_files.script() - assert script.get_web_client_urls(cache_marc_files.library) == [] - - # If we pass in a configured web client URL, that URL is returned. - assert script.get_web_client_urls( - cache_marc_files.library, "http://web-client" - ) == ["http://web-client"] - - # Add a URL from a library registry. - registry = db.discovery_service_integration() - create( - db.session, - DiscoveryServiceRegistration, - library=cache_marc_files.library, - integration=registry, - web_client="http://web-client-url/", - ) - assert script.get_web_client_urls(cache_marc_files.library) == [ - "http://web-client-url/" - ] - - # URL from library registry and configured URL are both returned. - assert script.get_web_client_urls( - cache_marc_files.library, "http://web-client" - ) == [ - "http://web-client-url/", - "http://web-client", - ] - - def test_process_library_not_configured( - self, - cache_marc_files: CacheMARCFilesFixture, - ): - script = cache_marc_files.script() - mock_process_collection = create_autospec(script.process_collection) - script.process_collection = mock_process_collection - mock_settings = create_autospec(script.settings) - script.settings = mock_settings - mock_settings.side_effect = NoResultFound - - # If there is no integration configuration for the library, the script - # does nothing. - script.process_library(cache_marc_files.library) - mock_process_collection.assert_not_called() - - def test_process_library(self, cache_marc_files: CacheMARCFilesFixture): - script = cache_marc_files.script() - mock_annotator_cls = MagicMock() - mock_process_collection = create_autospec(script.process_collection) - script.process_collection = mock_process_collection - mock_settings = create_autospec(script.settings) - script.settings = mock_settings - settings = MarcExporterSettings(update_frequency=3) - library_settings = MarcExporterLibrarySettings( - organization_code="test", - web_client_url="http://web-client-url/", - include_summary=True, - include_genres=False, - ) - mock_settings.return_value = ( - settings, - library_settings, - ) - - before_call_time = utc_now() - - # If there is an integration configuration for the library, the script - # processes all the collections for that library. - script.process_library( - cache_marc_files.library, annotator_cls=mock_annotator_cls - ) - - after_call_time = utc_now() - - mock_annotator_cls.assert_called_once_with( - cache_marc_files.base_url, - cache_marc_files.library.short_name, - [library_settings.web_client_url], - library_settings.organization_code, - library_settings.include_summary, - library_settings.include_genres, - ) - - assert mock_process_collection.call_count == 1 - ( - library, - collection, - annotator, - update_frequency, - creation_time, - ) = mock_process_collection.call_args.args - assert library == cache_marc_files.library - assert collection == cache_marc_files.collection - assert annotator == mock_annotator_cls.return_value - assert update_frequency == settings.update_frequency - assert creation_time > before_call_time - assert creation_time < after_call_time - - def test_last_updated( - self, db: DatabaseTransactionFixture, cache_marc_files: CacheMARCFilesFixture - ): - script = cache_marc_files.script() - - # If there is no cached file, we return None. - assert ( - script.last_updated(cache_marc_files.library, cache_marc_files.collection) - is None - ) - - # If there is a cached file, we return the time it was created. - file1 = MarcFile( - library=cache_marc_files.library, - collection=cache_marc_files.collection, - created=datetime_utc(1984, 5, 8), - key="file1", - ) - db.session.add(file1) - assert ( - script.last_updated(cache_marc_files.library, cache_marc_files.collection) - == file1.created - ) - - # If there are multiple cached files, we return the time of the most recent one. - file2 = MarcFile( - library=cache_marc_files.library, - collection=cache_marc_files.collection, - created=utc_now(), - key="file2", - ) - db.session.add(file2) - assert ( - script.last_updated(cache_marc_files.library, cache_marc_files.collection) - == file2.created - ) - - def test_force(self, cache_marc_files: CacheMARCFilesFixture): - script = cache_marc_files.script() - assert script.force is False - - script = cache_marc_files.script(cmd_args=["--force"]) - assert script.force is True - - @pytest.mark.parametrize( - "last_updated, force, update_frequency, run_exporter", - [ - pytest.param(None, False, 10, True, id="never_run_before"), - pytest.param(None, False, 10, True, id="never_run_before_w_force"), - pytest.param( - utc_now() - datetime.timedelta(days=5), - False, - 10, - False, - id="recently_run", - ), - pytest.param( - utc_now() - datetime.timedelta(days=5), - True, - 10, - True, - id="recently_run_w_force", - ), - pytest.param( - utc_now() - datetime.timedelta(days=5), - False, - 0, - True, - id="recently_run_w_frequency_0", - ), - pytest.param( - utc_now() - datetime.timedelta(days=15), - False, - 10, - True, - id="not_recently_run", - ), - pytest.param( - utc_now() - datetime.timedelta(days=15), - True, - 10, - True, - id="not_recently_run_w_force", - ), - pytest.param( - utc_now() - datetime.timedelta(days=15), - False, - 0, - True, - id="not_recently_run_w_frequency_0", - ), - ], - ) - def test_process_collection_skip( - self, - cache_marc_files: CacheMARCFilesFixture, - caplog: LogCaptureFixture, - last_updated: datetime.datetime | None, - force: bool, - update_frequency: int, - run_exporter: bool, - ): - script = cache_marc_files.script() - script.exporter = MagicMock() - now = utc_now() - caplog.set_level(logging.INFO) - - script.force = force - script.last_updated = MagicMock(return_value=last_updated) - script.process_collection( - cache_marc_files.library, - cache_marc_files.collection, - MagicMock(), - update_frequency, - now, - ) - - if run_exporter: - assert script.exporter.records.call_count > 0 - assert "Processed collection" in caplog.text - else: - assert script.exporter.records.call_count == 0 - assert "Skipping collection" in caplog.text - - def test_process_collection_never_called( - self, cache_marc_files: CacheMARCFilesFixture, caplog: LogCaptureFixture - ): - # If the collection has not been processed before, the script processes - # the collection and created a full export. - caplog.set_level(logging.INFO) - script = cache_marc_files.script() - mock_exporter = MagicMock(spec=MARCExporter) - script.exporter = mock_exporter - script.last_updated = MagicMock(return_value=None) - mock_annotator = MagicMock() - creation_time = utc_now() - script.process_collection( - cache_marc_files.library, - cache_marc_files.collection, - mock_annotator, - 10, - creation_time, - ) - mock_exporter.records.assert_called_once_with( - cache_marc_files.library, - cache_marc_files.collection, - mock_annotator, - creation_time=creation_time, - ) - assert "Processed collection" in caplog.text - - def test_process_collection_with_last_updated( - self, cache_marc_files: CacheMARCFilesFixture, caplog: LogCaptureFixture - ): - # If the collection has been processed before, the script processes - # the collection, created a full export and a delta export. - caplog.set_level(logging.INFO) - script = cache_marc_files.script() - mock_exporter = MagicMock(spec=MARCExporter) - script.exporter = mock_exporter - last_updated = utc_now() - datetime.timedelta(days=20) - script.last_updated = MagicMock(return_value=last_updated) - mock_annotator = MagicMock() - creation_time = utc_now() - script.process_collection( - cache_marc_files.library, - cache_marc_files.collection, - mock_annotator, - 10, - creation_time, - ) - assert "Processed collection" in caplog.text - assert mock_exporter.records.call_count == 2 - - full_call = call( - cache_marc_files.library, - cache_marc_files.collection, - mock_annotator, - creation_time=creation_time, - ) - - delta_call = call( - cache_marc_files.library, - cache_marc_files.collection, - mock_annotator, - creation_time=creation_time, - since_time=last_updated, - ) - - mock_exporter.records.assert_has_calls([full_call, delta_call]) diff --git a/tests/manager/service/redis/models/test_lock.py b/tests/manager/service/redis/models/test_lock.py index 93cd874709..c317db7a84 100644 --- a/tests/manager/service/redis/models/test_lock.py +++ b/tests/manager/service/redis/models/test_lock.py @@ -253,6 +253,9 @@ def test_acquire(self, json_lock_fixture: JsonLockFixture): json_lock_fixture.assert_locked(json_lock_fixture.lock) def test_release(self, json_lock_fixture: JsonLockFixture): + # If the lock doesn't exist, we can't release it + assert json_lock_fixture.lock.release() is False + # If you acquire a lock another client cannot release it assert json_lock_fixture.lock.acquire() assert json_lock_fixture.other_lock.release() is False @@ -267,6 +270,8 @@ def test_release(self, json_lock_fixture: JsonLockFixture): assert json_lock_fixture.get_key(json_lock_fixture.lock.key, "$") == {} def test_delete(self, json_lock_fixture: JsonLockFixture): + assert json_lock_fixture.lock.delete() is False + # If you acquire a lock another client cannot delete it assert json_lock_fixture.lock.acquire() assert json_lock_fixture.other_lock.delete() is False @@ -282,6 +287,8 @@ def test_delete(self, json_lock_fixture: JsonLockFixture): assert json_lock_fixture.get_key(json_lock_fixture.lock.key, "$") is None def test_extend_timeout(self, json_lock_fixture: JsonLockFixture): + assert json_lock_fixture.lock.extend_timeout() is False + # If the lock has a timeout, the acquiring client can extend it, but another client cannot assert json_lock_fixture.lock.acquire() json_lock_fixture.client.pexpire(json_lock_fixture.lock.key, 500) @@ -307,3 +314,22 @@ def test_locked(self, json_lock_fixture: JsonLockFixture): assert json_lock_fixture.lock.release() is True assert json_lock_fixture.lock.locked() is False assert json_lock_fixture.other_lock.locked() is False + + def test__parse_value(self): + assert RedisJsonLock._parse_value(None) is None + assert RedisJsonLock._parse_value([]) is None + assert RedisJsonLock._parse_value(["value"]) == "value" + + def test__parse_multi(self): + assert RedisJsonLock._parse_multi(None) == {} + assert RedisJsonLock._parse_multi({}) == {} + assert RedisJsonLock._parse_multi( + {"key": ["value"], "key2": ["value2"], "key3": []} + ) == {"key": "value", "key2": "value2", "key3": None} + + def test__parse_value_or_raise(self): + with pytest.raises(LockError): + RedisJsonLock._parse_value_or_raise(None) + with pytest.raises(LockError): + RedisJsonLock._parse_value_or_raise([]) + assert RedisJsonLock._parse_value_or_raise(["value"]) == "value" diff --git a/tests/manager/service/redis/models/test_marc.py b/tests/manager/service/redis/models/test_marc.py new file mode 100644 index 0000000000..4e5a2778ee --- /dev/null +++ b/tests/manager/service/redis/models/test_marc.py @@ -0,0 +1,406 @@ +import pytest + +from palace.manager.service.redis.models.marc import ( + MarcFileUpload, + MarcFileUploads, + RedisMarcError, +) +from palace.manager.service.redis.redis import Pipeline +from palace.manager.service.storage.s3 import MultipartS3UploadPart +from tests.fixtures.redis import RedisFixture + + +class MarcFileUploadsFixture: + def __init__(self, redis_fixture: RedisFixture): + self._redis_fixture = redis_fixture + + self.mock_collection_id = 1 + + self.uploads = MarcFileUploads( + self._redis_fixture.client, self.mock_collection_id + ) + + self.mock_upload_key_1 = "test1" + self.mock_upload_key_2 = "test2" + self.mock_upload_key_3 = "test3" + + self.mock_unset_upload_key = "test4" + + self.test_data = { + self.mock_upload_key_1: "test", + self.mock_upload_key_2: "another_test", + self.mock_upload_key_3: "another_another_test", + } + + self.part_1 = MultipartS3UploadPart(etag="abc", part_number=1) + self.part_2 = MultipartS3UploadPart(etag="def", part_number=2) + + def load_test_data(self) -> dict[str, int]: + lock_acquired = False + if not self.uploads.locked(): + self.uploads.acquire() + lock_acquired = True + + return_value = self.uploads.append_buffers(self.test_data) + + if lock_acquired: + self.uploads.release() + + return return_value + + def test_data_records(self, *keys: str): + return {key: MarcFileUpload(buffer=self.test_data[key]) for key in keys} + + +@pytest.fixture +def marc_file_uploads_fixture(redis_fixture: RedisFixture): + return MarcFileUploadsFixture(redis_fixture) + + +class TestMarcFileUploads: + def test__pipeline(self, marc_file_uploads_fixture: MarcFileUploadsFixture): + uploads = marc_file_uploads_fixture.uploads + + # Using the _pipeline() context manager makes sure that we hold the lock + with pytest.raises(RedisMarcError) as exc_info: + with uploads._pipeline(): + pass + assert "Must hold lock" in str(exc_info.value) + + uploads.acquire() + + # It also checks that the update_number is correct + uploads._update_number = 1 + with pytest.raises(RedisMarcError) as exc_info: + with uploads._pipeline(): + pass + assert "Update number mismatch" in str(exc_info.value) + + uploads._update_number = 0 + with uploads._pipeline() as pipe: + # If the lock and update number are correct, we should get a pipeline object + assert isinstance(pipe, Pipeline) + + # We are watching the key for this object, so that we know all the data within the + # transaction is consistent, and we are still holding the lock when the pipeline + # executes + assert pipe.watching is True + + # By default it starts the pipeline transaction + assert pipe.explicit_transaction is True + + # We can also start the pipeline without a transaction + with uploads._pipeline(begin_transaction=False) as pipe: + assert pipe.explicit_transaction is False + + def test__execute_pipeline( + self, + marc_file_uploads_fixture: MarcFileUploadsFixture, + redis_fixture: RedisFixture, + ): + client = redis_fixture.client + uploads = marc_file_uploads_fixture.uploads + uploads.acquire() + + # If we try to execute a pipeline without a transaction, we should get an error + with pytest.raises(RedisMarcError) as exc_info: + with uploads._pipeline(begin_transaction=False) as pipe: + uploads._execute_pipeline(pipe, 0) + assert "Pipeline should be in explicit transaction mode" in str(exc_info.value) + + # The _execute_pipeline function takes care of extending the timeout and incrementing + # the update number. + [update_number] = client.json().get( + uploads.key, uploads._update_number_json_key + ) + client.pexpire(uploads.key, 500) + with uploads._pipeline() as pipe: + # If we execute the pipeline, we should get a list of results, excluding the + # operations that _execute_pipeline does. + assert uploads._execute_pipeline(pipe, 2) == [] + [new_update_number] = client.json().get( + uploads.key, uploads._update_number_json_key + ) + assert new_update_number == update_number + 2 + assert client.pttl(uploads.key) > 500 + + # If we try to execute a pipeline that has been modified by another process, we should get an error + with uploads._pipeline() as pipe: + client.json().set( + uploads.key, uploads._update_number_json_key, update_number + ) + with pytest.raises(RedisMarcError) as exc_info: + uploads._execute_pipeline(pipe, 1) + assert "Another process is modifying the buffers" in str(exc_info.value) + + def test_append_buffers(self, marc_file_uploads_fixture: MarcFileUploadsFixture): + uploads = marc_file_uploads_fixture.uploads + + # If we try to update buffers without acquiring the lock, we should get an error + with pytest.raises(RedisMarcError) as exc_info: + uploads.append_buffers( + {marc_file_uploads_fixture.mock_upload_key_1: "test"} + ) + assert "Must hold lock" in str(exc_info.value) + + # Acquire the lock and try to update buffers + with uploads.lock() as locked: + assert locked + assert uploads.append_buffers({}) == {} + + assert uploads.append_buffers( + { + marc_file_uploads_fixture.mock_upload_key_1: "test", + marc_file_uploads_fixture.mock_upload_key_2: "another_test", + } + ) == { + marc_file_uploads_fixture.mock_upload_key_1: 4, + marc_file_uploads_fixture.mock_upload_key_2: 12, + } + assert uploads._update_number == 2 + + assert uploads.append_buffers( + { + marc_file_uploads_fixture.mock_upload_key_1: "x", + marc_file_uploads_fixture.mock_upload_key_2: "y", + marc_file_uploads_fixture.mock_upload_key_3: "new", + } + ) == { + marc_file_uploads_fixture.mock_upload_key_1: 5, + marc_file_uploads_fixture.mock_upload_key_2: 13, + marc_file_uploads_fixture.mock_upload_key_3: 3, + } + assert uploads._update_number == 5 + + # If we try to update buffers with an old update number, we should get an error + uploads._update_number = 4 + with pytest.raises(RedisMarcError) as exc_info: + uploads.append_buffers(marc_file_uploads_fixture.test_data) + assert "Update number mismatch" in str(exc_info.value) + + # Exiting the context manager should release the lock + assert not uploads.locked() + + def test_get(self, marc_file_uploads_fixture: MarcFileUploadsFixture): + uploads = marc_file_uploads_fixture.uploads + + assert uploads.get() == {} + assert uploads.get(marc_file_uploads_fixture.mock_upload_key_1) == {} + + marc_file_uploads_fixture.load_test_data() + + # You don't need to acquire the lock to get the uploads, but you should if you + # are using the data to do updates. + + # You can get a subset of the uploads + assert uploads.get( + marc_file_uploads_fixture.mock_upload_key_1, + ) == marc_file_uploads_fixture.test_data_records( + marc_file_uploads_fixture.mock_upload_key_1 + ) + + # Or multiple uploads, any that don't exist are not included in the result dict + assert uploads.get( + [ + marc_file_uploads_fixture.mock_upload_key_1, + marc_file_uploads_fixture.mock_upload_key_2, + marc_file_uploads_fixture.mock_unset_upload_key, + ] + ) == marc_file_uploads_fixture.test_data_records( + marc_file_uploads_fixture.mock_upload_key_1, + marc_file_uploads_fixture.mock_upload_key_2, + ) + + # Or you can get all the uploads + assert uploads.get() == marc_file_uploads_fixture.test_data_records( + marc_file_uploads_fixture.mock_upload_key_1, + marc_file_uploads_fixture.mock_upload_key_2, + marc_file_uploads_fixture.mock_upload_key_3, + ) + + def test_set_upload_id(self, marc_file_uploads_fixture: MarcFileUploadsFixture): + uploads = marc_file_uploads_fixture.uploads + + # must hold lock to do update + with pytest.raises(RedisMarcError) as exc_info: + uploads.set_upload_id(marc_file_uploads_fixture.mock_upload_key_1, "xyz") + assert "Must hold lock" in str(exc_info.value) + + uploads.acquire() + + # We are unable to set an upload id for an item that hasn't been initialized + with pytest.raises(RedisMarcError) as exc_info: + uploads.set_upload_id(marc_file_uploads_fixture.mock_upload_key_1, "xyz") + assert "Failed to set upload ID" in str(exc_info.value) + + marc_file_uploads_fixture.load_test_data() + uploads.set_upload_id(marc_file_uploads_fixture.mock_upload_key_1, "def") + uploads.set_upload_id(marc_file_uploads_fixture.mock_upload_key_2, "abc") + + all_uploads = uploads.get() + assert ( + all_uploads[marc_file_uploads_fixture.mock_upload_key_1].upload_id == "def" + ) + assert ( + all_uploads[marc_file_uploads_fixture.mock_upload_key_2].upload_id == "abc" + ) + + # We can't change the upload id for a library that has already been set + with pytest.raises(RedisMarcError) as exc_info: + uploads.set_upload_id(marc_file_uploads_fixture.mock_upload_key_1, "ghi") + assert "Failed to set upload ID" in str(exc_info.value) + + all_uploads = uploads.get() + assert ( + all_uploads[marc_file_uploads_fixture.mock_upload_key_1].upload_id == "def" + ) + assert ( + all_uploads[marc_file_uploads_fixture.mock_upload_key_2].upload_id == "abc" + ) + + def test_clear_uploads(self, marc_file_uploads_fixture: MarcFileUploadsFixture): + uploads = marc_file_uploads_fixture.uploads + + # must hold lock to do update + with pytest.raises(RedisMarcError) as exc_info: + uploads.clear_uploads() + assert "Must hold lock" in str(exc_info.value) + + uploads.acquire() + + # We are unable to clear the uploads for an item that hasn't been initialized + with pytest.raises(RedisMarcError) as exc_info: + uploads.clear_uploads() + assert "Failed to clear uploads" in str(exc_info.value) + + marc_file_uploads_fixture.load_test_data() + assert uploads.get() != {} + + uploads.clear_uploads() + assert uploads.get() == {} + + def test_get_upload_ids(self, marc_file_uploads_fixture: MarcFileUploadsFixture): + uploads = marc_file_uploads_fixture.uploads + + # If the id is not set, we should get None + assert uploads.get_upload_ids( + [marc_file_uploads_fixture.mock_upload_key_1] + ) == {marc_file_uploads_fixture.mock_upload_key_1: None} + + marc_file_uploads_fixture.load_test_data() + + # If the buffer has been set, but the upload id has not, we should still get None + assert uploads.get_upload_ids( + [marc_file_uploads_fixture.mock_upload_key_1] + ) == {marc_file_uploads_fixture.mock_upload_key_1: None} + + with uploads.lock() as locked: + assert locked + uploads.set_upload_id(marc_file_uploads_fixture.mock_upload_key_1, "abc") + uploads.set_upload_id(marc_file_uploads_fixture.mock_upload_key_2, "def") + assert uploads.get_upload_ids(marc_file_uploads_fixture.mock_upload_key_1) == { + marc_file_uploads_fixture.mock_upload_key_1: "abc" + } + assert uploads.get_upload_ids( + [ + marc_file_uploads_fixture.mock_upload_key_1, + marc_file_uploads_fixture.mock_upload_key_2, + ] + ) == { + marc_file_uploads_fixture.mock_upload_key_1: "abc", + marc_file_uploads_fixture.mock_upload_key_2: "def", + } + + def test_add_part_and_clear_buffer( + self, marc_file_uploads_fixture: MarcFileUploadsFixture + ): + uploads = marc_file_uploads_fixture.uploads + + # If we try to add parts without acquiring the lock, we should get an error + with pytest.raises(RedisMarcError) as exc_info: + uploads.add_part_and_clear_buffer( + marc_file_uploads_fixture.mock_upload_key_1, + marc_file_uploads_fixture.part_1, + ) + assert "Must hold lock" in str(exc_info.value) + + # Acquire the lock + uploads.acquire() + + # We are unable to add parts to a library whose buffers haven't been initialized + with pytest.raises(RedisMarcError) as exc_info: + uploads.add_part_and_clear_buffer( + marc_file_uploads_fixture.mock_upload_key_1, + marc_file_uploads_fixture.part_1, + ) + assert "Failed to add part and clear buffer" in str(exc_info.value) + + marc_file_uploads_fixture.load_test_data() + + # We are able to add parts to a library that exists + uploads.add_part_and_clear_buffer( + marc_file_uploads_fixture.mock_upload_key_1, + marc_file_uploads_fixture.part_1, + ) + uploads.add_part_and_clear_buffer( + marc_file_uploads_fixture.mock_upload_key_2, + marc_file_uploads_fixture.part_1, + ) + uploads.add_part_and_clear_buffer( + marc_file_uploads_fixture.mock_upload_key_1, + marc_file_uploads_fixture.part_2, + ) + + all_uploads = uploads.get() + # The parts are added in order and the buffers are cleared + assert all_uploads[marc_file_uploads_fixture.mock_upload_key_1].parts == [ + marc_file_uploads_fixture.part_1, + marc_file_uploads_fixture.part_2, + ] + assert all_uploads[marc_file_uploads_fixture.mock_upload_key_2].parts == [ + marc_file_uploads_fixture.part_1 + ] + assert all_uploads[marc_file_uploads_fixture.mock_upload_key_1].buffer == "" + assert all_uploads[marc_file_uploads_fixture.mock_upload_key_2].buffer == "" + + def test_get_part_num_and_buffer( + self, marc_file_uploads_fixture: MarcFileUploadsFixture + ): + uploads = marc_file_uploads_fixture.uploads + + # If the key has not been initialized, we get an exception + with pytest.raises(RedisMarcError) as exc_info: + uploads.get_part_num_and_buffer(marc_file_uploads_fixture.mock_upload_key_1) + assert "Failed to get part number and buffer data" in str(exc_info.value) + + marc_file_uploads_fixture.load_test_data() + + # If the buffer has been set, but no parts have been added + assert uploads.get_part_num_and_buffer( + marc_file_uploads_fixture.mock_upload_key_1 + ) == ( + 0, + marc_file_uploads_fixture.test_data[ + marc_file_uploads_fixture.mock_upload_key_1 + ], + ) + + with uploads.lock() as locked: + assert locked + uploads.add_part_and_clear_buffer( + marc_file_uploads_fixture.mock_upload_key_1, + marc_file_uploads_fixture.part_1, + ) + uploads.add_part_and_clear_buffer( + marc_file_uploads_fixture.mock_upload_key_1, + marc_file_uploads_fixture.part_2, + ) + uploads.append_buffers( + { + marc_file_uploads_fixture.mock_upload_key_1: "1234567", + } + ) + + assert uploads.get_part_num_and_buffer( + marc_file_uploads_fixture.mock_upload_key_1 + ) == (2, "1234567") diff --git a/tests/manager/service/storage/test_s3.py b/tests/manager/service/storage/test_s3.py index 28086f7a17..c946aa01e3 100644 --- a/tests/manager/service/storage/test_s3.py +++ b/tests/manager/service/storage/test_s3.py @@ -1,28 +1,15 @@ from __future__ import annotations import functools -import uuid -from collections.abc import Generator from io import BytesIO -from typing import TYPE_CHECKING from unittest.mock import MagicMock import pytest from botocore.exceptions import BotoCoreError, ClientError -from pydantic import AnyHttpUrl from palace.manager.core.config import CannotLoadConfiguration -from palace.manager.service.configuration.service_configuration import ( - ServiceConfiguration, -) -from palace.manager.service.storage.container import Storage from palace.manager.service.storage.s3 import S3Service -from tests.fixtures.config import FixtureTestUrlConfiguration - -if TYPE_CHECKING: - from mypy_boto3_s3 import S3Client - - from tests.fixtures.s3 import S3ServiceFixture +from tests.fixtures.s3 import S3ServiceFixture, S3ServiceIntegrationFixture class TestS3Service: @@ -239,88 +226,6 @@ def test_multipart_upload_exception(self, s3_service_fixture: S3ServiceFixture): upload.upload_part(b"foo") -class S3UploaderIntegrationConfiguration(FixtureTestUrlConfiguration): - url: AnyHttpUrl - user: str - password: str - - class Config(ServiceConfiguration.Config): - env_prefix = "PALACE_TEST_MINIO_" - - -class S3ServiceIntegrationFixture: - def __init__(self): - self.container = Storage() - self.configuration = S3UploaderIntegrationConfiguration.from_env() - self.analytics_bucket = self.random_name("analytics") - self.public_access_bucket = self.random_name("public") - self.container.config.from_dict( - { - "access_key": self.configuration.user, - "secret_key": self.configuration.password, - "endpoint_url": self.configuration.url, - "region": "us-east-1", - "analytics_bucket": self.analytics_bucket, - "public_access_bucket": self.public_access_bucket, - "url_template": self.configuration.url + "/{bucket}/{key}", - } - ) - self.buckets = [] - self.create_buckets() - - @classmethod - def random_name(cls, prefix: str = "test"): - return f"{prefix}-{uuid.uuid4()}" - - @property - def s3_client(self) -> S3Client: - return self.container.s3_client() - - @property - def public(self) -> S3Service: - return self.container.public() - - @property - def analytics(self) -> S3Service: - return self.container.analytics() - - def create_bucket(self, bucket_name: str) -> None: - client = self.s3_client - client.create_bucket(Bucket=bucket_name) - self.buckets.append(bucket_name) - - def get_bucket(self, bucket_name: str) -> str: - if bucket_name == "public": - return self.public_access_bucket - elif bucket_name == "analytics": - return self.analytics_bucket - else: - raise ValueError(f"Unknown bucket name: {bucket_name}") - - def create_buckets(self) -> None: - for bucket in [self.analytics_bucket, self.public_access_bucket]: - self.create_bucket(bucket) - - def close(self): - for bucket in self.buckets: - response = self.s3_client.list_objects(Bucket=bucket) - - for object in response.get("Contents", []): - object_key = object["Key"] - self.s3_client.delete_object(Bucket=bucket, Key=object_key) - - self.s3_client.delete_bucket(Bucket=bucket) - - -@pytest.fixture -def s3_service_integration_fixture() -> ( - Generator[S3ServiceIntegrationFixture, None, None] -): - fixture = S3ServiceIntegrationFixture() - yield fixture - fixture.close() - - @pytest.mark.minio class TestS3ServiceIntegration: def test_delete(self, s3_service_integration_fixture: S3ServiceIntegrationFixture): From 252d5852b39178919905f2164ecba2604b002ef9 Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Wed, 4 Sep 2024 09:35:32 -0300 Subject: [PATCH 3/7] Code review feedback: Update settings comments --- src/palace/manager/marc/settings.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/palace/manager/marc/settings.py b/src/palace/manager/marc/settings.py index a6517fb73f..4412876fe4 100644 --- a/src/palace/manager/marc/settings.py +++ b/src/palace/manager/marc/settings.py @@ -12,9 +12,9 @@ class MarcExporterSettings(BaseSettings): # This setting (in days) controls how often MARC files should be - # automatically updated. Since the crontab in docker isn't easily - # configurable, we can run a script daily but check this to decide - # whether to do anything. + # automatically updated. We run the celery task to update the MARC + # files on a schedule, but this setting easily allows admins to + # generate files more or less often. update_frequency: NonNegativeInt = FormField( 30, form=ConfigurationFormItem( From b75320c361f0a049a541af990ac6e13cc42bea76 Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Wed, 4 Sep 2024 09:36:29 -0300 Subject: [PATCH 4/7] Rename the MarcFileUploads class to MarcFileUploadSession --- src/palace/manager/celery/tasks/marc.py | 8 +- src/palace/manager/marc/uploader.py | 34 ++- .../manager/service/redis/models/marc.py | 24 +- tests/manager/celery/tasks/test_marc.py | 15 +- tests/manager/marc/test_uploader.py | 7 +- .../manager/service/redis/models/test_marc.py | 274 ++++++++++-------- 6 files changed, 203 insertions(+), 159 deletions(-) diff --git a/src/palace/manager/celery/tasks/marc.py b/src/palace/manager/celery/tasks/marc.py index 920ce34a70..81e50babaf 100644 --- a/src/palace/manager/celery/tasks/marc.py +++ b/src/palace/manager/celery/tasks/marc.py @@ -7,7 +7,7 @@ from palace.manager.marc.exporter import LibraryInfo, MarcExporter from palace.manager.marc.uploader import MarcUploader from palace.manager.service.celery.celery import QueueNames -from palace.manager.service.redis.models.marc import MarcFileUploads +from palace.manager.service.redis.models.marc import MarcFileUploadSession from palace.manager.util.datetime_helpers import utc_now @@ -26,7 +26,7 @@ def marc_export(task: Task, force: bool = False) -> None: # Collection.id should never be able to be None here, but mypy doesn't know that. # So we assert it for mypy's benefit. assert collection.id is not None - lock = MarcFileUploads(task.services.redis.client(), collection.id) + lock = MarcFileUploadSession(task.services.redis.client(), collection.id) with lock.lock() as acquired: if not acquired: task.log.info( @@ -92,7 +92,9 @@ def marc_export_collection( libraries_info = [LibraryInfo.parse_obj(l) for l in libraries] uploader = MarcUploader( storage_service, - MarcFileUploads(task.services.redis.client(), collection_id, update_number), + MarcFileUploadSession( + task.services.redis.client(), collection_id, update_number + ), ) with uploader.begin(): if not uploader.locked: diff --git a/src/palace/manager/marc/uploader.py b/src/palace/manager/marc/uploader.py index 976e5be0f4..132356bbef 100644 --- a/src/palace/manager/marc/uploader.py +++ b/src/palace/manager/marc/uploader.py @@ -5,7 +5,7 @@ from celery.exceptions import Ignore, Retry from typing_extensions import Self -from palace.manager.service.redis.models.marc import MarcFileUploads +from palace.manager.service.redis.models.marc import MarcFileUploadSession from palace.manager.service.storage.s3 import S3Service from palace.manager.sqlalchemy.model.resource import Representation from palace.manager.util.log import LoggerMixin @@ -18,12 +18,14 @@ class MarcUploader(LoggerMixin): between steps to redis, and flushing them to S3 when the buffer is large enough. This class orchestrates the upload process, delegating the redis operation to the - `MarcFileUploads` class, and the S3 upload to the `S3Service` class. + `MarcFileUploadSession` class, and the S3 upload to the `S3Service` class. """ - def __init__(self, storage_service: S3Service, marc_uploads: MarcFileUploads): + def __init__( + self, storage_service: S3Service, upload_session: MarcFileUploadSession + ): self.storage_service = storage_service - self.marc_uploads = marc_uploads + self.upload_session = upload_session self._buffers: defaultdict[str, str] = defaultdict(str) self._locked = False @@ -33,30 +35,30 @@ def locked(self) -> bool: @property def update_number(self) -> int: - return self.marc_uploads.update_number + return self.upload_session.update_number def add_record(self, key: str, record: bytes) -> None: self._buffers[key] += record.decode() def _s3_sync(self, needs_upload: Sequence[str]) -> None: - upload_ids = self.marc_uploads.get_upload_ids(needs_upload) + upload_ids = self.upload_session.get_upload_ids(needs_upload) for key in needs_upload: if upload_ids.get(key) is None: upload_id = self.storage_service.multipart_create( key, content_type=Representation.MARC_MEDIA_TYPE ) - self.marc_uploads.set_upload_id(key, upload_id) + self.upload_session.set_upload_id(key, upload_id) upload_ids[key] = upload_id - part_number, data = self.marc_uploads.get_part_num_and_buffer(key) + part_number, data = self.upload_session.get_part_num_and_buffer(key) upload_part = self.storage_service.multipart_upload( key, upload_ids[key], part_number, data.encode() ) - self.marc_uploads.add_part_and_clear_buffer(key, upload_part) + self.upload_session.add_part_and_clear_buffer(key, upload_part) def sync(self) -> None: # First sync our buffers to redis - buffer_lengths = self.marc_uploads.append_buffers(self._buffers) + buffer_lengths = self.upload_session.append_buffers(self._buffers) self._buffers.clear() # Then, if any of our redis buffers are large enough, upload them to S3 @@ -72,7 +74,7 @@ def sync(self) -> None: self._s3_sync(needs_upload) def _abort(self) -> None: - in_progress = self.marc_uploads.get() + in_progress = self.upload_session.get() for key, upload in in_progress.items(): if upload.upload_id is None: # This upload has not started, so there is nothing to abort. @@ -94,7 +96,7 @@ def complete(self) -> set[str]: # Make sure any local data we have is synced self.sync() - in_progress = self.marc_uploads.get() + in_progress = self.upload_session.get() for key, upload in in_progress.items(): if upload.upload_id is None: # We haven't started the upload. At this point there is no reason to start a @@ -118,17 +120,17 @@ def complete(self) -> set[str]: # Delete our in-progress uploads data from redis if in_progress: - self.marc_uploads.clear_uploads() + self.upload_session.clear_uploads() # Return the keys that were uploaded return set(in_progress.keys()) def delete(self) -> None: - self.marc_uploads.delete() + self.upload_session.delete() @contextmanager def begin(self) -> Generator[Self, None, None]: - self._locked = self.marc_uploads.acquire() + self._locked = self.upload_session.acquire() try: yield self except Exception as e: @@ -138,5 +140,5 @@ def begin(self) -> Generator[Self, None, None]: self._abort() raise finally: - self.marc_uploads.release() + self.upload_session.release() self._locked = False diff --git a/src/palace/manager/service/redis/models/marc.py b/src/palace/manager/service/redis/models/marc.py index 7612f05f27..3e447ac259 100644 --- a/src/palace/manager/service/redis/models/marc.py +++ b/src/palace/manager/service/redis/models/marc.py @@ -16,7 +16,7 @@ from palace.manager.util.log import LoggerMixin -class RedisMarcError(LockError): +class MarcFileUploadSessionError(LockError): pass @@ -26,7 +26,7 @@ class MarcFileUpload(BaseModel): parts: list[MultipartS3UploadPart] = [] -class MarcFileUploads(RedisJsonLock, LoggerMixin): +class MarcFileUploadSession(RedisJsonLock, LoggerMixin): """ This class is used as a lock for the Celery MARC export task, to ensure that only one task can upload MARC files for a given collection at a time. It increments an update @@ -115,7 +115,7 @@ def _pipeline( if ( remote_random := fetched_data.get(self._lock_json_key) ) != self._random_value: - raise RedisMarcError( + raise MarcFileUploadSessionError( f"Must hold lock to append to buffer. " f"Expected: {self._random_value}, got: {remote_random}" ) @@ -123,7 +123,7 @@ def _pipeline( if ( remote_update_number := fetched_data.get(self._update_number_json_key) ) != self._update_number: - raise RedisMarcError( + raise MarcFileUploadSessionError( f"Update number mismatch. " f"Expected: {self._update_number}, got: {remote_update_number}" ) @@ -133,7 +133,7 @@ def _pipeline( def _execute_pipeline(self, pipe: Pipeline, updates: int) -> list[Any]: if not pipe.explicit_transaction: - raise RedisMarcError( + raise MarcFileUploadSessionError( "Pipeline should be in explicit transaction mode before executing." ) pipe.json().numincrby(self.key, self._update_number_json_key, updates) @@ -141,7 +141,7 @@ def _execute_pipeline(self, pipe: Pipeline, updates: int) -> list[Any]: try: pipe_results = pipe.execute() except WatchError as e: - raise RedisMarcError( + raise MarcFileUploadSessionError( "Failed to update buffers. Another process is modifying the buffers." ) from e self._update_number = self._parse_value_or_raise(pipe_results[-2]) @@ -175,7 +175,7 @@ def append_buffers(self, data: Mapping[str, str]) -> dict[str, int]: pipe_results = self._execute_pipeline(pipe, len(data)) if not all(pipe_results): - raise RedisMarcError("Failed to append buffers.") + raise MarcFileUploadSessionError("Failed to append buffers.") return { k: set_results[k] if v is True else self._parse_value_or_raise(v) @@ -197,7 +197,7 @@ def add_part_and_clear_buffer(self, key: str, part: MultipartS3UploadPart) -> No pipe_results = self._execute_pipeline(pipe, 1) if not all(pipe_results): - raise RedisMarcError("Failed to add part and clear buffer.") + raise MarcFileUploadSessionError("Failed to add part and clear buffer.") def set_upload_id(self, key: str, upload_id: str) -> None: with self._pipeline() as pipe: @@ -210,7 +210,7 @@ def set_upload_id(self, key: str, upload_id: str) -> None: pipe_results = self._execute_pipeline(pipe, 1) if not all(pipe_results): - raise RedisMarcError("Failed to set upload ID.") + raise MarcFileUploadSessionError("Failed to set upload ID.") def clear_uploads(self) -> None: with self._pipeline() as pipe: @@ -218,7 +218,7 @@ def clear_uploads(self) -> None: pipe_results = self._execute_pipeline(pipe, 1) if not all(pipe_results): - raise RedisMarcError("Failed to clear uploads.") + raise MarcFileUploadSessionError("Failed to clear uploads.") def _get_specific( self, @@ -263,7 +263,9 @@ def get_part_num_and_buffer(self, key: str) -> tuple[int, str]: pipe.json().arrlen(self.key, self._parts_path(key)) results = pipe.execute() except ResponseError as e: - raise RedisMarcError("Failed to get part number and buffer data.") from e + raise MarcFileUploadSessionError( + "Failed to get part number and buffer data." + ) from e buffer_data: str = self._parse_value_or_raise(results[0]) part_number: int = self._parse_value_or_raise(results[1]) diff --git a/tests/manager/celery/tasks/test_marc.py b/tests/manager/celery/tasks/test_marc.py index 3f38c4d6b9..a878cdac9a 100644 --- a/tests/manager/celery/tasks/test_marc.py +++ b/tests/manager/celery/tasks/test_marc.py @@ -9,7 +9,10 @@ from palace.manager.marc.exporter import MarcExporter from palace.manager.marc.uploader import MarcUploader from palace.manager.service.logging.configuration import LogLevel -from palace.manager.service.redis.models.marc import MarcFileUploads, RedisMarcError +from palace.manager.service.redis.models.marc import ( + MarcFileUploadSession, + MarcFileUploadSessionError, +) from palace.manager.sqlalchemy.model.collection import Collection from palace.manager.sqlalchemy.model.marcfile import MarcFile from palace.manager.sqlalchemy.model.work import Work @@ -56,7 +59,7 @@ def test_marc_export( # Collection 1 should be skipped because it is locked assert marc_exporter_fixture.collection1.id is not None - MarcFileUploads( + MarcFileUploadSession( redis_fixture.client, marc_exporter_fixture.collection1.id ).acquire() @@ -113,7 +116,7 @@ def marc_files(self) -> list[MarcFile]: def redis_data(self, collection: Collection) -> dict[str, Any] | None: assert collection.id is not None - uploads = MarcFileUploads(self.redis_fixture.client, collection.id) + uploads = MarcFileUploadSession(self.redis_fixture.client, collection.id) return self.redis_fixture.client.json().get(uploads.key) def setup_minio_storage(self) -> None: @@ -252,7 +255,7 @@ def test_locked( caplog.set_level(LogLevel.info) collection = marc_exporter_fixture.collection1 assert collection.id is not None - MarcFileUploads(redis_fixture.client, collection.id).acquire() + MarcFileUploadSession(redis_fixture.client, collection.id).acquire() marc_export_collection_fixture.setup_mock_storage() with patch.object(MarcExporter, "query_works") as query: marc_export_collection_fixture.export_collection(collection) @@ -275,12 +278,12 @@ def test_outdated_task_run( # Acquire the lock and start an upload, this simulates another task having done work # that the current task doesn't know about. - uploads = MarcFileUploads(redis_fixture.client, collection.id) + uploads = MarcFileUploadSession(redis_fixture.client, collection.id) with uploads.lock() as locked: assert locked uploads.append_buffers({"test": "data"}) - with pytest.raises(RedisMarcError, match="Update number mismatch"): + with pytest.raises(MarcFileUploadSessionError, match="Update number mismatch"): marc_export_collection_fixture.export_collection(collection) assert marc_export_collection_fixture.marc_files() == [] diff --git a/tests/manager/marc/test_uploader.py b/tests/manager/marc/test_uploader.py index 90fd32623e..1c6087aad9 100644 --- a/tests/manager/marc/test_uploader.py +++ b/tests/manager/marc/test_uploader.py @@ -4,7 +4,10 @@ from celery.exceptions import Ignore, Retry from palace.manager.marc.uploader import MarcUploader -from palace.manager.service.redis.models.marc import MarcFileUpload, MarcFileUploads +from palace.manager.service.redis.models.marc import ( + MarcFileUpload, + MarcFileUploadSession, +) from palace.manager.sqlalchemy.model.resource import Representation from tests.fixtures.redis import RedisFixture from tests.fixtures.s3 import S3ServiceFixture @@ -31,7 +34,7 @@ def __init__( self.mock_collection_id = 52 - self.uploads = MarcFileUploads(self.redis_client, self.mock_collection_id) + self.uploads = MarcFileUploadSession(self.redis_client, self.mock_collection_id) self.uploader = MarcUploader(self.mock_s3_service, self.uploads) diff --git a/tests/manager/service/redis/models/test_marc.py b/tests/manager/service/redis/models/test_marc.py index 4e5a2778ee..5b64725089 100644 --- a/tests/manager/service/redis/models/test_marc.py +++ b/tests/manager/service/redis/models/test_marc.py @@ -2,21 +2,21 @@ from palace.manager.service.redis.models.marc import ( MarcFileUpload, - MarcFileUploads, - RedisMarcError, + MarcFileUploadSession, + MarcFileUploadSessionError, ) from palace.manager.service.redis.redis import Pipeline from palace.manager.service.storage.s3 import MultipartS3UploadPart from tests.fixtures.redis import RedisFixture -class MarcFileUploadsFixture: +class MarcFileUploadSessionFixture: def __init__(self, redis_fixture: RedisFixture): self._redis_fixture = redis_fixture self.mock_collection_id = 1 - self.uploads = MarcFileUploads( + self.uploads = MarcFileUploadSession( self._redis_fixture.client, self.mock_collection_id ) @@ -53,16 +53,18 @@ def test_data_records(self, *keys: str): @pytest.fixture -def marc_file_uploads_fixture(redis_fixture: RedisFixture): - return MarcFileUploadsFixture(redis_fixture) +def marc_file_upload_session_fixture(redis_fixture: RedisFixture): + return MarcFileUploadSessionFixture(redis_fixture) -class TestMarcFileUploads: - def test__pipeline(self, marc_file_uploads_fixture: MarcFileUploadsFixture): - uploads = marc_file_uploads_fixture.uploads +class TestMarcFileUploadSession: + def test__pipeline( + self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture + ): + uploads = marc_file_upload_session_fixture.uploads # Using the _pipeline() context manager makes sure that we hold the lock - with pytest.raises(RedisMarcError) as exc_info: + with pytest.raises(MarcFileUploadSessionError) as exc_info: with uploads._pipeline(): pass assert "Must hold lock" in str(exc_info.value) @@ -71,7 +73,7 @@ def test__pipeline(self, marc_file_uploads_fixture: MarcFileUploadsFixture): # It also checks that the update_number is correct uploads._update_number = 1 - with pytest.raises(RedisMarcError) as exc_info: + with pytest.raises(MarcFileUploadSessionError) as exc_info: with uploads._pipeline(): pass assert "Update number mismatch" in str(exc_info.value) @@ -95,15 +97,15 @@ def test__pipeline(self, marc_file_uploads_fixture: MarcFileUploadsFixture): def test__execute_pipeline( self, - marc_file_uploads_fixture: MarcFileUploadsFixture, + marc_file_upload_session_fixture: MarcFileUploadSessionFixture, redis_fixture: RedisFixture, ): client = redis_fixture.client - uploads = marc_file_uploads_fixture.uploads + uploads = marc_file_upload_session_fixture.uploads uploads.acquire() # If we try to execute a pipeline without a transaction, we should get an error - with pytest.raises(RedisMarcError) as exc_info: + with pytest.raises(MarcFileUploadSessionError) as exc_info: with uploads._pipeline(begin_transaction=False) as pipe: uploads._execute_pipeline(pipe, 0) assert "Pipeline should be in explicit transaction mode" in str(exc_info.value) @@ -129,17 +131,19 @@ def test__execute_pipeline( client.json().set( uploads.key, uploads._update_number_json_key, update_number ) - with pytest.raises(RedisMarcError) as exc_info: + with pytest.raises(MarcFileUploadSessionError) as exc_info: uploads._execute_pipeline(pipe, 1) assert "Another process is modifying the buffers" in str(exc_info.value) - def test_append_buffers(self, marc_file_uploads_fixture: MarcFileUploadsFixture): - uploads = marc_file_uploads_fixture.uploads + def test_append_buffers( + self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture + ): + uploads = marc_file_upload_session_fixture.uploads # If we try to update buffers without acquiring the lock, we should get an error - with pytest.raises(RedisMarcError) as exc_info: + with pytest.raises(MarcFileUploadSessionError) as exc_info: uploads.append_buffers( - {marc_file_uploads_fixture.mock_upload_key_1: "test"} + {marc_file_upload_session_fixture.mock_upload_key_1: "test"} ) assert "Must hold lock" in str(exc_info.value) @@ -150,177 +154,197 @@ def test_append_buffers(self, marc_file_uploads_fixture: MarcFileUploadsFixture) assert uploads.append_buffers( { - marc_file_uploads_fixture.mock_upload_key_1: "test", - marc_file_uploads_fixture.mock_upload_key_2: "another_test", + marc_file_upload_session_fixture.mock_upload_key_1: "test", + marc_file_upload_session_fixture.mock_upload_key_2: "another_test", } ) == { - marc_file_uploads_fixture.mock_upload_key_1: 4, - marc_file_uploads_fixture.mock_upload_key_2: 12, + marc_file_upload_session_fixture.mock_upload_key_1: 4, + marc_file_upload_session_fixture.mock_upload_key_2: 12, } assert uploads._update_number == 2 assert uploads.append_buffers( { - marc_file_uploads_fixture.mock_upload_key_1: "x", - marc_file_uploads_fixture.mock_upload_key_2: "y", - marc_file_uploads_fixture.mock_upload_key_3: "new", + marc_file_upload_session_fixture.mock_upload_key_1: "x", + marc_file_upload_session_fixture.mock_upload_key_2: "y", + marc_file_upload_session_fixture.mock_upload_key_3: "new", } ) == { - marc_file_uploads_fixture.mock_upload_key_1: 5, - marc_file_uploads_fixture.mock_upload_key_2: 13, - marc_file_uploads_fixture.mock_upload_key_3: 3, + marc_file_upload_session_fixture.mock_upload_key_1: 5, + marc_file_upload_session_fixture.mock_upload_key_2: 13, + marc_file_upload_session_fixture.mock_upload_key_3: 3, } assert uploads._update_number == 5 # If we try to update buffers with an old update number, we should get an error uploads._update_number = 4 - with pytest.raises(RedisMarcError) as exc_info: - uploads.append_buffers(marc_file_uploads_fixture.test_data) + with pytest.raises(MarcFileUploadSessionError) as exc_info: + uploads.append_buffers(marc_file_upload_session_fixture.test_data) assert "Update number mismatch" in str(exc_info.value) # Exiting the context manager should release the lock assert not uploads.locked() - def test_get(self, marc_file_uploads_fixture: MarcFileUploadsFixture): - uploads = marc_file_uploads_fixture.uploads + def test_get(self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture): + uploads = marc_file_upload_session_fixture.uploads assert uploads.get() == {} - assert uploads.get(marc_file_uploads_fixture.mock_upload_key_1) == {} + assert uploads.get(marc_file_upload_session_fixture.mock_upload_key_1) == {} - marc_file_uploads_fixture.load_test_data() + marc_file_upload_session_fixture.load_test_data() # You don't need to acquire the lock to get the uploads, but you should if you # are using the data to do updates. # You can get a subset of the uploads assert uploads.get( - marc_file_uploads_fixture.mock_upload_key_1, - ) == marc_file_uploads_fixture.test_data_records( - marc_file_uploads_fixture.mock_upload_key_1 + marc_file_upload_session_fixture.mock_upload_key_1, + ) == marc_file_upload_session_fixture.test_data_records( + marc_file_upload_session_fixture.mock_upload_key_1 ) # Or multiple uploads, any that don't exist are not included in the result dict assert uploads.get( [ - marc_file_uploads_fixture.mock_upload_key_1, - marc_file_uploads_fixture.mock_upload_key_2, - marc_file_uploads_fixture.mock_unset_upload_key, + marc_file_upload_session_fixture.mock_upload_key_1, + marc_file_upload_session_fixture.mock_upload_key_2, + marc_file_upload_session_fixture.mock_unset_upload_key, ] - ) == marc_file_uploads_fixture.test_data_records( - marc_file_uploads_fixture.mock_upload_key_1, - marc_file_uploads_fixture.mock_upload_key_2, + ) == marc_file_upload_session_fixture.test_data_records( + marc_file_upload_session_fixture.mock_upload_key_1, + marc_file_upload_session_fixture.mock_upload_key_2, ) # Or you can get all the uploads - assert uploads.get() == marc_file_uploads_fixture.test_data_records( - marc_file_uploads_fixture.mock_upload_key_1, - marc_file_uploads_fixture.mock_upload_key_2, - marc_file_uploads_fixture.mock_upload_key_3, + assert uploads.get() == marc_file_upload_session_fixture.test_data_records( + marc_file_upload_session_fixture.mock_upload_key_1, + marc_file_upload_session_fixture.mock_upload_key_2, + marc_file_upload_session_fixture.mock_upload_key_3, ) - def test_set_upload_id(self, marc_file_uploads_fixture: MarcFileUploadsFixture): - uploads = marc_file_uploads_fixture.uploads + def test_set_upload_id( + self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture + ): + uploads = marc_file_upload_session_fixture.uploads # must hold lock to do update - with pytest.raises(RedisMarcError) as exc_info: - uploads.set_upload_id(marc_file_uploads_fixture.mock_upload_key_1, "xyz") + with pytest.raises(MarcFileUploadSessionError) as exc_info: + uploads.set_upload_id( + marc_file_upload_session_fixture.mock_upload_key_1, "xyz" + ) assert "Must hold lock" in str(exc_info.value) uploads.acquire() # We are unable to set an upload id for an item that hasn't been initialized - with pytest.raises(RedisMarcError) as exc_info: - uploads.set_upload_id(marc_file_uploads_fixture.mock_upload_key_1, "xyz") + with pytest.raises(MarcFileUploadSessionError) as exc_info: + uploads.set_upload_id( + marc_file_upload_session_fixture.mock_upload_key_1, "xyz" + ) assert "Failed to set upload ID" in str(exc_info.value) - marc_file_uploads_fixture.load_test_data() - uploads.set_upload_id(marc_file_uploads_fixture.mock_upload_key_1, "def") - uploads.set_upload_id(marc_file_uploads_fixture.mock_upload_key_2, "abc") + marc_file_upload_session_fixture.load_test_data() + uploads.set_upload_id(marc_file_upload_session_fixture.mock_upload_key_1, "def") + uploads.set_upload_id(marc_file_upload_session_fixture.mock_upload_key_2, "abc") all_uploads = uploads.get() assert ( - all_uploads[marc_file_uploads_fixture.mock_upload_key_1].upload_id == "def" + all_uploads[marc_file_upload_session_fixture.mock_upload_key_1].upload_id + == "def" ) assert ( - all_uploads[marc_file_uploads_fixture.mock_upload_key_2].upload_id == "abc" + all_uploads[marc_file_upload_session_fixture.mock_upload_key_2].upload_id + == "abc" ) # We can't change the upload id for a library that has already been set - with pytest.raises(RedisMarcError) as exc_info: - uploads.set_upload_id(marc_file_uploads_fixture.mock_upload_key_1, "ghi") + with pytest.raises(MarcFileUploadSessionError) as exc_info: + uploads.set_upload_id( + marc_file_upload_session_fixture.mock_upload_key_1, "ghi" + ) assert "Failed to set upload ID" in str(exc_info.value) all_uploads = uploads.get() assert ( - all_uploads[marc_file_uploads_fixture.mock_upload_key_1].upload_id == "def" + all_uploads[marc_file_upload_session_fixture.mock_upload_key_1].upload_id + == "def" ) assert ( - all_uploads[marc_file_uploads_fixture.mock_upload_key_2].upload_id == "abc" + all_uploads[marc_file_upload_session_fixture.mock_upload_key_2].upload_id + == "abc" ) - def test_clear_uploads(self, marc_file_uploads_fixture: MarcFileUploadsFixture): - uploads = marc_file_uploads_fixture.uploads + def test_clear_uploads( + self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture + ): + uploads = marc_file_upload_session_fixture.uploads # must hold lock to do update - with pytest.raises(RedisMarcError) as exc_info: + with pytest.raises(MarcFileUploadSessionError) as exc_info: uploads.clear_uploads() assert "Must hold lock" in str(exc_info.value) uploads.acquire() # We are unable to clear the uploads for an item that hasn't been initialized - with pytest.raises(RedisMarcError) as exc_info: + with pytest.raises(MarcFileUploadSessionError) as exc_info: uploads.clear_uploads() assert "Failed to clear uploads" in str(exc_info.value) - marc_file_uploads_fixture.load_test_data() + marc_file_upload_session_fixture.load_test_data() assert uploads.get() != {} uploads.clear_uploads() assert uploads.get() == {} - def test_get_upload_ids(self, marc_file_uploads_fixture: MarcFileUploadsFixture): - uploads = marc_file_uploads_fixture.uploads + def test_get_upload_ids( + self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture + ): + uploads = marc_file_upload_session_fixture.uploads # If the id is not set, we should get None assert uploads.get_upload_ids( - [marc_file_uploads_fixture.mock_upload_key_1] - ) == {marc_file_uploads_fixture.mock_upload_key_1: None} + [marc_file_upload_session_fixture.mock_upload_key_1] + ) == {marc_file_upload_session_fixture.mock_upload_key_1: None} - marc_file_uploads_fixture.load_test_data() + marc_file_upload_session_fixture.load_test_data() # If the buffer has been set, but the upload id has not, we should still get None assert uploads.get_upload_ids( - [marc_file_uploads_fixture.mock_upload_key_1] - ) == {marc_file_uploads_fixture.mock_upload_key_1: None} + [marc_file_upload_session_fixture.mock_upload_key_1] + ) == {marc_file_upload_session_fixture.mock_upload_key_1: None} with uploads.lock() as locked: assert locked - uploads.set_upload_id(marc_file_uploads_fixture.mock_upload_key_1, "abc") - uploads.set_upload_id(marc_file_uploads_fixture.mock_upload_key_2, "def") - assert uploads.get_upload_ids(marc_file_uploads_fixture.mock_upload_key_1) == { - marc_file_uploads_fixture.mock_upload_key_1: "abc" - } + uploads.set_upload_id( + marc_file_upload_session_fixture.mock_upload_key_1, "abc" + ) + uploads.set_upload_id( + marc_file_upload_session_fixture.mock_upload_key_2, "def" + ) + assert uploads.get_upload_ids( + marc_file_upload_session_fixture.mock_upload_key_1 + ) == {marc_file_upload_session_fixture.mock_upload_key_1: "abc"} assert uploads.get_upload_ids( [ - marc_file_uploads_fixture.mock_upload_key_1, - marc_file_uploads_fixture.mock_upload_key_2, + marc_file_upload_session_fixture.mock_upload_key_1, + marc_file_upload_session_fixture.mock_upload_key_2, ] ) == { - marc_file_uploads_fixture.mock_upload_key_1: "abc", - marc_file_uploads_fixture.mock_upload_key_2: "def", + marc_file_upload_session_fixture.mock_upload_key_1: "abc", + marc_file_upload_session_fixture.mock_upload_key_2: "def", } def test_add_part_and_clear_buffer( - self, marc_file_uploads_fixture: MarcFileUploadsFixture + self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture ): - uploads = marc_file_uploads_fixture.uploads + uploads = marc_file_upload_session_fixture.uploads # If we try to add parts without acquiring the lock, we should get an error - with pytest.raises(RedisMarcError) as exc_info: + with pytest.raises(MarcFileUploadSessionError) as exc_info: uploads.add_part_and_clear_buffer( - marc_file_uploads_fixture.mock_upload_key_1, - marc_file_uploads_fixture.part_1, + marc_file_upload_session_fixture.mock_upload_key_1, + marc_file_upload_session_fixture.part_1, ) assert "Must hold lock" in str(exc_info.value) @@ -328,79 +352,87 @@ def test_add_part_and_clear_buffer( uploads.acquire() # We are unable to add parts to a library whose buffers haven't been initialized - with pytest.raises(RedisMarcError) as exc_info: + with pytest.raises(MarcFileUploadSessionError) as exc_info: uploads.add_part_and_clear_buffer( - marc_file_uploads_fixture.mock_upload_key_1, - marc_file_uploads_fixture.part_1, + marc_file_upload_session_fixture.mock_upload_key_1, + marc_file_upload_session_fixture.part_1, ) assert "Failed to add part and clear buffer" in str(exc_info.value) - marc_file_uploads_fixture.load_test_data() + marc_file_upload_session_fixture.load_test_data() # We are able to add parts to a library that exists uploads.add_part_and_clear_buffer( - marc_file_uploads_fixture.mock_upload_key_1, - marc_file_uploads_fixture.part_1, + marc_file_upload_session_fixture.mock_upload_key_1, + marc_file_upload_session_fixture.part_1, ) uploads.add_part_and_clear_buffer( - marc_file_uploads_fixture.mock_upload_key_2, - marc_file_uploads_fixture.part_1, + marc_file_upload_session_fixture.mock_upload_key_2, + marc_file_upload_session_fixture.part_1, ) uploads.add_part_and_clear_buffer( - marc_file_uploads_fixture.mock_upload_key_1, - marc_file_uploads_fixture.part_2, + marc_file_upload_session_fixture.mock_upload_key_1, + marc_file_upload_session_fixture.part_2, ) all_uploads = uploads.get() # The parts are added in order and the buffers are cleared - assert all_uploads[marc_file_uploads_fixture.mock_upload_key_1].parts == [ - marc_file_uploads_fixture.part_1, - marc_file_uploads_fixture.part_2, + assert all_uploads[ + marc_file_upload_session_fixture.mock_upload_key_1 + ].parts == [ + marc_file_upload_session_fixture.part_1, + marc_file_upload_session_fixture.part_2, ] - assert all_uploads[marc_file_uploads_fixture.mock_upload_key_2].parts == [ - marc_file_uploads_fixture.part_1 - ] - assert all_uploads[marc_file_uploads_fixture.mock_upload_key_1].buffer == "" - assert all_uploads[marc_file_uploads_fixture.mock_upload_key_2].buffer == "" + assert all_uploads[ + marc_file_upload_session_fixture.mock_upload_key_2 + ].parts == [marc_file_upload_session_fixture.part_1] + assert ( + all_uploads[marc_file_upload_session_fixture.mock_upload_key_1].buffer == "" + ) + assert ( + all_uploads[marc_file_upload_session_fixture.mock_upload_key_2].buffer == "" + ) def test_get_part_num_and_buffer( - self, marc_file_uploads_fixture: MarcFileUploadsFixture + self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture ): - uploads = marc_file_uploads_fixture.uploads + uploads = marc_file_upload_session_fixture.uploads # If the key has not been initialized, we get an exception - with pytest.raises(RedisMarcError) as exc_info: - uploads.get_part_num_and_buffer(marc_file_uploads_fixture.mock_upload_key_1) + with pytest.raises(MarcFileUploadSessionError) as exc_info: + uploads.get_part_num_and_buffer( + marc_file_upload_session_fixture.mock_upload_key_1 + ) assert "Failed to get part number and buffer data" in str(exc_info.value) - marc_file_uploads_fixture.load_test_data() + marc_file_upload_session_fixture.load_test_data() # If the buffer has been set, but no parts have been added assert uploads.get_part_num_and_buffer( - marc_file_uploads_fixture.mock_upload_key_1 + marc_file_upload_session_fixture.mock_upload_key_1 ) == ( 0, - marc_file_uploads_fixture.test_data[ - marc_file_uploads_fixture.mock_upload_key_1 + marc_file_upload_session_fixture.test_data[ + marc_file_upload_session_fixture.mock_upload_key_1 ], ) with uploads.lock() as locked: assert locked uploads.add_part_and_clear_buffer( - marc_file_uploads_fixture.mock_upload_key_1, - marc_file_uploads_fixture.part_1, + marc_file_upload_session_fixture.mock_upload_key_1, + marc_file_upload_session_fixture.part_1, ) uploads.add_part_and_clear_buffer( - marc_file_uploads_fixture.mock_upload_key_1, - marc_file_uploads_fixture.part_2, + marc_file_upload_session_fixture.mock_upload_key_1, + marc_file_upload_session_fixture.part_2, ) uploads.append_buffers( { - marc_file_uploads_fixture.mock_upload_key_1: "1234567", + marc_file_upload_session_fixture.mock_upload_key_1: "1234567", } ) assert uploads.get_part_num_and_buffer( - marc_file_uploads_fixture.mock_upload_key_1 + marc_file_upload_session_fixture.mock_upload_key_1 ) == (2, "1234567") From 3576991cc0fe945adaa2af6f543ec3d78e06a365 Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Wed, 4 Sep 2024 09:48:24 -0300 Subject: [PATCH 5/7] Rename MarcUploader to MarcUploadManager --- src/palace/manager/celery/tasks/marc.py | 20 +- src/palace/manager/marc/exporter.py | 8 +- src/palace/manager/marc/uploader.py | 6 +- .../manager/service/redis/models/marc.py | 2 +- tests/manager/celery/tasks/test_marc.py | 4 +- tests/manager/marc/test_exporter.py | 12 +- tests/manager/marc/test_uploader.py | 187 ++++++++++-------- 7 files changed, 128 insertions(+), 111 deletions(-) diff --git a/src/palace/manager/celery/tasks/marc.py b/src/palace/manager/celery/tasks/marc.py index 81e50babaf..9fa82acfc6 100644 --- a/src/palace/manager/celery/tasks/marc.py +++ b/src/palace/manager/celery/tasks/marc.py @@ -5,7 +5,7 @@ from palace.manager.celery.task import Task from palace.manager.marc.exporter import LibraryInfo, MarcExporter -from palace.manager.marc.uploader import MarcUploader +from palace.manager.marc.uploader import MarcUploadManager from palace.manager.service.celery.celery import QueueNames from palace.manager.service.redis.models.marc import MarcFileUploadSession from palace.manager.util.datetime_helpers import utc_now @@ -90,14 +90,14 @@ def marc_export_collection( base_url = task.services.config.sitewide.base_url() storage_service = task.services.storage.public() libraries_info = [LibraryInfo.parse_obj(l) for l in libraries] - uploader = MarcUploader( + upload_manager = MarcUploadManager( storage_service, MarcFileUploadSession( task.services.redis.client(), collection_id, update_number ), ) - with uploader.begin(): - if not uploader.locked: + with upload_manager.begin(): + if not upload_manager.locked: task.log.info( f"Skipping collection {collection_id} because another task is already processing it." ) @@ -112,11 +112,11 @@ def marc_export_collection( ) for work in works: MarcExporter.process_work( - work, libraries_info, base_url, uploader=uploader + work, libraries_info, base_url, upload_manager=upload_manager ) - # Sync the uploader to ensure that all the data is written to storage. - uploader.sync() + # Sync the upload_manager to ensure that all the data is written to storage. + upload_manager.sync() if len(works) == batch_size: # This task is complete, but there are more works waiting to be exported. So we requeue ourselves @@ -128,7 +128,7 @@ def marc_export_collection( libraries=[l.dict() for l in libraries_info], batch_size=batch_size, last_work_id=works[-1].id, - update_number=uploader.update_number, + update_number=upload_manager.update_number, ) ) @@ -136,11 +136,11 @@ def marc_export_collection( with task.transaction() as session: collection = MarcExporter.collection(session, collection_id) collection_name = collection.name if collection else "unknown" - completed_uploads = uploader.complete() + completed_uploads = upload_manager.complete() MarcExporter.create_marc_upload_records( session, start_time, collection_id, libraries_info, completed_uploads ) - uploader.delete() + upload_manager.remove_session() task.log.info( f"Finished generating MARC records for collection '{collection_name}' ({collection_id})." ) diff --git a/src/palace/manager/marc/exporter.py b/src/palace/manager/marc/exporter.py index 7745ba79cb..13a587a7ba 100644 --- a/src/palace/manager/marc/exporter.py +++ b/src/palace/manager/marc/exporter.py @@ -16,7 +16,7 @@ MarcExporterLibrarySettings, MarcExporterSettings, ) -from palace.manager.marc.uploader import MarcUploader +from palace.manager.marc.uploader import MarcUploadManager from palace.manager.service.integration_registry.catalog_services import ( CatalogServicesRegistry, ) @@ -305,7 +305,7 @@ def process_work( libraries_info: Iterable[LibraryInfo], base_url: str, *, - uploader: MarcUploader, + upload_manager: MarcUploadManager, annotator: type[Annotator] = Annotator, ) -> None: pool = work.active_license_pool() @@ -325,7 +325,7 @@ def process_work( library_info.include_genres, ) - uploader.add_record( + upload_manager.add_record( library_info.s3_key_full, library_record.as_marc(), ) @@ -336,7 +336,7 @@ def process_work( and work.last_update_time and work.last_update_time > library_info.last_updated ): - uploader.add_record( + upload_manager.add_record( library_info.s3_key_delta, annotator.set_revised(library_record).as_marc(), ) diff --git a/src/palace/manager/marc/uploader.py b/src/palace/manager/marc/uploader.py index 132356bbef..81677977dd 100644 --- a/src/palace/manager/marc/uploader.py +++ b/src/palace/manager/marc/uploader.py @@ -11,7 +11,7 @@ from palace.manager.util.log import LoggerMixin -class MarcUploader(LoggerMixin): +class MarcUploadManager(LoggerMixin): """ This class is used to manage the upload of MARC files to S3. The upload is done in multiple parts, so that the Celery task can be broken up into multiple steps, saving the progress @@ -90,7 +90,7 @@ def _abort(self) -> None: ) # Delete our in-progress uploads from redis as well - self.delete() + self.remove_session() def complete(self) -> set[str]: # Make sure any local data we have is synced @@ -125,7 +125,7 @@ def complete(self) -> set[str]: # Return the keys that were uploaded return set(in_progress.keys()) - def delete(self) -> None: + def remove_session(self) -> None: self.upload_session.delete() @contextmanager diff --git a/src/palace/manager/service/redis/models/marc.py b/src/palace/manager/service/redis/models/marc.py index 3e447ac259..92578c85d2 100644 --- a/src/palace/manager/service/redis/models/marc.py +++ b/src/palace/manager/service/redis/models/marc.py @@ -38,7 +38,7 @@ class MarcFileUploadSession(RedisJsonLock, LoggerMixin): them to S3 when the buffer is full. This object is focused on the redis part of this operation, the actual s3 upload orchestration - is handled by the `MarcUploader` class. + is handled by the `MarcUploadManager` class. """ def __init__( diff --git a/tests/manager/celery/tasks/test_marc.py b/tests/manager/celery/tasks/test_marc.py index a878cdac9a..4779672ad6 100644 --- a/tests/manager/celery/tasks/test_marc.py +++ b/tests/manager/celery/tasks/test_marc.py @@ -7,7 +7,7 @@ from palace.manager.celery.tasks import marc from palace.manager.marc.exporter import MarcExporter -from palace.manager.marc.uploader import MarcUploader +from palace.manager.marc.uploader import MarcUploadManager from palace.manager.service.logging.configuration import LogLevel from palace.manager.service.redis.models.marc import ( MarcFileUploadSession, @@ -235,7 +235,7 @@ def test_exception_handled( collection = marc_exporter_fixture.collection1 marc_export_collection_fixture.works(collection) - with patch.object(MarcUploader, "complete") as complete: + with patch.object(MarcUploadManager, "complete") as complete: complete.side_effect = Exception("Test Exception") with pytest.raises(Exception, match="Test Exception"): marc_export_collection_fixture.export_collection(collection) diff --git a/tests/manager/marc/test_exporter.py b/tests/manager/marc/test_exporter.py index 4d40b5f2c0..c09c6c80a5 100644 --- a/tests/manager/marc/test_exporter.py +++ b/tests/manager/marc/test_exporter.py @@ -8,7 +8,7 @@ from palace.manager.marc.exporter import LibraryInfo, MarcExporter from palace.manager.marc.settings import MarcExporterLibrarySettings -from palace.manager.marc.uploader import MarcUploader +from palace.manager.marc.uploader import MarcUploadManager from palace.manager.sqlalchemy.model.discovery_service_registration import ( DiscoveryServiceRegistration, ) @@ -334,18 +334,18 @@ def test_process_work(self, marc_exporter_fixture: MarcExporterFixture) -> None: work = marc_exporter_fixture.work(collection) enabled_libraries = marc_exporter_fixture.enabled_libraries(collection) - mock_uploader = create_autospec(MarcUploader) + mock_upload_manager = create_autospec(MarcUploadManager) process_work = partial( MarcExporter.process_work, work, enabled_libraries, "http://base.url", - uploader=mock_uploader, + upload_manager=mock_upload_manager, ) process_work() - mock_uploader.add_record.assert_has_calls( + mock_upload_manager.add_record.assert_has_calls( [ call(enabled_libraries[0].s3_key_full, ANY), call(enabled_libraries[0].s3_key_delta, ANY), @@ -354,10 +354,10 @@ def test_process_work(self, marc_exporter_fixture: MarcExporterFixture) -> None: ) # If the work has no license pools, it is skipped. - mock_uploader.reset_mock() + mock_upload_manager.reset_mock() work.license_pools = [] process_work() - mock_uploader.add_record.assert_not_called() + mock_upload_manager.add_record.assert_not_called() def test_create_marc_upload_records( self, marc_exporter_fixture: MarcExporterFixture diff --git a/tests/manager/marc/test_uploader.py b/tests/manager/marc/test_uploader.py index 1c6087aad9..bb7898e34c 100644 --- a/tests/manager/marc/test_uploader.py +++ b/tests/manager/marc/test_uploader.py @@ -3,7 +3,7 @@ import pytest from celery.exceptions import Ignore, Retry -from palace.manager.marc.uploader import MarcUploader +from palace.manager.marc.uploader import MarcUploadManager from palace.manager.service.redis.models.marc import ( MarcFileUpload, MarcFileUploadSession, @@ -13,7 +13,7 @@ from tests.fixtures.s3 import S3ServiceFixture -class MarcUploaderFixture: +class MarcUploadManagerFixture: def __init__( self, redis_fixture: RedisFixture, s3_service_fixture: S3ServiceFixture ): @@ -35,24 +35,26 @@ def __init__( self.mock_collection_id = 52 self.uploads = MarcFileUploadSession(self.redis_client, self.mock_collection_id) - self.uploader = MarcUploader(self.mock_s3_service, self.uploads) + self.uploader = MarcUploadManager(self.mock_s3_service, self.uploads) @pytest.fixture -def marc_uploader_fixture( +def marc_upload_manager_fixture( redis_fixture: RedisFixture, s3_service_fixture: S3ServiceFixture ): - return MarcUploaderFixture(redis_fixture, s3_service_fixture) + return MarcUploadManagerFixture(redis_fixture, s3_service_fixture) -class TestMarcUploader: +class TestMarcUploadManager: def test_begin( - self, marc_uploader_fixture: MarcUploaderFixture, redis_fixture: RedisFixture + self, + marc_upload_manager_fixture: MarcUploadManagerFixture, + redis_fixture: RedisFixture, ): - uploader = marc_uploader_fixture.uploader + uploader = marc_upload_manager_fixture.uploader assert uploader.locked is False - assert marc_uploader_fixture.uploads.locked(by_us=True) is False + assert marc_upload_manager_fixture.uploads.locked(by_us=True) is False with uploader.begin() as u: # The context manager returns the uploader object @@ -62,11 +64,11 @@ def test_begin( assert uploader.locked is True # The lock is also reflected in the uploads object - assert marc_uploader_fixture.uploads.locked(by_us=True) is True # type: ignore[unreachable] + assert marc_upload_manager_fixture.uploads.locked(by_us=True) is True # type: ignore[unreachable] # The lock is released after the context manager exits assert uploader.locked is False # type: ignore[unreachable] - assert marc_uploader_fixture.uploads.locked(by_us=True) is False + assert marc_upload_manager_fixture.uploads.locked(by_us=True) is False # If an exception occurs, the lock is deleted and the exception is raised by calling # the _abort method @@ -77,7 +79,8 @@ def test_begin( assert uploader.locked is True raise Exception() assert ( - redis_fixture.client.json().get(marc_uploader_fixture.uploads.key) is None + redis_fixture.client.json().get(marc_upload_manager_fixture.uploads.key) + is None ) mock_abort.assert_called_once() @@ -89,40 +92,44 @@ def test_begin( with uploader.begin(): assert uploader.locked is True raise exception() - assert marc_uploader_fixture.uploads.locked(by_us=True) is False + assert marc_upload_manager_fixture.uploads.locked(by_us=True) is False assert ( - redis_fixture.client.json().get(marc_uploader_fixture.uploads.key) + redis_fixture.client.json().get(marc_upload_manager_fixture.uploads.key) is not None ) mock_abort.assert_not_called() - def test_add_record(self, marc_uploader_fixture: MarcUploaderFixture): - uploader = marc_uploader_fixture.uploader + def test_add_record(self, marc_upload_manager_fixture: MarcUploadManagerFixture): + uploader = marc_upload_manager_fixture.uploader uploader.add_record( - marc_uploader_fixture.test_key1, marc_uploader_fixture.test_record1 + marc_upload_manager_fixture.test_key1, + marc_upload_manager_fixture.test_record1, ) assert ( - uploader._buffers[marc_uploader_fixture.test_key1] - == marc_uploader_fixture.test_record1.decode() + uploader._buffers[marc_upload_manager_fixture.test_key1] + == marc_upload_manager_fixture.test_record1.decode() ) uploader.add_record( - marc_uploader_fixture.test_key1, marc_uploader_fixture.test_record1 + marc_upload_manager_fixture.test_key1, + marc_upload_manager_fixture.test_record1, ) assert ( - uploader._buffers[marc_uploader_fixture.test_key1] - == marc_uploader_fixture.test_record1.decode() * 2 + uploader._buffers[marc_upload_manager_fixture.test_key1] + == marc_upload_manager_fixture.test_record1.decode() * 2 ) - def test_sync(self, marc_uploader_fixture: MarcUploaderFixture): - uploader = marc_uploader_fixture.uploader + def test_sync(self, marc_upload_manager_fixture: MarcUploadManagerFixture): + uploader = marc_upload_manager_fixture.uploader uploader.add_record( - marc_uploader_fixture.test_key1, marc_uploader_fixture.test_record1 + marc_upload_manager_fixture.test_key1, + marc_upload_manager_fixture.test_record1, ) uploader.add_record( - marc_uploader_fixture.test_key2, marc_uploader_fixture.test_record2 * 2 + marc_upload_manager_fixture.test_key2, + marc_upload_manager_fixture.test_record2 * 2, ) with uploader.begin(): uploader.sync() @@ -131,27 +138,30 @@ def test_sync(self, marc_uploader_fixture: MarcUploaderFixture): assert uploader._buffers == {} # And pushes the local records to redis - assert marc_uploader_fixture.uploads.get() == { - marc_uploader_fixture.test_key1: MarcFileUpload( - buffer=marc_uploader_fixture.test_record1 + assert marc_upload_manager_fixture.uploads.get() == { + marc_upload_manager_fixture.test_key1: MarcFileUpload( + buffer=marc_upload_manager_fixture.test_record1 ), - marc_uploader_fixture.test_key2: MarcFileUpload( - buffer=marc_uploader_fixture.test_record2 * 2 + marc_upload_manager_fixture.test_key2: MarcFileUpload( + buffer=marc_upload_manager_fixture.test_record2 * 2 ), } # Because the buffer did not contain enough data, it was not uploaded to S3 - assert marc_uploader_fixture.mock_s3_service.upload_in_progress == {} + assert marc_upload_manager_fixture.mock_s3_service.upload_in_progress == {} # Add enough data for test_key1 to be uploaded to S3 uploader.add_record( - marc_uploader_fixture.test_key1, marc_uploader_fixture.test_record1 * 2 + marc_upload_manager_fixture.test_key1, + marc_upload_manager_fixture.test_record1 * 2, ) uploader.add_record( - marc_uploader_fixture.test_key1, marc_uploader_fixture.test_record1 * 2 + marc_upload_manager_fixture.test_key1, + marc_upload_manager_fixture.test_record1 * 2, ) uploader.add_record( - marc_uploader_fixture.test_key2, marc_uploader_fixture.test_record2 + marc_upload_manager_fixture.test_key2, + marc_upload_manager_fixture.test_record2, ) with uploader.begin(): @@ -162,61 +172,66 @@ def test_sync(self, marc_uploader_fixture: MarcUploaderFixture): # Because the data for test_key1 was large enough, it was uploaded to S3, and its redis data structure was # updated to reflect this. test_key2 was not large enough to upload, so it remains in redis and not in s3. - redis_data = marc_uploader_fixture.uploads.get() - assert redis_data[marc_uploader_fixture.test_key2] == MarcFileUpload( - buffer=marc_uploader_fixture.test_record2 * 3 + redis_data = marc_upload_manager_fixture.uploads.get() + assert redis_data[marc_upload_manager_fixture.test_key2] == MarcFileUpload( + buffer=marc_upload_manager_fixture.test_record2 * 3 ) - redis_data_test1 = redis_data[marc_uploader_fixture.test_key1] + redis_data_test1 = redis_data[marc_upload_manager_fixture.test_key1] assert redis_data_test1.buffer == "" - assert len(marc_uploader_fixture.mock_s3_service.upload_in_progress) == 1 + assert len(marc_upload_manager_fixture.mock_s3_service.upload_in_progress) == 1 assert ( - marc_uploader_fixture.test_key1 - in marc_uploader_fixture.mock_s3_service.upload_in_progress + marc_upload_manager_fixture.test_key1 + in marc_upload_manager_fixture.mock_s3_service.upload_in_progress ) - upload = marc_uploader_fixture.mock_s3_service.upload_in_progress[ - marc_uploader_fixture.test_key1 + upload = marc_upload_manager_fixture.mock_s3_service.upload_in_progress[ + marc_upload_manager_fixture.test_key1 ] assert upload.upload_id is not None assert upload.content_type is Representation.MARC_MEDIA_TYPE [part] = upload.parts - assert part.content == marc_uploader_fixture.test_record1 * 5 + assert part.content == marc_upload_manager_fixture.test_record1 * 5 # And the s3 part data and upload_id is synced to redis assert redis_data_test1.parts == [part.part_data] assert redis_data_test1.upload_id == upload.upload_id - def test_complete(self, marc_uploader_fixture: MarcUploaderFixture): - uploader = marc_uploader_fixture.uploader + def test_complete(self, marc_upload_manager_fixture: MarcUploadManagerFixture): + uploader = marc_upload_manager_fixture.uploader # Wrap the clear method so we can check if it was called mock_clear_uploads = MagicMock( - wraps=marc_uploader_fixture.uploads.clear_uploads + wraps=marc_upload_manager_fixture.uploads.clear_uploads ) - marc_uploader_fixture.uploads.clear_uploads = mock_clear_uploads + marc_upload_manager_fixture.uploads.clear_uploads = mock_clear_uploads # Set up the records for the test uploader.add_record( - marc_uploader_fixture.test_key1, marc_uploader_fixture.test_record1 * 5 + marc_upload_manager_fixture.test_key1, + marc_upload_manager_fixture.test_record1 * 5, ) uploader.add_record( - marc_uploader_fixture.test_key2, marc_uploader_fixture.test_record2 * 5 + marc_upload_manager_fixture.test_key2, + marc_upload_manager_fixture.test_record2 * 5, ) with uploader.begin(): uploader.sync() uploader.add_record( - marc_uploader_fixture.test_key1, marc_uploader_fixture.test_record1 * 5 + marc_upload_manager_fixture.test_key1, + marc_upload_manager_fixture.test_record1 * 5, ) with uploader.begin(): uploader.sync() uploader.add_record( - marc_uploader_fixture.test_key2, marc_uploader_fixture.test_record2 + marc_upload_manager_fixture.test_key2, + marc_upload_manager_fixture.test_record2, ) uploader.add_record( - marc_uploader_fixture.test_key3, marc_uploader_fixture.test_record3 + marc_upload_manager_fixture.test_key3, + marc_upload_manager_fixture.test_record3, ) # Complete the uploads @@ -225,9 +240,9 @@ def test_complete(self, marc_uploader_fixture: MarcUploaderFixture): # The complete method should return the keys that were completed assert completed == { - marc_uploader_fixture.test_key1, - marc_uploader_fixture.test_key2, - marc_uploader_fixture.test_key3, + marc_upload_manager_fixture.test_key1, + marc_upload_manager_fixture.test_key2, + marc_upload_manager_fixture.test_key3, } # The local buffers should be empty @@ -237,43 +252,45 @@ def test_complete(self, marc_uploader_fixture: MarcUploaderFixture): mock_clear_uploads.assert_called_once() # The s3 service should have the completed uploads - assert len(marc_uploader_fixture.mock_s3_service.uploads) == 3 - assert len(marc_uploader_fixture.mock_s3_service.upload_in_progress) == 0 + assert len(marc_upload_manager_fixture.mock_s3_service.uploads) == 3 + assert len(marc_upload_manager_fixture.mock_s3_service.upload_in_progress) == 0 - test_key1_upload = marc_uploader_fixture.mock_s3_service.uploads[ - marc_uploader_fixture.test_key1 + test_key1_upload = marc_upload_manager_fixture.mock_s3_service.uploads[ + marc_upload_manager_fixture.test_key1 ] - assert test_key1_upload.key == marc_uploader_fixture.test_key1 - assert test_key1_upload.content == marc_uploader_fixture.test_record1 * 10 + assert test_key1_upload.key == marc_upload_manager_fixture.test_key1 + assert test_key1_upload.content == marc_upload_manager_fixture.test_record1 * 10 assert test_key1_upload.media_type == Representation.MARC_MEDIA_TYPE - test_key2_upload = marc_uploader_fixture.mock_s3_service.uploads[ - marc_uploader_fixture.test_key2 + test_key2_upload = marc_upload_manager_fixture.mock_s3_service.uploads[ + marc_upload_manager_fixture.test_key2 ] - assert test_key2_upload.key == marc_uploader_fixture.test_key2 - assert test_key2_upload.content == marc_uploader_fixture.test_record2 * 6 + assert test_key2_upload.key == marc_upload_manager_fixture.test_key2 + assert test_key2_upload.content == marc_upload_manager_fixture.test_record2 * 6 assert test_key2_upload.media_type == Representation.MARC_MEDIA_TYPE - test_key3_upload = marc_uploader_fixture.mock_s3_service.uploads[ - marc_uploader_fixture.test_key3 + test_key3_upload = marc_upload_manager_fixture.mock_s3_service.uploads[ + marc_upload_manager_fixture.test_key3 ] - assert test_key3_upload.key == marc_uploader_fixture.test_key3 - assert test_key3_upload.content == marc_uploader_fixture.test_record3 + assert test_key3_upload.key == marc_upload_manager_fixture.test_key3 + assert test_key3_upload.content == marc_upload_manager_fixture.test_record3 assert test_key3_upload.media_type == Representation.MARC_MEDIA_TYPE def test__abort( self, - marc_uploader_fixture: MarcUploaderFixture, + marc_upload_manager_fixture: MarcUploadManagerFixture, caplog: pytest.LogCaptureFixture, ): - uploader = marc_uploader_fixture.uploader + uploader = marc_upload_manager_fixture.uploader # Set up the records for the test uploader.add_record( - marc_uploader_fixture.test_key1, marc_uploader_fixture.test_record1 * 10 + marc_upload_manager_fixture.test_key1, + marc_upload_manager_fixture.test_record1 * 10, ) uploader.add_record( - marc_uploader_fixture.test_key2, marc_uploader_fixture.test_record2 * 10 + marc_upload_manager_fixture.test_key2, + marc_upload_manager_fixture.test_record2 * 10, ) with uploader.begin(): uploader.sync() @@ -281,17 +298,17 @@ def test__abort( # Mock the multipart_abort method so we can check if it was called and have it # raise an exception on the first call mock_abort = MagicMock(side_effect=[Exception("Boom"), None]) - marc_uploader_fixture.mock_s3_service.multipart_abort = mock_abort + marc_upload_manager_fixture.mock_s3_service.multipart_abort = mock_abort # Wrap the delete method so we can check if it was called - mock_delete = MagicMock(wraps=marc_uploader_fixture.uploads.delete) - marc_uploader_fixture.uploads.delete = mock_delete + mock_delete = MagicMock(wraps=marc_upload_manager_fixture.uploads.delete) + marc_upload_manager_fixture.uploads.delete = mock_delete - upload_id_1 = marc_uploader_fixture.mock_s3_service.upload_in_progress[ - marc_uploader_fixture.test_key1 + upload_id_1 = marc_upload_manager_fixture.mock_s3_service.upload_in_progress[ + marc_upload_manager_fixture.test_key1 ].upload_id - upload_id_2 = marc_uploader_fixture.mock_s3_service.upload_in_progress[ - marc_uploader_fixture.test_key2 + upload_id_2 = marc_upload_manager_fixture.mock_s3_service.upload_in_progress[ + marc_upload_manager_fixture.test_key2 ].upload_id # Abort the uploads, the original exception should propagate, and the exception @@ -302,14 +319,14 @@ def test__abort( assert str(exc_info.value) == "Bang" assert ( - f"Failed to abort upload {marc_uploader_fixture.test_key1} (UploadID: {upload_id_1}) due to exception (Boom)" + f"Failed to abort upload {marc_upload_manager_fixture.test_key1} (UploadID: {upload_id_1}) due to exception (Boom)" in caplog.text ) mock_abort.assert_has_calls( [ - call(marc_uploader_fixture.test_key1, upload_id_1), - call(marc_uploader_fixture.test_key2, upload_id_2), + call(marc_upload_manager_fixture.test_key1, upload_id_1), + call(marc_upload_manager_fixture.test_key2, upload_id_2), ] ) From 2439737a569467655056a02caae715aabdfd3d17 Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Wed, 4 Sep 2024 10:50:00 -0300 Subject: [PATCH 6/7] Set a state on the upload session, so we can tell if it is being processed --- src/palace/manager/celery/tasks/marc.py | 21 +++- .../manager/service/redis/models/marc.py | 41 +++++- tests/manager/celery/tasks/test_marc.py | 119 +++++++++++------- .../manager/service/redis/models/test_marc.py | 33 ++++- 4 files changed, 162 insertions(+), 52 deletions(-) diff --git a/src/palace/manager/celery/tasks/marc.py b/src/palace/manager/celery/tasks/marc.py index 9fa82acfc6..2d164adcb2 100644 --- a/src/palace/manager/celery/tasks/marc.py +++ b/src/palace/manager/celery/tasks/marc.py @@ -7,7 +7,10 @@ from palace.manager.marc.exporter import LibraryInfo, MarcExporter from palace.manager.marc.uploader import MarcUploadManager from palace.manager.service.celery.celery import QueueNames -from palace.manager.service.redis.models.marc import MarcFileUploadSession +from palace.manager.service.redis.models.marc import ( + MarcFileUploadSession, + MarcFileUploadState, +) from palace.manager.util.datetime_helpers import utc_now @@ -26,14 +29,25 @@ def marc_export(task: Task, force: bool = False) -> None: # Collection.id should never be able to be None here, but mypy doesn't know that. # So we assert it for mypy's benefit. assert collection.id is not None - lock = MarcFileUploadSession(task.services.redis.client(), collection.id) - with lock.lock() as acquired: + upload_session = MarcFileUploadSession( + task.services.redis.client(), collection.id + ) + with upload_session.lock() as acquired: if not acquired: task.log.info( f"Skipping collection {collection.name} ({collection.id}) because another task holds its lock." ) continue + if ( + upload_state := upload_session.state() + ) != MarcFileUploadState.INITIAL: + task.log.info( + f"Skipping collection {collection.name} ({collection.id}) because it is already being " + f"processed (state: {upload_state})." + ) + continue + libraries_info = MarcExporter.enabled_libraries( session, registry, collection.id ) @@ -62,6 +76,7 @@ def marc_export(task: Task, force: bool = False) -> None: task.log.info( f"Generating MARC records for collection {collection.name} ({collection.id})." ) + upload_session.set_state(MarcFileUploadState.QUEUED) marc_export_collection.delay( collection_id=collection.id, start_time=start_time, diff --git a/src/palace/manager/service/redis/models/marc.py b/src/palace/manager/service/redis/models/marc.py index 92578c85d2..d340cbeb3b 100644 --- a/src/palace/manager/service/redis/models/marc.py +++ b/src/palace/manager/service/redis/models/marc.py @@ -3,9 +3,11 @@ import json from collections.abc import Callable, Generator, Mapping, Sequence from contextlib import contextmanager +from enum import auto from functools import cached_property from typing import Any +from backports.strenum import StrEnum from pydantic import BaseModel from redis import ResponseError, WatchError @@ -26,6 +28,12 @@ class MarcFileUpload(BaseModel): parts: list[MultipartS3UploadPart] = [] +class MarcFileUploadState(StrEnum): + INITIAL = auto() + QUEUED = auto() + UPLOADING = auto() + + class MarcFileUploadSession(RedisJsonLock, LoggerMixin): """ This class is used as a lock for the Celery MARC export task, to ensure that only one @@ -71,7 +79,9 @@ def _initial_value(self) -> str: """ The initial value to use for the locks JSON object. """ - return json.dumps({"uploads": {}, "update_number": 0}) + return json.dumps( + {"uploads": {}, "update_number": 0, "state": MarcFileUploadState.INITIAL} + ) @property def _update_number_json_key(self) -> str: @@ -81,6 +91,10 @@ def _update_number_json_key(self) -> str: def _uploads_json_key(self) -> str: return "$.uploads" + @property + def _state_json_key(self) -> str: + return "$.state" + @staticmethod def _upload_initial_value(buffer_data: str) -> dict[str, Any]: return MarcFileUpload(buffer=buffer_data).dict(exclude_none=True) @@ -116,7 +130,7 @@ def _pipeline( remote_random := fetched_data.get(self._lock_json_key) ) != self._random_value: raise MarcFileUploadSessionError( - f"Must hold lock to append to buffer. " + f"Must hold lock to update upload session. " f"Expected: {self._random_value}, got: {remote_random}" ) # Check that the update number is correct @@ -131,11 +145,18 @@ def _pipeline( pipe.multi() yield pipe - def _execute_pipeline(self, pipe: Pipeline, updates: int) -> list[Any]: + def _execute_pipeline( + self, + pipe: Pipeline, + updates: int, + *, + state: MarcFileUploadState = MarcFileUploadState.UPLOADING, + ) -> list[Any]: if not pipe.explicit_transaction: raise MarcFileUploadSessionError( "Pipeline should be in explicit transaction mode before executing." ) + pipe.json().set(self.key, path=self._state_json_key, obj=state) pipe.json().numincrby(self.key, self._update_number_json_key, updates) pipe.pexpire(self.key, self._lock_timeout_ms) try: @@ -145,7 +166,8 @@ def _execute_pipeline(self, pipe: Pipeline, updates: int) -> list[Any]: "Failed to update buffers. Another process is modifying the buffers." ) from e self._update_number = self._parse_value_or_raise(pipe_results[-2]) - return pipe_results[:-2] + + return pipe_results[:-3] def append_buffers(self, data: Mapping[str, str]) -> dict[str, int]: if not data: @@ -271,3 +293,14 @@ def get_part_num_and_buffer(self, key: str) -> tuple[int, str]: part_number: int = self._parse_value_or_raise(results[1]) return part_number, buffer_data + + def state(self) -> MarcFileUploadState | None: + get_results = self._redis_client.json().get(self.key, self._state_json_key) + state: str | None = self._parse_value(get_results) + if state is None: + return None + return MarcFileUploadState(state) + + def set_state(self, state: MarcFileUploadState) -> None: + with self._pipeline() as pipe: + self._execute_pipeline(pipe, 0, state=state) diff --git a/tests/manager/celery/tasks/test_marc.py b/tests/manager/celery/tasks/test_marc.py index 4779672ad6..3b796de2ed 100644 --- a/tests/manager/celery/tasks/test_marc.py +++ b/tests/manager/celery/tasks/test_marc.py @@ -12,6 +12,7 @@ from palace.manager.service.redis.models.marc import ( MarcFileUploadSession, MarcFileUploadSessionError, + MarcFileUploadState, ) from palace.manager.sqlalchemy.model.collection import Collection from palace.manager.sqlalchemy.model.marcfile import MarcFile @@ -26,19 +27,54 @@ from tests.fixtures.services import ServicesFixture -def test_marc_export( - db: DatabaseTransactionFixture, - redis_fixture: RedisFixture, - marc_exporter_fixture: MarcExporterFixture, - celery_fixture: CeleryFixture, -): - marc_exporter_fixture.configure_export() - with (patch.object(marc, "marc_export_collection") as marc_export_collection,): - # Because none of the collections have works, we should skip all of them. - marc.marc_export.delay().wait() - marc_export_collection.delay.assert_not_called() - - # Runs against all the expected collections +class TestMarcExport: + def test_no_works( + self, + db: DatabaseTransactionFixture, + redis_fixture: RedisFixture, + marc_exporter_fixture: MarcExporterFixture, + celery_fixture: CeleryFixture, + ): + marc_exporter_fixture.configure_export() + with patch.object(marc, "marc_export_collection") as marc_export_collection: + # Because none of the collections have works, we should skip all of them. + marc.marc_export.delay().wait() + marc_export_collection.delay.assert_not_called() + + def test_normal_run( + self, + db: DatabaseTransactionFixture, + redis_fixture: RedisFixture, + marc_exporter_fixture: MarcExporterFixture, + celery_fixture: CeleryFixture, + ): + marc_exporter_fixture.configure_export() + with patch.object(marc, "marc_export_collection") as marc_export_collection: + # Runs against all the expected collections + collections = [ + marc_exporter_fixture.collection1, + marc_exporter_fixture.collection2, + marc_exporter_fixture.collection3, + ] + for collection in collections: + marc_exporter_fixture.work(collection) + marc.marc_export.delay().wait() + marc_export_collection.delay.assert_has_calls( + [ + call(collection_id=collection.id, start_time=ANY, libraries=ANY) + for collection in collections + ], + any_order=True, + ) + + def test_skip_collections( + self, + db: DatabaseTransactionFixture, + redis_fixture: RedisFixture, + marc_exporter_fixture: MarcExporterFixture, + celery_fixture: CeleryFixture, + ): + marc_exporter_fixture.configure_export() collections = [ marc_exporter_fixture.collection1, marc_exporter_fixture.collection2, @@ -46,39 +82,34 @@ def test_marc_export( ] for collection in collections: marc_exporter_fixture.work(collection) - marc.marc_export.delay().wait() - marc_export_collection.delay.assert_has_calls( - [ - call(collection_id=collection.id, start_time=ANY, libraries=ANY) - for collection in collections - ], - any_order=True, - ) + with patch.object(marc, "marc_export_collection") as marc_export_collection: + # Collection 1 should be skipped because it is locked + assert marc_exporter_fixture.collection1.id is not None + MarcFileUploadSession( + redis_fixture.client, marc_exporter_fixture.collection1.id + ).acquire() + + # Collection 2 should be skipped because it was updated recently + create( + db.session, + MarcFile, + library=marc_exporter_fixture.library1, + collection=marc_exporter_fixture.collection2, + created=utc_now(), + key="test-file-2.mrc", + ) - marc_export_collection.reset_mock() - - # Collection 1 should be skipped because it is locked - assert marc_exporter_fixture.collection1.id is not None - MarcFileUploadSession( - redis_fixture.client, marc_exporter_fixture.collection1.id - ).acquire() - - # Collection 2 should be skipped because it was updated recently - create( - db.session, - MarcFile, - library=marc_exporter_fixture.library1, - collection=marc_exporter_fixture.collection2, - created=utc_now(), - key="test-file-2.mrc", - ) + # Collection 3 should be skipped because its state is not INITIAL + assert marc_exporter_fixture.collection3.id is not None + upload_session = MarcFileUploadSession( + redis_fixture.client, marc_exporter_fixture.collection3.id + ) + with upload_session.lock() as acquired: + assert acquired + upload_session.set_state(MarcFileUploadState.QUEUED) - marc.marc_export.delay().wait() - marc_export_collection.delay.assert_called_once_with( - collection_id=marc_exporter_fixture.collection3.id, - start_time=ANY, - libraries=ANY, - ) + marc.marc_export.delay().wait() + marc_export_collection.delay.assert_not_called() class MarcExportCollectionFixture: diff --git a/tests/manager/service/redis/models/test_marc.py b/tests/manager/service/redis/models/test_marc.py index 5b64725089..3013b2906d 100644 --- a/tests/manager/service/redis/models/test_marc.py +++ b/tests/manager/service/redis/models/test_marc.py @@ -4,6 +4,7 @@ MarcFileUpload, MarcFileUploadSession, MarcFileUploadSessionError, + MarcFileUploadState, ) from palace.manager.service.redis.redis import Pipeline from palace.manager.service.storage.s3 import MultipartS3UploadPart @@ -111,11 +112,12 @@ def test__execute_pipeline( assert "Pipeline should be in explicit transaction mode" in str(exc_info.value) # The _execute_pipeline function takes care of extending the timeout and incrementing - # the update number. + # the update number and setting the state of the session [update_number] = client.json().get( uploads.key, uploads._update_number_json_key ) client.pexpire(uploads.key, 500) + old_state = uploads.state() with uploads._pipeline() as pipe: # If we execute the pipeline, we should get a list of results, excluding the # operations that _execute_pipeline does. @@ -125,6 +127,8 @@ def test__execute_pipeline( ) assert new_update_number == update_number + 2 assert client.pttl(uploads.key) > 500 + assert uploads.state() != old_state + assert uploads.state() == MarcFileUploadState.UPLOADING # If we try to execute a pipeline that has been modified by another process, we should get an error with uploads._pipeline() as pipe: @@ -436,3 +440,30 @@ def test_get_part_num_and_buffer( assert uploads.get_part_num_and_buffer( marc_file_upload_session_fixture.mock_upload_key_1 ) == (2, "1234567") + + def test_state( + self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture + ): + uploads = marc_file_upload_session_fixture.uploads + + # If the session doesn't exist, the state should be None + assert uploads.state() is None + + # Once the state is created, by locking for example, the state should be SessionState.INITIAL + with uploads.lock(): + assert uploads.state() == MarcFileUploadState.INITIAL + + def test_set_state( + self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture + ): + uploads = marc_file_upload_session_fixture.uploads + + # If we don't hold the lock, we can't set the state + with pytest.raises(MarcFileUploadSessionError) as exc_info: + uploads.set_state(MarcFileUploadState.UPLOADING) + assert "Must hold lock" in str(exc_info.value) + + # Once the state is created, by locking for example, we can set the state + with uploads.lock(): + uploads.set_state(MarcFileUploadState.UPLOADING) + assert uploads.state() == MarcFileUploadState.UPLOADING From c1098b06f80fb2a5e7a66796a2e642b2ddd74591 Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Wed, 4 Sep 2024 10:58:11 -0300 Subject: [PATCH 7/7] Fix import --- src/palace/manager/service/redis/models/marc.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/palace/manager/service/redis/models/marc.py b/src/palace/manager/service/redis/models/marc.py index d340cbeb3b..6f443b7ceb 100644 --- a/src/palace/manager/service/redis/models/marc.py +++ b/src/palace/manager/service/redis/models/marc.py @@ -1,13 +1,13 @@ from __future__ import annotations import json +import sys from collections.abc import Callable, Generator, Mapping, Sequence from contextlib import contextmanager from enum import auto from functools import cached_property from typing import Any -from backports.strenum import StrEnum from pydantic import BaseModel from redis import ResponseError, WatchError @@ -17,6 +17,12 @@ from palace.manager.sqlalchemy.model.collection import Collection from palace.manager.util.log import LoggerMixin +# TODO: Remove this when we drop support for Python 3.10 +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + from backports.strenum import StrEnum + class MarcFileUploadSessionError(LockError): pass