Skip to content

Commit

Permalink
call cache bugfix (#101)
Browse files Browse the repository at this point in the history
  • Loading branch information
morsecodist authored Jan 20, 2023
1 parent a923baf commit ece1205
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 20 deletions.
22 changes: 21 additions & 1 deletion miniwdl-plugins/s3upload/miniwdl_s3upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def inode(link: str):

_uploaded_files: Dict[Tuple[int, int], str] = {}
_cached_files: Dict[Tuple[int, int], Tuple[str, Env.Bindings[Value.Base]]] = {}
_key_inputs: Dict[str, Env.Bindings[Value.Base]] = {}
_uploaded_files_lock = threading.Lock()


Expand All @@ -107,8 +108,18 @@ def cache(v: Union[Value.File, Value.Directory]) -> str:
return _uploaded_files[inode(str(v.value))]

remapped_outputs = Value.rewrite_env_paths(outputs, cache)

input_digest = Value.digest_env(
Value.rewrite_env_paths(
_key_inputs[key], lambda v: _uploaded_files.get(inode(str(v.value)), str(v.value))
)
)
key_parts = key.split('/')
key_parts[-1] = input_digest
s3_cache_key = "/".join(key_parts)

if not missing and cfg.has_option("s3_progressive_upload", "uri_prefix"):
uri = os.path.join(get_s3_put_prefix(cfg), "cache", f"{key}.json")
uri = os.path.join(get_s3_put_prefix(cfg), "cache", f"{s3_cache_key}.json")
s3_object(uri).put(Body=json.dumps(values_to_json(remapped_outputs)).encode())
flag_temporary(uri)
logger.info(_("call cache insert", cache_file=uri))
Expand All @@ -118,6 +129,15 @@ class CallCache(cache.CallCache):
def get(
self, key: str, inputs: Env.Bindings[Value.Base], output_types: Env.Bindings[Type.Base]
) -> Optional[Env.Bindings[Value.Base]]:
# HACK: in order to back the call cache in S3 we need to cache the S3 paths to the outputs.
# If we get a cache hit, those S3 paths will be passed to the next step. However,
# the cache key is computed using local inputs so this results in a cache miss.
# we need `put` to use a key based on S3 paths instead but put doesn't have access to step
# inputs. 'put' should always be run after a `get` is called so here we are storing the
# inputs based on the cache key so `put` can get the inputs.
global _key_inputs
_key_inputs[key] = inputs

if not self._cfg.has_option("s3_progressive_upload", "uri_prefix"):
return super().get(key, inputs, output_types)
uri = urlparse(get_s3_get_prefix(self._cfg))
Expand Down
64 changes: 46 additions & 18 deletions test/test_wdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,42 @@
call add_world {
input:
hello = hello,
input_file = hello,
docker_image_id = docker_image_id
}
call add_goodbye {
input:
hello_world = add_world.out,
input_file = add_world.out_world,
docker_image_id = docker_image_id
}
call add_farewell {
input:
input_file = add_goodbye.out_goodbye,
docker_image_id = docker_image_id
}
output {
File out = add_world.out
File out_world = add_world.out_world
File out_goodbye = add_goodbye.out_goodbye
File out_farewell = add_farewell.out_farewell
}
}
task add_world {
input {
File hello
File input_file
String docker_image_id
}
command <<<
cat ~{hello} > out.txt
echo world >> out.txt
cat ~{input_file} > out_world.txt
echo world >> out_world.txt
>>>
output {
File out = "out.txt"
File out_world = "out_world.txt"
}
runtime {
Expand All @@ -58,12 +65,12 @@
task add_goodbye {
input {
File hello_world
File input_file
String docker_image_id
}
command <<<
cat ~{hello_world} > out_goodbye.txt
cat ~{input_file} > out_goodbye.txt
echo goodbye >> out_goodbye.txt
>>>
Expand All @@ -75,6 +82,26 @@
docker: docker_image_id
}
}
task add_farewell {
input {
File input_file
String docker_image_id
}
command <<<
cat ~{input_file} > out_farewell.txt
echo farewell >> out_farewell.txt
>>>
output {
File out_farewell = "out_farewell.txt"
}
runtime {
docker: docker_image_id
}
}
"""

test_fail_wdl = """
Expand Down Expand Up @@ -161,7 +188,7 @@

test_stage_io_map = {
"Two": {
"hello_world": "out",
"hello_world": "out_world",
},
}

Expand Down Expand Up @@ -301,11 +328,12 @@ def test_simple_sfn_wdl_workflow(self):

output = json.loads(description["output"])
self.assertEqual(output["Result"], {
"swipe_test.out": f"s3://{self.input_obj.bucket_name}/{output_prefix}/test-1/out.txt",
"swipe_test.out_world": f"s3://{self.input_obj.bucket_name}/{output_prefix}/test-1/out_world.txt",
"swipe_test.out_goodbye": f"s3://{self.input_obj.bucket_name}/{output_prefix}/test-1/out_goodbye.txt",
"swipe_test.out_farewell": f"s3://{self.input_obj.bucket_name}/{output_prefix}/test-1/out_farewell.txt",
})

outputs_obj = self.test_bucket.Object(f"{output_prefix}/test-1/out.txt")
outputs_obj = self.test_bucket.Object(f"{output_prefix}/test-1/out_world.txt")
output_text = outputs_obj.get()["Body"].read().decode()
self.assertEqual(output_text, "hello\nworld\n")

Expand Down Expand Up @@ -384,19 +412,19 @@ def test_call_cache(self):
self.sqs.receive_message(
QueueUrl=self.state_change_queue_url, MaxNumberOfMessages=1
)
outputs_obj = self.test_bucket.Object(f"{output_prefix}/test-1/out.txt")
outputs_obj = self.test_bucket.Object(f"{output_prefix}/test-1/out_world.txt")
output_text = outputs_obj.get()["Body"].read().decode()
self.assertEqual(output_text, "hello\nworld\n")

self.test_bucket.Object(f"{output_prefix}/test-1/out.txt").put(
self.test_bucket.Object(f"{output_prefix}/test-1/out_goodbye.txt").put(
Body="cache_break\n".encode()
)
self.test_bucket.Object(f"{output_prefix}/test-1/out_goodbye.txt").delete()
self.test_bucket.Object(f"{output_prefix}/test-1/out_farewell.txt").delete()

# clear cache to simulate getting cut off the step before this one
objects = self.s3_client.list_objects_v2(
Bucket=self.test_bucket.name,
Prefix=f"{output_prefix}/test-1/cache/add_goodbye/",
Prefix=f"{output_prefix}/test-1/cache/add_farewell/",
)["Contents"]
self.test_bucket.Object(objects[0]["Key"]).delete()
objects = self.s3_client.list_objects_v2(
Expand All @@ -412,9 +440,9 @@ def test_call_cache(self):
for v in outputs.values():
self.assert_(v.startswith("s3://"), f"{v} does not start with 's3://'")

outputs_obj = self.test_bucket.Object(f"{output_prefix}/test-1/out_goodbye.txt")
outputs_obj = self.test_bucket.Object(f"{output_prefix}/test-1/out_farewell.txt")
output_text = outputs_obj.get()["Body"].read().decode()
self.assertEqual(output_text, "cache_break\ngoodbye\n")
self.assertEqual(output_text, "cache_break\nfarewell\n")

def test_zip_wdls(self):
output_prefix = "zip-output"
Expand Down
2 changes: 1 addition & 1 deletion version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v1.3.0
v1.3.1

0 comments on commit ece1205

Please sign in to comment.