Skip to content

Commit

Permalink
fix[train] flag to make jax deterministic if seed is manually set (#2923
Browse files Browse the repository at this point in the history
)

Addresses #2911

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Can Ergen <[email protected]>
(cherry picked from commit 4722952)
  • Loading branch information
justjhong authored and canergen committed Aug 7, 2024
1 parent 45a7fdd commit d105f9f
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/scvi/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ def seed(self, seed: Union[int, None] = None):
else:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Ensure deterministic CUDA operations for Jax (see https://github.com/google/jax/issues/13672)
if "XLA_FLAGS" not in os.environ:
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
else:
os.environ["XLA_FLAGS"] += " --xla_gpu_deterministic_ops=true"
seed_everything(seed)
self._seed = seed

Expand Down

0 comments on commit d105f9f

Please sign in to comment.