From dd0df00dc0b6122dcf15554a485b415c8289a73a Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Wed, 7 Aug 2024 06:02:21 +0300 Subject: [PATCH] persistent_workers SCVI (#2924) opened persistent_workers (a pytorch dataloader parameter in case of num_workers>0) to scvi settings, the same as we did for num_workers & the same as was done in CZI --------- Co-authored-by: Can Ergen --- CHANGELOG.md | 2 ++ src/scvi/_settings.py | 12 ++++++++++++ src/scvi/dataloaders/_ann_dataloader.py | 2 ++ src/scvi/model/base/_base_model.py | 2 ++ tests/model/test_scvi.py | 15 +++++++++++++++ 5 files changed, 33 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 76069f7192..c06b954669 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ to [Semantic Versioning]. Full commit history is available in the #### Added +- {attr}`scvi.settings.dl_persistent_workers` allows using persistent workers in + {class}`scvi.dataloaders.AnnDataLoader` {pr}`2924`. - Add option for using external indexes in data splitting classes that are under `scvi.dataloaders` by passing `external_indexing=list[train_idx,valid_idx,test_idx]` as well as in all models available {pr}`2902`. diff --git a/src/scvi/_settings.py b/src/scvi/_settings.py index 2644737f55..bc989b4db0 100644 --- a/src/scvi/_settings.py +++ b/src/scvi/_settings.py @@ -50,6 +50,7 @@ def __init__( seed: Optional[int] = None, logging_dir: str = "./scvi_log/", dl_num_workers: int = 0, + dl_persistent_workers: bool = False, jax_preallocate_gpu_memory: bool = False, warnings_stacklevel: int = 2, ): @@ -61,6 +62,7 @@ def __init__( self.progress_bar_style = progress_bar_style self.logging_dir = logging_dir self.dl_num_workers = dl_num_workers + self.dl_persistent_workers = dl_persistent_workers self._num_threads = None self.jax_preallocate_gpu_memory = jax_preallocate_gpu_memory self.verbosity = verbosity @@ -93,6 +95,16 @@ def dl_num_workers(self, dl_num_workers: int): """Number of workers for PyTorch data loaders (Default is 0).""" self._dl_num_workers = dl_num_workers + @property + def dl_persistent_workers(self) -> bool: + """Whether to use persistent_workers in PyTorch data loaders (Default is False).""" + return self._dl_persistent_workers + + @dl_persistent_workers.setter + def dl_persistent_workers(self, dl_persistent_workers: bool): + """Whether to use persistent_workers in PyTorch data loaders (Default is False).""" + self._dl_persistent_workers = dl_persistent_workers + @property def logging_dir(self) -> Path: """Directory for training logs (default `'./scvi_log/'`).""" diff --git a/src/scvi/dataloaders/_ann_dataloader.py b/src/scvi/dataloaders/_ann_dataloader.py index 70fb2045cd..d10f8156de 100644 --- a/src/scvi/dataloaders/_ann_dataloader.py +++ b/src/scvi/dataloaders/_ann_dataloader.py @@ -101,6 +101,8 @@ def __init__( ) if "num_workers" not in kwargs: kwargs["num_workers"] = settings.dl_num_workers + if "persistent_workers" not in kwargs: + kwargs["persistent_workers"] = settings.dl_persistent_workers self.kwargs = copy.deepcopy(kwargs) diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index e991da3eb3..09d597ef05 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -434,6 +434,8 @@ def _make_data_loader( if "num_workers" not in data_loader_kwargs: data_loader_kwargs.update({"num_workers": settings.dl_num_workers}) + if "persistent_workers" not in data_loader_kwargs: + data_loader_kwargs.update({"persistent_workers": settings.dl_persistent_workers}) dl = data_loader_class( adata_manager, diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py index b2cc1a6791..537e2b9e5e 100644 --- a/tests/model/test_scvi.py +++ b/tests/model/test_scvi.py @@ -1093,3 +1093,18 @@ def test_scvi_normal_likelihood(): model.get_reconstruction_error() model.get_normalized_expression(transform_batch="batch_1") model.get_normalized_expression(n_samples=2) + + +def test_scvi_num_workers(): + adata = synthetic_iid() + scvi.settings.dl_num_workers = 7 + scvi.settings.dl_persistent_workers = True + SCVI.setup_anndata(adata, batch_key="batch") + + model = SCVI(adata) + model.train(max_epochs=1, accelerator="cpu") + model.get_elbo() + model.get_marginal_ll(n_mc_samples=3) + model.get_reconstruction_error() + model.get_normalized_expression(transform_batch="batch_1") + model.get_normalized_expression(n_samples=2)