From a435561ebab92410b8cb82e7a18dde5510222387 Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Wed, 8 Jan 2025 09:40:11 +0200 Subject: [PATCH] bug: Batch Size is not utilised properly with the use of external indices following last batch fix. (#3128) close https://github.com/scverse/scvi-tools/issues/3123 continuing: https://github.com/scverse/scvi-tools/pull/3036 --- CHANGELOG.md | 4 +++- src/scvi/dataloaders/_data_splitting.py | 12 ++++++------ tests/model/test_scvi.py | 7 ++----- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e362b329f5..9e914cc8ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ to [Semantic Versioning]. Full commit history is available in the #### Fixed -- Fixed bug in distributed `scvi.dataloaders._concat_dataloader` {pr}`3053`. +- Fixed bug in distributed {class}`scvi.dataloaders.ConcatDataLoader` {pr}`3053`. #### Changed @@ -37,6 +37,8 @@ to [Semantic Versioning]. Full commit history is available in the #### Fixed +- Fixed batch_size pop to get in {class}`scvi.dataloaders.DataSplitter` {pr}`3128`. + #### Changed - Updated the CI workflow with internet, private and optional tests {pr}`3082`. diff --git a/src/scvi/dataloaders/_data_splitting.py b/src/scvi/dataloaders/_data_splitting.py index b4b54d8c7f..9ea0146acb 100644 --- a/src/scvi/dataloaders/_data_splitting.py +++ b/src/scvi/dataloaders/_data_splitting.py @@ -251,7 +251,7 @@ def __init__( self.n_train, self.n_val = validate_data_split_with_external_indexing( self.adata_manager.adata.n_obs, self.external_indexing, - self.data_loader_kwargs.pop("batch_size", settings.batch_size), + self.data_loader_kwargs.get("batch_size", settings.batch_size), self.drop_last, ) else: @@ -259,7 +259,7 @@ def __init__( self.adata_manager.adata.n_obs, self.train_size, self.validation_size, - self.data_loader_kwargs.pop("batch_size", settings.batch_size), + self.data_loader_kwargs.get("batch_size", settings.batch_size), self.drop_last, self.train_size_is_none, ) @@ -434,7 +434,7 @@ def setup(self, stage: str | None = None): n_labeled_train, n_labeled_val = validate_data_split_with_external_indexing( n_labeled_idx, [labeled_idx_train, labeled_idx_val, labeled_idx_test], - self.data_loader_kwargs.pop("batch_size", settings.batch_size), + self.data_loader_kwargs.get("batch_size", settings.batch_size), self.drop_last, ) else: @@ -442,7 +442,7 @@ def setup(self, stage: str | None = None): n_labeled_idx, self.train_size, self.validation_size, - self.data_loader_kwargs.pop("batch_size", settings.batch_size), + self.data_loader_kwargs.get("batch_size", settings.batch_size), self.drop_last, self.train_size_is_none, ) @@ -475,7 +475,7 @@ def setup(self, stage: str | None = None): n_unlabeled_train, n_unlabeled_val = validate_data_split_with_external_indexing( n_unlabeled_idx, [unlabeled_idx_train, unlabeled_idx_val, unlabeled_idx_test], - self.data_loader_kwargs.pop("batch_size", settings.batch_size), + self.data_loader_kwargs.get("batch_size", settings.batch_size), self.drop_last, ) else: @@ -483,7 +483,7 @@ def setup(self, stage: str | None = None): n_unlabeled_idx, self.train_size, self.validation_size, - self.data_loader_kwargs.pop("batch_size", settings.batch_size), + self.data_loader_kwargs.get("batch_size", settings.batch_size), self.drop_last, self.train_size_is_none, ) diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py index 0f1e4c876a..b76a6f6b3d 100644 --- a/tests/model/test_scvi.py +++ b/tests/model/test_scvi.py @@ -474,8 +474,7 @@ def test_scvi_n_obs_error(n_latent: int = 5): with pytest.raises(ValueError): model.train(1, train_size=1.0) with pytest.raises(ValueError): - # Warning is emitted if last batch less than 3 cells + failure. - model.train(1, train_size=1.0, batch_size=127) + model.train(1, train_size=1.0, batch_size=128) model.train(1, train_size=1.0, datasplitter_kwargs={"drop_last": True}) adata = synthetic_iid() @@ -484,9 +483,7 @@ def test_scvi_n_obs_error(n_latent: int = 5): model = SCVI(adata, n_latent=n_latent) with pytest.raises(ValueError): model.train(1, train_size=0.9) # np.ceil(n_cells * 0.9) % 128 == 1 - model.train( - 1, train_size=0.9, datasplitter_kwargs={"drop_last": True} - ) # np.ceil(n_cells * 0.9) % 128 == 1 + model.train(1, train_size=0.9, datasplitter_kwargs={"drop_last": True}) model.train(1) assert model.is_trained is True