Skip to content

Commit

Permalink
Set PYTORCH_MPS_HIGH_WATERMARK_RATIO as float
Browse files Browse the repository at this point in the history
  • Loading branch information
cristid9 committed Dec 4, 2024
1 parent ebfd0d8 commit 63683d4
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions sdgym/synthesizers/realtabformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _get_trained_synthesizer(self, data, metadata):

with prevent_tqdm_output():
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = 0.0
model = REaLTabFormer(model_type='tabular')
model.fit(data, device='cpu')
LOGGER.debug('PYTORCH_ENABLE_MPS_FALLBACK')
Expand All @@ -51,7 +51,7 @@ def _get_trained_synthesizer(self, data, metadata):

def _sample_from_synthesizer(self, synthesizer, n_sample):
"""Sample synthetic data with specified sample count."""
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = 0.0
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
LOGGER.debug('PYTORCH_ENABLE_MPS_FALLBACK')
LOGGER.debug(os.getenv('PYTORCH_ENABLE_MPS_FALLBACK'))
Expand Down

0 comments on commit 63683d4

Please sign in to comment.