diff --git a/sdgym/synthesizers/realtabformer.py b/sdgym/synthesizers/realtabformer.py index cf4e25e2..4b14e469 100644 --- a/sdgym/synthesizers/realtabformer.py +++ b/sdgym/synthesizers/realtabformer.py @@ -37,7 +37,7 @@ def _get_trained_synthesizer(self, data, metadata): with prevent_tqdm_output(): os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" - os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0" + os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = 0.0 model = REaLTabFormer(model_type='tabular') model.fit(data, device='cpu') LOGGER.debug('PYTORCH_ENABLE_MPS_FALLBACK') @@ -51,7 +51,7 @@ def _get_trained_synthesizer(self, data, metadata): def _sample_from_synthesizer(self, synthesizer, n_sample): """Sample synthetic data with specified sample count.""" - os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0" + os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = 0.0 os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" LOGGER.debug('PYTORCH_ENABLE_MPS_FALLBACK') LOGGER.debug(os.getenv('PYTORCH_ENABLE_MPS_FALLBACK'))