From d105f9ff6614590244dd3735fe92b7587eac7151 Mon Sep 17 00:00:00 2001 From: Justin Hong <justin.hong@columbia.edu> Date: Tue, 6 Aug 2024 23:04:15 -0400 Subject: [PATCH] fix[train] flag to make jax deterministic if seed is manually set (#2923) Addresses #2911 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Can Ergen <canergen.ac@gmail.com> (cherry picked from commit 47229520fd6d93eaca025d200160ba035a8abf66) --- src/scvi/_settings.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/scvi/_settings.py b/src/scvi/_settings.py index 2644737f55..37f2dcff7b 100644 --- a/src/scvi/_settings.py +++ b/src/scvi/_settings.py @@ -136,6 +136,11 @@ def seed(self, seed: Union[int, None] = None): else: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False + # Ensure deterministic CUDA operations for Jax (see https://github.com/google/jax/issues/13672) + if "XLA_FLAGS" not in os.environ: + os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true" + else: + os.environ["XLA_FLAGS"] += " --xla_gpu_deterministic_ops=true" seed_everything(seed) self._seed = seed