Skip to content

Commit

Permalink
persistent_workers SCVI (#2924)
Browse files Browse the repository at this point in the history
opened persistent_workers (a pytorch dataloader parameter in case of
num_workers>0) to scvi settings, the same as we did for num_workers &
the same as was done in CZI

---------

Co-authored-by: Can Ergen <[email protected]>
  • Loading branch information
ori-kron-wis and canergen authored Aug 7, 2024
1 parent 7effcfb commit dd0df00
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ to [Semantic Versioning]. Full commit history is available in the

#### Added

- {attr}`scvi.settings.dl_persistent_workers` allows using persistent workers in
{class}`scvi.dataloaders.AnnDataLoader` {pr}`2924`.
- Add option for using external indexes in data splitting classes that are under `scvi.dataloaders`
by passing `external_indexing=list[train_idx,valid_idx,test_idx]` as well as in all models
available {pr}`2902`.
Expand Down
12 changes: 12 additions & 0 deletions src/scvi/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
seed: Optional[int] = None,
logging_dir: str = "./scvi_log/",
dl_num_workers: int = 0,
dl_persistent_workers: bool = False,
jax_preallocate_gpu_memory: bool = False,
warnings_stacklevel: int = 2,
):
Expand All @@ -61,6 +62,7 @@ def __init__(
self.progress_bar_style = progress_bar_style
self.logging_dir = logging_dir
self.dl_num_workers = dl_num_workers
self.dl_persistent_workers = dl_persistent_workers
self._num_threads = None
self.jax_preallocate_gpu_memory = jax_preallocate_gpu_memory
self.verbosity = verbosity
Expand Down Expand Up @@ -93,6 +95,16 @@ def dl_num_workers(self, dl_num_workers: int):
"""Number of workers for PyTorch data loaders (Default is 0)."""
self._dl_num_workers = dl_num_workers

@property
def dl_persistent_workers(self) -> bool:
"""Whether to use persistent_workers in PyTorch data loaders (Default is False)."""
return self._dl_persistent_workers

@dl_persistent_workers.setter
def dl_persistent_workers(self, dl_persistent_workers: bool):
"""Whether to use persistent_workers in PyTorch data loaders (Default is False)."""
self._dl_persistent_workers = dl_persistent_workers

@property
def logging_dir(self) -> Path:
"""Directory for training logs (default `'./scvi_log/'`)."""
Expand Down
2 changes: 2 additions & 0 deletions src/scvi/dataloaders/_ann_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def __init__(
)
if "num_workers" not in kwargs:
kwargs["num_workers"] = settings.dl_num_workers
if "persistent_workers" not in kwargs:
kwargs["persistent_workers"] = settings.dl_persistent_workers

self.kwargs = copy.deepcopy(kwargs)

Expand Down
2 changes: 2 additions & 0 deletions src/scvi/model/base/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,8 @@ def _make_data_loader(

if "num_workers" not in data_loader_kwargs:
data_loader_kwargs.update({"num_workers": settings.dl_num_workers})
if "persistent_workers" not in data_loader_kwargs:
data_loader_kwargs.update({"persistent_workers": settings.dl_persistent_workers})

dl = data_loader_class(
adata_manager,
Expand Down
15 changes: 15 additions & 0 deletions tests/model/test_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,3 +1093,18 @@ def test_scvi_normal_likelihood():
model.get_reconstruction_error()
model.get_normalized_expression(transform_batch="batch_1")
model.get_normalized_expression(n_samples=2)


def test_scvi_num_workers():
adata = synthetic_iid()
scvi.settings.dl_num_workers = 7
scvi.settings.dl_persistent_workers = True
SCVI.setup_anndata(adata, batch_key="batch")

model = SCVI(adata)
model.train(max_epochs=1, accelerator="cpu")
model.get_elbo()
model.get_marginal_ll(n_mc_samples=3)
model.get_reconstruction_error()
model.get_normalized_expression(transform_batch="batch_1")
model.get_normalized_expression(n_samples=2)

0 comments on commit dd0df00

Please sign in to comment.