Skip to content

Commit

Permalink
bug: Batch Size is not utilised properly with the use of external ind…
Browse files Browse the repository at this point in the history
…ices following last batch fix. (#3128)

close #3123

continuing: #3036
  • Loading branch information
ori-kron-wis authored Jan 8, 2025
1 parent a6deb6b commit a435561
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ to [Semantic Versioning]. Full commit history is available in the

#### Fixed

- Fixed bug in distributed `scvi.dataloaders._concat_dataloader` {pr}`3053`.
- Fixed bug in distributed {class}`scvi.dataloaders.ConcatDataLoader` {pr}`3053`.

#### Changed

Expand All @@ -37,6 +37,8 @@ to [Semantic Versioning]. Full commit history is available in the

#### Fixed

- Fixed batch_size pop to get in {class}`scvi.dataloaders.DataSplitter` {pr}`3128`.

#### Changed

- Updated the CI workflow with internet, private and optional tests {pr}`3082`.
Expand Down
12 changes: 6 additions & 6 deletions src/scvi/dataloaders/_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,15 @@ def __init__(
self.n_train, self.n_val = validate_data_split_with_external_indexing(
self.adata_manager.adata.n_obs,
self.external_indexing,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.data_loader_kwargs.get("batch_size", settings.batch_size),
self.drop_last,
)
else:
self.n_train, self.n_val = validate_data_split(
self.adata_manager.adata.n_obs,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.data_loader_kwargs.get("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)
Expand Down Expand Up @@ -434,15 +434,15 @@ def setup(self, stage: str | None = None):
n_labeled_train, n_labeled_val = validate_data_split_with_external_indexing(
n_labeled_idx,
[labeled_idx_train, labeled_idx_val, labeled_idx_test],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.data_loader_kwargs.get("batch_size", settings.batch_size),
self.drop_last,
)
else:
n_labeled_train, n_labeled_val = validate_data_split(
n_labeled_idx,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.data_loader_kwargs.get("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)
Expand Down Expand Up @@ -475,15 +475,15 @@ def setup(self, stage: str | None = None):
n_unlabeled_train, n_unlabeled_val = validate_data_split_with_external_indexing(
n_unlabeled_idx,
[unlabeled_idx_train, unlabeled_idx_val, unlabeled_idx_test],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.data_loader_kwargs.get("batch_size", settings.batch_size),
self.drop_last,
)
else:
n_unlabeled_train, n_unlabeled_val = validate_data_split(
n_unlabeled_idx,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.data_loader_kwargs.get("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)
Expand Down
7 changes: 2 additions & 5 deletions tests/model/test_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,7 @@ def test_scvi_n_obs_error(n_latent: int = 5):
with pytest.raises(ValueError):
model.train(1, train_size=1.0)
with pytest.raises(ValueError):
# Warning is emitted if last batch less than 3 cells + failure.
model.train(1, train_size=1.0, batch_size=127)
model.train(1, train_size=1.0, batch_size=128)
model.train(1, train_size=1.0, datasplitter_kwargs={"drop_last": True})

adata = synthetic_iid()
Expand All @@ -484,9 +483,7 @@ def test_scvi_n_obs_error(n_latent: int = 5):
model = SCVI(adata, n_latent=n_latent)
with pytest.raises(ValueError):
model.train(1, train_size=0.9) # np.ceil(n_cells * 0.9) % 128 == 1
model.train(
1, train_size=0.9, datasplitter_kwargs={"drop_last": True}
) # np.ceil(n_cells * 0.9) % 128 == 1
model.train(1, train_size=0.9, datasplitter_kwargs={"drop_last": True})
model.train(1)
assert model.is_trained is True

Expand Down

0 comments on commit a435561

Please sign in to comment.