Skip to content

Commit

Permalink
ci: Manual Backport of PR 3128 to 1.2.x (#3130)
Browse files Browse the repository at this point in the history
pop get bug fix data splitter backport 1.2.x

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: ori-kron-wis <[email protected]>
Co-authored-by: Lumberbot (aka Jack) <[email protected]>
Co-authored-by: Martin Kim <[email protected]>
Co-authored-by: Can Ergen <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Justin Hong <[email protected]>
Co-authored-by: Yishen Miao <[email protected]>
Co-authored-by: Ramon Viñas <[email protected]>
Co-authored-by: Martin Kim <[email protected]>
Co-authored-by: Ethan Weinberger <[email protected]>
Co-authored-by: Ethan Weinberger <[email protected]>
Co-authored-by: access <[email protected]>
  • Loading branch information
14 people authored Jan 8, 2025
1 parent 32b0b5c commit 935b56b
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 13 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ to [Semantic Versioning]. Full commit history is available in the

#### Fixed

- Fixed batch_size pop to get in {class}`scvi.dataloaders.DataSplitter` {pr}`3128`.

#### Changed

- Updated the CI workflow with internet, private and optional tests {pr}`3082`.
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ requires = ["hatchling"]

[project]
name = "scvi-tools"
version = "1.2.2.post1"
version = "1.2.2.post2"
description = "Deep probabilistic analysis of single-cell omics data."
readme = "README.md"
requires-python = ">=3.10"
Expand Down
12 changes: 6 additions & 6 deletions src/scvi/dataloaders/_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,15 @@ def __init__(
self.n_train, self.n_val = validate_data_split_with_external_indexing(
self.adata_manager.adata.n_obs,
self.external_indexing,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.data_loader_kwargs.get("batch_size", settings.batch_size),
self.drop_last,
)
else:
self.n_train, self.n_val = validate_data_split(
self.adata_manager.adata.n_obs,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.data_loader_kwargs.get("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)
Expand Down Expand Up @@ -434,15 +434,15 @@ def setup(self, stage: str | None = None):
n_labeled_train, n_labeled_val = validate_data_split_with_external_indexing(
n_labeled_idx,
[labeled_idx_train, labeled_idx_val, labeled_idx_test],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.data_loader_kwargs.get("batch_size", settings.batch_size),
self.drop_last,
)
else:
n_labeled_train, n_labeled_val = validate_data_split(
n_labeled_idx,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.data_loader_kwargs.get("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)
Expand Down Expand Up @@ -475,15 +475,15 @@ def setup(self, stage: str | None = None):
n_unlabeled_train, n_unlabeled_val = validate_data_split_with_external_indexing(
n_unlabeled_idx,
[unlabeled_idx_train, unlabeled_idx_val, unlabeled_idx_test],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.data_loader_kwargs.get("batch_size", settings.batch_size),
self.drop_last,
)
else:
n_unlabeled_train, n_unlabeled_val = validate_data_split(
n_unlabeled_idx,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.data_loader_kwargs.get("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)
Expand Down
7 changes: 2 additions & 5 deletions tests/model/test_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,7 @@ def test_scvi_n_obs_error(n_latent: int = 5):
with pytest.raises(ValueError):
model.train(1, train_size=1.0)
with pytest.raises(ValueError):
# Warning is emitted if last batch less than 3 cells + failure.
model.train(1, train_size=1.0, batch_size=127)
model.train(1, train_size=1.0, batch_size=128)
model.train(1, train_size=1.0, datasplitter_kwargs={"drop_last": True})

adata = synthetic_iid()
Expand All @@ -484,9 +483,7 @@ def test_scvi_n_obs_error(n_latent: int = 5):
model = SCVI(adata, n_latent=n_latent)
with pytest.raises(ValueError):
model.train(1, train_size=0.9) # np.ceil(n_cells * 0.9) % 128 == 1
model.train(
1, train_size=0.9, datasplitter_kwargs={"drop_last": True}
) # np.ceil(n_cells * 0.9) % 128 == 1
model.train(1, train_size=0.9, datasplitter_kwargs={"drop_last": True})
model.train(1)
assert model.is_trained is True

Expand Down

0 comments on commit 935b56b

Please sign in to comment.