Skip to content

Commit

Permalink
Don't persist signed manifest URLs in StepFunction output
Browse files Browse the repository at this point in the history
  • Loading branch information
hannes-ucsc committed Dec 12, 2024
1 parent cfc605b commit 3a64490
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 33 deletions.
2 changes: 1 addition & 1 deletion src/azul/service/manifest_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def get_manifest_async(self,
manifest_key = self.service.sign_manifest_key(manifest_key)
url = self.manifest_url_func(fetch=False, token_or_key=manifest_key.encode())
else:
url = furl(manifest.location)
url = furl(self.service.get_manifest_url(manifest))
body = {
'Status': 302,
'Location': str(url),
Expand Down
48 changes: 26 additions & 22 deletions src/azul/service/manifest_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,9 @@ class Manifest:
"""
Contains the details of a prepared manifest.
"""
#: The URL of the manifest file.
location: str
#: The S3 object key under which the manifest is stored in the storage
#: bucket
object_key: str

#: True if an existing manifest was reused or False if a new manifest was
#: generated.
Expand All @@ -415,7 +416,7 @@ class Manifest:
#: The format of the manifest
format: ManifestFormat

#: The key under which the manifest is stored
#: Uniquely identifies this manifest
manifest_key: ManifestKey

#: The proposed file name of the manifest when downloading it to a user's
Expand All @@ -424,7 +425,7 @@ class Manifest:

def to_json(self) -> JSON:
return {
'location': self.location,
'object_key': self.object_key,
'was_cached': self.was_cached,
'format': self.format.value,
'manifest_key': self.manifest_key.to_json(),
Expand All @@ -433,7 +434,7 @@ def to_json(self) -> JSON:

@classmethod
def from_json(cls, json: JSON) -> 'Manifest':
return cls(location=json['location'],
return cls(object_key=json['object_key'],
was_cached=json['was_cached'],
format=ManifestFormat(json['format']),
manifest_key=ManifestKey.from_json(json['manifest_key']),
Expand Down Expand Up @@ -637,10 +638,10 @@ def _generate_manifest(self,
) -> Manifest | ManifestPartition:
partition = generator.write(manifest_key, partition)
if partition.is_last:
return self._presign_manifest(generator_cls=type(generator),
manifest_key=manifest_key,
file_name=partition.file_name,
was_cached=False)
return self._make_manifest(generator_cls=type(generator),
manifest_key=manifest_key,
file_name=partition.file_name,
was_cached=False)
else:
return partition

Expand Down Expand Up @@ -695,27 +696,30 @@ def _get_cached_manifest(self,
if file_name is None:
raise CachedManifestNotFound(manifest_key)
else:
return self._presign_manifest(generator_cls=generator_cls,
manifest_key=manifest_key,
file_name=file_name,
was_cached=True)

def _presign_manifest(self,
generator_cls: Type['ManifestGenerator'],
manifest_key: ManifestKey,
file_name: Optional[str],
was_cached: bool
) -> Manifest:
return self._make_manifest(generator_cls=generator_cls,
manifest_key=manifest_key,
file_name=file_name,
was_cached=True)

def _make_manifest(self,
generator_cls: Type['ManifestGenerator'],
manifest_key: ManifestKey,
file_name: Optional[str],
was_cached: bool
) -> Manifest:
if not generator_cls.use_content_disposition_file_name:
file_name = None
object_key = generator_cls.s3_object_key(manifest_key)
presigned_url = self.storage_service.get_presigned_url(object_key, file_name)
return Manifest(location=presigned_url,
return Manifest(object_key=object_key,
was_cached=was_cached,
format=generator_cls.format(),
manifest_key=manifest_key,
file_name=file_name)

def get_manifest_url(self, manifest: Manifest) -> str:
return self.storage_service.get_presigned_url(key=manifest.object_key,
file_name=manifest.file_name)

file_name_tag = 'azul_file_name'

def _get_cached_manifest_file_name(self,
Expand Down
22 changes: 13 additions & 9 deletions test/service/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,11 @@ def _get_manifest(self,
) -> Response:
manifest, num_partitions = self._get_manifest_object(format, filters)
self.assertEqual(1, num_partitions)
response = requests.get(manifest.location, stream=stream)
url = furl(self._service.get_manifest_url(manifest))
response = requests.get(str(url), stream=stream)
# Moto doesn't support signed S3 URLs with Content-Disposition baked in,
# so we'll retroactively inject it into the response header.
location = furl(manifest.location)
content_disposition = location.args.get('response-content-disposition')
content_disposition = url.args.get('response-content-disposition')
if content_disposition is not None:
response.headers['content-disposition'] = content_disposition
return response
Expand Down Expand Up @@ -1373,9 +1373,9 @@ def test_manifest_content_disposition_header(self):
manifest, num_partitions = self._get_manifest_object(format, filters)
self.assertFalse(manifest.was_cached)
self.assertEqual(1, num_partitions)
query = furl(manifest.location).query
url = furl(self._service.get_manifest_url(manifest))
expected_cd = f'attachment;filename="{expected_name}.tsv"'
actual_cd = query.params['response-content-disposition']
actual_cd = url.args['response-content-disposition']
self.assertEqual(expected_cd, actual_cd)

def test_verbatim_jsonl_manifest(self):
Expand Down Expand Up @@ -1591,11 +1591,11 @@ def test_get_cached_manifest(self, _time_until_object_expires: MagicMock):
# seconds, the signed URL is going to have a different expiration.
manifest = attrs.evolve(manifest,
was_cached=True,
location=cached_manifest_1.location)
object_key=cached_manifest_1.object_key)
self.assertEqual(manifest, cached_manifest_1)
cached_manifest_2 = self._service.get_cached_manifest_with_key(manifest_key)
cached_manifest_1 = attrs.evolve(cached_manifest_1,
location=cached_manifest_2.location)
object_key=cached_manifest_2.object_key)
self.assertEqual(cached_manifest_1, cached_manifest_2)
_time_until_object_expires.assert_called_once()
_time_until_object_expires.reset_mock()
Expand Down Expand Up @@ -1623,7 +1623,9 @@ class TestManifestResponse(DCP1ManifestTestCase):
@patch.object(ManifestService, 'get_cached_manifest_with_key')
@patch.object(ManifestService, 'sign_manifest_key')
@patch.object(ManifestService, 'verify_manifest_key')
@patch.object(ManifestService, 'get_manifest_url')
def test_manifest(self,
get_manifest_url,
verify_manifest_key,
sign_manifest_key,
get_cached_manifest_with_key,
Expand All @@ -1642,13 +1644,14 @@ def test(*, format: ManifestFormat, fetch: bool, url: Optional[furl] = None):
signed_manifest_key = SignedManifestKey(value=manifest_key, signature=b'123')
sign_manifest_key.return_value = signed_manifest_key
verify_manifest_key.return_value = manifest_key
manifest = Manifest(location=str(object_url),
manifest = Manifest(object_key='key/of/manifest',
was_cached=False,
format=format,
manifest_key=manifest_key,
file_name=default_file_name)
get_cached_manifest.return_value = manifest
get_cached_manifest_with_key.return_value = manifest
get_manifest_url.return_value = object_url
args = dict(catalog=self.catalog,
format=format.value,
filters='{}')
Expand Down Expand Up @@ -1726,7 +1729,8 @@ def test(self):
with patch.object(PagedManifestGenerator, 'part_size', part_size):
manifest, num_partitions = self._get_manifest_object(ManifestFormat.compact,
filters={})
content = requests.get(manifest.location).content
url = self._service.get_manifest_url(manifest)
content = requests.get(url).content
self.assertGreater(num_partitions, 1)
self.assertGreater(len(content), (num_partitions - 1) * part_size)

Expand Down
5 changes: 4 additions & 1 deletion test/service/test_manifest_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def lambda_name(cls) -> str:
@mock.patch.object(ManifestService, 'get_cached_manifest')
@mock.patch.object(ManifestService, 'verify_manifest_key')
@mock.patch.object(ManifestService, 'get_cached_manifest_with_key')
@mock.patch.object(ManifestService, 'get_manifest_url')
def test(self,
get_manifest_url,
get_cached_manifest_with_key,
verify_manifest_key,
get_cached_manifest,
Expand Down Expand Up @@ -193,7 +195,7 @@ def test(self,

object_url = 'https://url.to.manifest?foo=bar'
file_name = 'some_file_name'
manifest = Manifest(location=object_url,
manifest = Manifest(object_key='key/of/manifest',
was_cached=False,
format=format,
manifest_key=manifest_key,
Expand Down Expand Up @@ -279,6 +281,7 @@ def test(self,
_sfn.describe_execution.return_value = {'status': 'SUCCEEDED'}
elif i == 2:
get_manifest.return_value = manifest
get_manifest_url.return_value = object_url
_sfn.start_execution.assert_not_called()
_sfn.describe_execution.assert_called_once()
_sfn.reset_mock()
Expand Down

0 comments on commit 3a64490

Please sign in to comment.