Skip to content

Commit

Permalink
Fix issues caused by differences between redis and elasticache
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathangreen committed Sep 10, 2024
1 parent b136123 commit 83484e3
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 45 deletions.
49 changes: 27 additions & 22 deletions src/palace/manager/celery/tasks/marc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,32 +133,37 @@ def marc_export_collection(
# Sync the upload_manager to ensure that all the data is written to storage.
upload_manager.sync()

if len(works) == batch_size:
# This task is complete, but there are more works waiting to be exported. So we requeue ourselves
# to process the next batch.
raise task.replace(
marc_export_collection.s(
collection_id=collection_id,
start_time=start_time,
libraries=[l.dict() for l in libraries_info],
batch_size=batch_size,
last_work_id=works[-1].id,
update_number=upload_manager.update_number,
if len(works) != batch_size:
# We have finished generating MARC records. Cleanup and exit.
with task.transaction() as session:
collection = MarcExporter.collection(session, collection_id)
collection_name = collection.name if collection else "unknown"
completed_uploads = upload_manager.complete()
MarcExporter.create_marc_upload_records(
session,
start_time,
collection_id,
libraries_info,
completed_uploads,
)
upload_manager.remove_session()
task.log.info(
f"Finished generating MARC records for collection '{collection_name}' ({collection_id})."
)
return

# If we got here, we have finished generating MARC records. Cleanup and exit.
with task.transaction() as session:
collection = MarcExporter.collection(session, collection_id)
collection_name = collection.name if collection else "unknown"
completed_uploads = upload_manager.complete()
MarcExporter.create_marc_upload_records(
session, start_time, collection_id, libraries_info, completed_uploads
)
upload_manager.remove_session()
task.log.info(
f"Finished generating MARC records for collection '{collection_name}' ({collection_id})."
# This task is complete, but there are more works waiting to be exported. So we requeue ourselves
# to process the next batch.
raise task.replace(
marc_export_collection.s(
collection_id=collection_id,
start_time=start_time,
libraries=[l.dict() for l in libraries_info],
batch_size=batch_size,
last_work_id=works[-1].id,
update_number=upload_manager.update_number,
)
)


@shared_task(queue=QueueNames.default, bind=True)
Expand Down
133 changes: 114 additions & 19 deletions src/palace/manager/service/redis/models/marc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pydantic import BaseModel
from redis import ResponseError, WatchError

from palace.manager.core.exceptions import PalaceValueError
from palace.manager.service.redis.models.lock import LockError, RedisJsonLock
from palace.manager.service.redis.redis import Pipeline, Redis
from palace.manager.service.storage.s3 import MultipartS3UploadPart
Expand Down Expand Up @@ -40,7 +41,85 @@ class MarcFileUploadState(StrEnum):
UPLOADING = auto()


class MarcFileUploadSession(RedisJsonLock, LoggerMixin):
class PathEscapeMixin:
"""
Mixin to provide methods for escaping and unescaping paths for use in redis.
This is necessary because it seems like there is a bug in the AWS elasticache implementation
of JSONPATH where slashes or tilde character within a string literal used as a key cause issues.
This bug is not present in the open source redis implementation, which does the sane thing, not
requiring any special escaping.
Hopefully at some point AWS will fix these issues, and we can drop this mixin, so I tried to
encapsulate the logic for this here.
In AWS when a tilde is used in a key, the key is never updated, despite returning a success. And
when a slash is used in a key, the key is interpreted as a nested path, nesting a new key for every
slash in the path. This is not the behavior we want, so we need to escape these characters.
We can test if this is fixed in the future by running the test suite against AWS elasticache with
this mixin removed. If the tests pass, then it can be removed.
Characters are escaped by prefixing them with a backtick character, followed by a single character
from _MAPPING that represents the escaped character. The backtick character itself is escaped by
prefixing it with another backtick character.
"""

_ESCAPE_CHAR = "`"

_MAPPING = {
"/": "s",
"~": "t",
}

@cached_property
def _FORWARD_MAPPING(self) -> dict[str, str]:
mapping = {k: "".join((self._ESCAPE_CHAR, v)) for k, v in self._MAPPING.items()}
mapping[self._ESCAPE_CHAR] = "".join((self._ESCAPE_CHAR, self._ESCAPE_CHAR))
return mapping

@cached_property
def _REVERSE_MAPPING(self) -> dict[str, str]:
mapping = {v: k for k, v in self._MAPPING.items()}
mapping[self._ESCAPE_CHAR] = self._ESCAPE_CHAR
return mapping

def _escape_path(self, path: str) -> str:
escaped = json.dumps("".join([self._FORWARD_MAPPING.get(c, c) for c in path]))
return escaped[1:-1]

def _unescape_path(self, path: str) -> str:
# Normal redis paths are always double-quoted, so we can use json.loads to unescape them.
# This does not happen in the AWS elasticache implementation, so we need to handle it manually,
# so that we can support both implementations.
try:
path = json.loads(f'"{path}"')
except json.JSONDecodeError:
pass

in_escape = False
unescaped = []
for char in path:
if in_escape:
if char not in self._REVERSE_MAPPING:
raise PalaceValueError(
f"Invalid escape sequence '{self._ESCAPE_CHAR}{char}'"
)
unescaped.append(self._REVERSE_MAPPING[char])
in_escape = False
else:
if char == self._ESCAPE_CHAR:
in_escape = True
else:
unescaped.append(char)

if in_escape:
raise PalaceValueError("Unterminated escape sequence.")

return "".join(unescaped)


class MarcFileUploadSession(RedisJsonLock, PathEscapeMixin, LoggerMixin):
"""
This class is used as a lock for the Celery MARC export task, to ensure that only one
task can upload MARC files for a given collection at a time. It increments an update
Expand Down Expand Up @@ -106,7 +185,8 @@ def _upload_initial_value(buffer_data: str) -> dict[str, Any]:
return MarcFileUpload(buffer=buffer_data).dict(exclude_none=True)

def _upload_path(self, upload_key: str) -> str:
return f"{self._uploads_json_key}['{upload_key}']"
upload_key = self._escape_path(upload_key)
return f'{self._uploads_json_key}["{upload_key}"]'

def _buffer_path(self, upload_key: str) -> str:
upload_path = self._upload_path(upload_key)
Expand Down Expand Up @@ -166,7 +246,7 @@ def _execute_pipeline(
pipe.json().numincrby(self.key, self._update_number_json_key, updates)
pipe.pexpire(self.key, self._lock_timeout_ms)
try:
pipe_results = pipe.execute()
pipe_results = pipe.execute(raise_on_error=False)
except WatchError as e:
raise MarcFileUploadSessionError(
"Failed to update buffers. Another process is modifying the buffers."
Expand All @@ -175,15 +255,30 @@ def _execute_pipeline(

return pipe_results[:-3]

@staticmethod
def _validate_results(results: list[Any]) -> bool:
"""
This function validates that all the results of the pipeline are successful,
and not a ResponseError.
NOTE: The AWS elasticache implementation returns slightly different results then redis.
In redis, unsuccessful results when a key is not found are `None`, but in AWS they are
returned as a `ResponseError`, which is why we are checking for both in this function.
"""
return all(r and not isinstance(r, ResponseError) for r in results)

def append_buffers(self, data: Mapping[str, str]) -> dict[str, int]:
if not data:
return {}

set_results = {}
with self._pipeline(begin_transaction=False) as pipe:
existing_uploads: list[str] = self._parse_value_or_raise(
pipe.json().objkeys(self.key, self._uploads_json_key)
)
existing_uploads: list[str] = [
self._unescape_path(r)
for r in self._parse_value_or_raise(
pipe.json().objkeys(self.key, self._uploads_json_key)
)
]
pipe.multi()
for key, value in data.items():
if value == "":
Expand All @@ -193,16 +288,17 @@ def append_buffers(self, data: Mapping[str, str]) -> dict[str, int]:
self.key, path=self._buffer_path(key), value=value
)
else:
path = self._upload_path(key)
pipe.json().set(
self.key,
path=self._upload_path(key),
path=path,
obj=self._upload_initial_value(value),
)
set_results[key] = len(value)

pipe_results = self._execute_pipeline(pipe, len(data))

if not all(pipe_results):
if not self._validate_results(pipe_results):
raise MarcFileUploadSessionError("Failed to append buffers.")

return {
Expand All @@ -224,7 +320,7 @@ def add_part_and_clear_buffer(self, key: str, part: MultipartS3UploadPart) -> No
)
pipe_results = self._execute_pipeline(pipe, 1)

if not all(pipe_results):
if not self._validate_results(pipe_results):
raise MarcFileUploadSessionError("Failed to add part and clear buffer.")

def set_upload_id(self, key: str, upload_id: str) -> None:
Expand All @@ -237,15 +333,15 @@ def set_upload_id(self, key: str, upload_id: str) -> None:
)
pipe_results = self._execute_pipeline(pipe, 1)

if not all(pipe_results):
if not self._validate_results(pipe_results):
raise MarcFileUploadSessionError("Failed to set upload ID.")

def clear_uploads(self) -> None:
with self._pipeline() as pipe:
pipe.json().clear(self.key, self._uploads_json_key)
pipe_results = self._execute_pipeline(pipe, 1)

if not all(pipe_results):
if not self._validate_results(pipe_results):
raise MarcFileUploadSessionError("Failed to clear uploads.")

def _get_specific(
Expand All @@ -269,7 +365,7 @@ def _get_all(self, key: str) -> dict[str, Any]:
if results is None:
return {}

return results
return {self._unescape_path(k): v for k, v in results.items()}

def get(self, keys: str | Sequence[str] | None = None) -> dict[str, MarcFileUpload]:
if keys is None:
Expand All @@ -285,15 +381,14 @@ def get_upload_ids(self, keys: str | Sequence[str]) -> dict[str, str]:
return self._get_specific(keys, self._upload_id_path)

def get_part_num_and_buffer(self, key: str) -> tuple[int, str]:
try:
with self._redis_client.pipeline() as pipe:
pipe.json().get(self.key, self._buffer_path(key))
pipe.json().arrlen(self.key, self._parts_path(key))
results = pipe.execute()
except ResponseError as e:
with self._redis_client.pipeline() as pipe:
pipe.json().get(self.key, self._buffer_path(key))
pipe.json().arrlen(self.key, self._parts_path(key))
results = pipe.execute(raise_on_error=False)
if not self._validate_results(results):
raise MarcFileUploadSessionError(
"Failed to get part number and buffer data."
) from e
)

buffer_data: str = self._parse_value_or_raise(results[0])
part_number: int = self._parse_value_or_raise(results[1])
Expand Down
70 changes: 66 additions & 4 deletions tests/manager/service/redis/models/test_marc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import re
import string

import pytest

from palace.manager.core.exceptions import PalaceValueError
from palace.manager.service.redis.models.marc import (
MarcFileUpload,
MarcFileUploadSession,
MarcFileUploadSessionError,
MarcFileUploadState,
PathEscapeMixin,
)
from palace.manager.service.redis.redis import Pipeline
from palace.manager.service.storage.s3 import MultipartS3UploadPart
Expand All @@ -21,9 +26,10 @@ def __init__(self, redis_fixture: RedisFixture):
self._redis_fixture.client, self.mock_collection_id
)

self.mock_upload_key_1 = "test1"
self.mock_upload_key_2 = "test2"
self.mock_upload_key_3 = "test3"
# Some keys with special characters to make sure they are handled correctly.
self.mock_upload_key_1 = "test/test1/?$xyz.abc"
self.mock_upload_key_2 = "test/test2.ext`"
self.mock_upload_key_3 = "//t/e/s/t3!!\\~'\"`"

self.mock_unset_upload_key = "test4"

Expand All @@ -49,7 +55,7 @@ def load_test_data(self) -> dict[str, int]:

return return_value

def test_data_records(self, *keys: str):
def test_data_records(self, *keys: str) -> dict[str, MarcFileUpload]:
return {key: MarcFileUpload(buffer=self.test_data[key]) for key in keys}


Expand Down Expand Up @@ -467,3 +473,59 @@ def test_set_state(
with uploads.lock():
uploads.set_state(MarcFileUploadState.UPLOADING)
assert uploads.state() == MarcFileUploadState.UPLOADING


class TestPathEscapeMixin:
@pytest.mark.parametrize(
"path",
[
"",
"test",
string.printable,
"test/test1/?$xyz.abc",
"`",
"```",
"/~`\\",
"`\\~/``/",
"a",
"/",
"~",
" ",
],
)
def test_escape_path(self, path: str) -> None:
# Test a round trip
escaper = PathEscapeMixin()
escaped = escaper._escape_path(path)
unescaped = escaper._unescape_path(escaped)
assert unescaped == path

# Test that we can handle escaping the escaped path multiple times
escaped = path
for _ in range(10):
escaped = escaper._escape_path(escaped)

unescaped = escaped
for _ in range(10):
unescaped = escaper._unescape_path(unescaped)

assert unescaped == path

def test_unescape(self) -> None:
escaper = PathEscapeMixin()
assert escaper._unescape_path("") == ""

with pytest.raises(
PalaceValueError, match=re.escape("Invalid escape sequence '`?'")
):
escaper._unescape_path("test `?")

with pytest.raises(
PalaceValueError, match=re.escape("Invalid escape sequence '` '")
):
escaper._unescape_path("``` test")

with pytest.raises(
PalaceValueError, match=re.escape("Unterminated escape sequence")
):
escaper._unescape_path("`")

0 comments on commit 83484e3

Please sign in to comment.