From d105f9ff6614590244dd3735fe92b7587eac7151 Mon Sep 17 00:00:00 2001
From: Justin Hong <justin.hong@columbia.edu>
Date: Tue, 6 Aug 2024 23:04:15 -0400
Subject: [PATCH] fix[train] flag to make jax deterministic if seed is manually
 set (#2923)

Addresses #2911

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Can Ergen <canergen.ac@gmail.com>
(cherry picked from commit 47229520fd6d93eaca025d200160ba035a8abf66)
---
 src/scvi/_settings.py | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/src/scvi/_settings.py b/src/scvi/_settings.py
index 2644737f55..37f2dcff7b 100644
--- a/src/scvi/_settings.py
+++ b/src/scvi/_settings.py
@@ -136,6 +136,11 @@ def seed(self, seed: Union[int, None] = None):
         else:
             torch.backends.cudnn.deterministic = True
             torch.backends.cudnn.benchmark = False
+            # Ensure deterministic CUDA operations for Jax (see https://github.com/google/jax/issues/13672)
+            if "XLA_FLAGS" not in os.environ:
+                os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
+            else:
+                os.environ["XLA_FLAGS"] += " --xla_gpu_deterministic_ops=true"
             seed_everything(seed)
             self._seed = seed