diff --git a/src/palace/manager/celery/tasks/marc.py b/src/palace/manager/celery/tasks/marc.py index fd095583d..ec91113c4 100644 --- a/src/palace/manager/celery/tasks/marc.py +++ b/src/palace/manager/celery/tasks/marc.py @@ -133,32 +133,37 @@ def marc_export_collection( # Sync the upload_manager to ensure that all the data is written to storage. upload_manager.sync() - if len(works) == batch_size: - # This task is complete, but there are more works waiting to be exported. So we requeue ourselves - # to process the next batch. - raise task.replace( - marc_export_collection.s( - collection_id=collection_id, - start_time=start_time, - libraries=[l.dict() for l in libraries_info], - batch_size=batch_size, - last_work_id=works[-1].id, - update_number=upload_manager.update_number, + if len(works) != batch_size: + # We have finished generating MARC records. Cleanup and exit. + with task.transaction() as session: + collection = MarcExporter.collection(session, collection_id) + collection_name = collection.name if collection else "unknown" + completed_uploads = upload_manager.complete() + MarcExporter.create_marc_upload_records( + session, + start_time, + collection_id, + libraries_info, + completed_uploads, ) + upload_manager.remove_session() + task.log.info( + f"Finished generating MARC records for collection '{collection_name}' ({collection_id})." ) + return - # If we got here, we have finished generating MARC records. Cleanup and exit. - with task.transaction() as session: - collection = MarcExporter.collection(session, collection_id) - collection_name = collection.name if collection else "unknown" - completed_uploads = upload_manager.complete() - MarcExporter.create_marc_upload_records( - session, start_time, collection_id, libraries_info, completed_uploads - ) - upload_manager.remove_session() - task.log.info( - f"Finished generating MARC records for collection '{collection_name}' ({collection_id})." + # This task is complete, but there are more works waiting to be exported. So we requeue ourselves + # to process the next batch. + raise task.replace( + marc_export_collection.s( + collection_id=collection_id, + start_time=start_time, + libraries=[l.dict() for l in libraries_info], + batch_size=batch_size, + last_work_id=works[-1].id, + update_number=upload_manager.update_number, ) + ) @shared_task(queue=QueueNames.default, bind=True) diff --git a/src/palace/manager/service/redis/models/marc.py b/src/palace/manager/service/redis/models/marc.py index 6f443b7ce..b7ee003c9 100644 --- a/src/palace/manager/service/redis/models/marc.py +++ b/src/palace/manager/service/redis/models/marc.py @@ -11,6 +11,7 @@ from pydantic import BaseModel from redis import ResponseError, WatchError +from palace.manager.core.exceptions import PalaceValueError from palace.manager.service.redis.models.lock import LockError, RedisJsonLock from palace.manager.service.redis.redis import Pipeline, Redis from palace.manager.service.storage.s3 import MultipartS3UploadPart @@ -40,7 +41,85 @@ class MarcFileUploadState(StrEnum): UPLOADING = auto() -class MarcFileUploadSession(RedisJsonLock, LoggerMixin): +class PathEscapeMixin: + """ + Mixin to provide methods for escaping and unescaping paths for use in redis. + + This is necessary because it seems like there is a bug in the AWS elasticache implementation + of JSONPATH where slashes or tilde character within a string literal used as a key cause issues. + This bug is not present in the open source redis implementation, which does the sane thing, not + requiring any special escaping. + + Hopefully at some point AWS will fix these issues, and we can drop this mixin, so I tried to + encapsulate the logic for this here. + + In AWS when a tilde is used in a key, the key is never updated, despite returning a success. And + when a slash is used in a key, the key is interpreted as a nested path, nesting a new key for every + slash in the path. This is not the behavior we want, so we need to escape these characters. + + We can test if this is fixed in the future by running the test suite against AWS elasticache with + this mixin removed. If the tests pass, then it can be removed. + + Characters are escaped by prefixing them with a backtick character, followed by a single character + from _MAPPING that represents the escaped character. The backtick character itself is escaped by + prefixing it with another backtick character. + """ + + _ESCAPE_CHAR = "`" + + _MAPPING = { + "/": "s", + "~": "t", + } + + @cached_property + def _FORWARD_MAPPING(self) -> dict[str, str]: + mapping = {k: "".join((self._ESCAPE_CHAR, v)) for k, v in self._MAPPING.items()} + mapping[self._ESCAPE_CHAR] = "".join((self._ESCAPE_CHAR, self._ESCAPE_CHAR)) + return mapping + + @cached_property + def _REVERSE_MAPPING(self) -> dict[str, str]: + mapping = {v: k for k, v in self._MAPPING.items()} + mapping[self._ESCAPE_CHAR] = self._ESCAPE_CHAR + return mapping + + def _escape_path(self, path: str) -> str: + escaped = json.dumps("".join([self._FORWARD_MAPPING.get(c, c) for c in path])) + return escaped[1:-1] + + def _unescape_path(self, path: str) -> str: + # Normal redis paths are always double-quoted, so we can use json.loads to unescape them. + # This does not happen in the AWS elasticache implementation, so we need to handle it manually, + # so that we can support both implementations. + try: + path = json.loads(f'"{path}"') + except json.JSONDecodeError: + pass + + in_escape = False + unescaped = [] + for char in path: + if in_escape: + if char not in self._REVERSE_MAPPING: + raise PalaceValueError( + f"Invalid escape sequence '{self._ESCAPE_CHAR}{char}'" + ) + unescaped.append(self._REVERSE_MAPPING[char]) + in_escape = False + else: + if char == self._ESCAPE_CHAR: + in_escape = True + else: + unescaped.append(char) + + if in_escape: + raise PalaceValueError("Unterminated escape sequence.") + + return "".join(unescaped) + + +class MarcFileUploadSession(RedisJsonLock, PathEscapeMixin, LoggerMixin): """ This class is used as a lock for the Celery MARC export task, to ensure that only one task can upload MARC files for a given collection at a time. It increments an update @@ -106,7 +185,8 @@ def _upload_initial_value(buffer_data: str) -> dict[str, Any]: return MarcFileUpload(buffer=buffer_data).dict(exclude_none=True) def _upload_path(self, upload_key: str) -> str: - return f"{self._uploads_json_key}['{upload_key}']" + upload_key = self._escape_path(upload_key) + return f'{self._uploads_json_key}["{upload_key}"]' def _buffer_path(self, upload_key: str) -> str: upload_path = self._upload_path(upload_key) @@ -166,7 +246,7 @@ def _execute_pipeline( pipe.json().numincrby(self.key, self._update_number_json_key, updates) pipe.pexpire(self.key, self._lock_timeout_ms) try: - pipe_results = pipe.execute() + pipe_results = pipe.execute(raise_on_error=False) except WatchError as e: raise MarcFileUploadSessionError( "Failed to update buffers. Another process is modifying the buffers." @@ -175,15 +255,30 @@ def _execute_pipeline( return pipe_results[:-3] + @staticmethod + def _validate_results(results: list[Any]) -> bool: + """ + This function validates that all the results of the pipeline are successful, + and not a ResponseError. + + NOTE: The AWS elasticache implementation returns slightly different results then redis. + In redis, unsuccessful results when a key is not found are `None`, but in AWS they are + returned as a `ResponseError`, which is why we are checking for both in this function. + """ + return all(r and not isinstance(r, ResponseError) for r in results) + def append_buffers(self, data: Mapping[str, str]) -> dict[str, int]: if not data: return {} set_results = {} with self._pipeline(begin_transaction=False) as pipe: - existing_uploads: list[str] = self._parse_value_or_raise( - pipe.json().objkeys(self.key, self._uploads_json_key) - ) + existing_uploads: list[str] = [ + self._unescape_path(r) + for r in self._parse_value_or_raise( + pipe.json().objkeys(self.key, self._uploads_json_key) + ) + ] pipe.multi() for key, value in data.items(): if value == "": @@ -193,16 +288,17 @@ def append_buffers(self, data: Mapping[str, str]) -> dict[str, int]: self.key, path=self._buffer_path(key), value=value ) else: + path = self._upload_path(key) pipe.json().set( self.key, - path=self._upload_path(key), + path=path, obj=self._upload_initial_value(value), ) set_results[key] = len(value) pipe_results = self._execute_pipeline(pipe, len(data)) - if not all(pipe_results): + if not self._validate_results(pipe_results): raise MarcFileUploadSessionError("Failed to append buffers.") return { @@ -224,7 +320,7 @@ def add_part_and_clear_buffer(self, key: str, part: MultipartS3UploadPart) -> No ) pipe_results = self._execute_pipeline(pipe, 1) - if not all(pipe_results): + if not self._validate_results(pipe_results): raise MarcFileUploadSessionError("Failed to add part and clear buffer.") def set_upload_id(self, key: str, upload_id: str) -> None: @@ -237,7 +333,7 @@ def set_upload_id(self, key: str, upload_id: str) -> None: ) pipe_results = self._execute_pipeline(pipe, 1) - if not all(pipe_results): + if not self._validate_results(pipe_results): raise MarcFileUploadSessionError("Failed to set upload ID.") def clear_uploads(self) -> None: @@ -245,7 +341,7 @@ def clear_uploads(self) -> None: pipe.json().clear(self.key, self._uploads_json_key) pipe_results = self._execute_pipeline(pipe, 1) - if not all(pipe_results): + if not self._validate_results(pipe_results): raise MarcFileUploadSessionError("Failed to clear uploads.") def _get_specific( @@ -269,7 +365,7 @@ def _get_all(self, key: str) -> dict[str, Any]: if results is None: return {} - return results + return {self._unescape_path(k): v for k, v in results.items()} def get(self, keys: str | Sequence[str] | None = None) -> dict[str, MarcFileUpload]: if keys is None: @@ -285,15 +381,14 @@ def get_upload_ids(self, keys: str | Sequence[str]) -> dict[str, str]: return self._get_specific(keys, self._upload_id_path) def get_part_num_and_buffer(self, key: str) -> tuple[int, str]: - try: - with self._redis_client.pipeline() as pipe: - pipe.json().get(self.key, self._buffer_path(key)) - pipe.json().arrlen(self.key, self._parts_path(key)) - results = pipe.execute() - except ResponseError as e: + with self._redis_client.pipeline() as pipe: + pipe.json().get(self.key, self._buffer_path(key)) + pipe.json().arrlen(self.key, self._parts_path(key)) + results = pipe.execute(raise_on_error=False) + if not self._validate_results(results): raise MarcFileUploadSessionError( "Failed to get part number and buffer data." - ) from e + ) buffer_data: str = self._parse_value_or_raise(results[0]) part_number: int = self._parse_value_or_raise(results[1]) diff --git a/tests/manager/service/redis/models/test_marc.py b/tests/manager/service/redis/models/test_marc.py index 3013b2906..e10bf3ef0 100644 --- a/tests/manager/service/redis/models/test_marc.py +++ b/tests/manager/service/redis/models/test_marc.py @@ -1,10 +1,15 @@ +import re +import string + import pytest +from palace.manager.core.exceptions import PalaceValueError from palace.manager.service.redis.models.marc import ( MarcFileUpload, MarcFileUploadSession, MarcFileUploadSessionError, MarcFileUploadState, + PathEscapeMixin, ) from palace.manager.service.redis.redis import Pipeline from palace.manager.service.storage.s3 import MultipartS3UploadPart @@ -21,9 +26,10 @@ def __init__(self, redis_fixture: RedisFixture): self._redis_fixture.client, self.mock_collection_id ) - self.mock_upload_key_1 = "test1" - self.mock_upload_key_2 = "test2" - self.mock_upload_key_3 = "test3" + # Some keys with special characters to make sure they are handled correctly. + self.mock_upload_key_1 = "test/test1/?$xyz.abc" + self.mock_upload_key_2 = "test/test2.ext`" + self.mock_upload_key_3 = "//t/e/s/t3!!\\~'\"`" self.mock_unset_upload_key = "test4" @@ -49,7 +55,7 @@ def load_test_data(self) -> dict[str, int]: return return_value - def test_data_records(self, *keys: str): + def test_data_records(self, *keys: str) -> dict[str, MarcFileUpload]: return {key: MarcFileUpload(buffer=self.test_data[key]) for key in keys} @@ -467,3 +473,59 @@ def test_set_state( with uploads.lock(): uploads.set_state(MarcFileUploadState.UPLOADING) assert uploads.state() == MarcFileUploadState.UPLOADING + + +class TestPathEscapeMixin: + @pytest.mark.parametrize( + "path", + [ + "", + "test", + string.printable, + "test/test1/?$xyz.abc", + "`", + "```", + "/~`\\", + "`\\~/``/", + "a", + "/", + "~", + " ", + ], + ) + def test_escape_path(self, path: str) -> None: + # Test a round trip + escaper = PathEscapeMixin() + escaped = escaper._escape_path(path) + unescaped = escaper._unescape_path(escaped) + assert unescaped == path + + # Test that we can handle escaping the escaped path multiple times + escaped = path + for _ in range(10): + escaped = escaper._escape_path(escaped) + + unescaped = escaped + for _ in range(10): + unescaped = escaper._unescape_path(unescaped) + + assert unescaped == path + + def test_unescape(self) -> None: + escaper = PathEscapeMixin() + assert escaper._unescape_path("") == "" + + with pytest.raises( + PalaceValueError, match=re.escape("Invalid escape sequence '`?'") + ): + escaper._unescape_path("test `?") + + with pytest.raises( + PalaceValueError, match=re.escape("Invalid escape sequence '` '") + ): + escaper._unescape_path("``` test") + + with pytest.raises( + PalaceValueError, match=re.escape("Unterminated escape sequence") + ): + escaper._unescape_path("`")