Skip to content

Commit

Permalink
fix(train): in case of last batch <=2, move to validation if possible (
Browse files Browse the repository at this point in the history
…#3036)

In case that train_size is None and the size of the last batch during
training is <=2 , we adaptively move those samples from training to
validation if possible. If train_size is set by user we do no fix this
error and let the user change its train_size, selected indices or use
drop last batch option.
close #3035

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ori-kron-wis and pre-commit-ci[bot] authored Nov 19, 2024
1 parent 54ba452 commit b08e5df
Show file tree
Hide file tree
Showing 18 changed files with 141 additions and 41 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ to [Semantic Versioning]. Full commit history is available in the

#### Added

- Added adaptive handling for last training minibatch of 1-2 cells in case of
`datasplitter_kwargs={"drop_last": False}` and `train_size = None` by moving them into
validation set, if available.
{pr}`3036`.

#### Fixed

- Breaking Change: Fix `get_outlier_cell_sample_pairs` function in {class}`scvi.external.MRVI`
Expand Down
114 changes: 92 additions & 22 deletions src/scvi/dataloaders/_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@
from scvi.utils._docstrings import devices_dsp


def validate_data_split(n_samples: int, train_size: float, validation_size: float | None = None):
def validate_data_split(
n_samples: int,
train_size: float,
validation_size: float | None = None,
batch_size: int | None = None,
drop_last: bool | int = False,
train_size_is_none: bool | int = True,
):
"""Check data splitting parameters and return n_train and n_val.
Parameters
Expand All @@ -32,21 +39,18 @@ def validate_data_split(n_samples: int, train_size: float, validation_size: floa
Size of train set. Need to be: 0 < train_size <= 1.
validation_size
Size of validation set. Need to be 0 <= validation_size < 1
batch_size
batch size of each iteration. If `None`, do not minibatch
drop_last
drops last non-full batch
train_size_is_none
Whether the user did not explicitly input train_size
"""
if train_size > 1.0 or train_size <= 0.0:
raise ValueError("Invalid train_size. Must be: 0 < train_size <= 1")

n_train = ceil(train_size * n_samples)

if n_train % settings.batch_size < 3 and n_train % settings.batch_size > 0:
warnings.warn(
f"Last batch will have a small size of {n_train % settings.batch_size}"
f"samples. Consider changing settings.batch_size or batch_size in model.train"
f"currently {settings.batch_size} to avoid errors during model training.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)

if validation_size is None:
n_val = n_samples - n_train
elif validation_size >= 1.0 or validation_size < 0.0:
Expand All @@ -59,16 +63,40 @@ def validate_data_split(n_samples: int, train_size: float, validation_size: floa
if n_train == 0:
raise ValueError(
f"With n_samples={n_samples}, train_size={train_size} and "
f"validation_size={validation_size}, the resulting train set will be empty. Adjust"
f"validation_size={validation_size}, the resulting train set will be empty. Adjust "
"any of the aforementioned parameters."
)

if batch_size is not None:
num_of_cells = n_train % batch_size
if (num_of_cells < 3 and num_of_cells > 0) and drop_last is False:
if not train_size_is_none:
warnings.warn(
f"Last batch will have a small size of {num_of_cells} "
f"samples. Consider changing settings.batch_size or batch_size in model.train "
f"from currently {batch_size} to avoid errors during model training.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)
else:
n_train -= num_of_cells
if n_val > 0:
n_val += num_of_cells
warnings.warn(
f"{num_of_cells} cells moved from training set to validation set."
f" if you want to avoid it please use train_size parameter during train.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)

return n_train, n_val


def validate_data_split_with_external_indexing(
n_samples: int,
external_indexing: list[np.array, np.array, np.array] | None = None,
batch_size: int | None = None,
drop_last: bool | int = False,
):
"""Check data splitting parameters and return n_train and n_val.
Expand All @@ -79,6 +107,10 @@ def validate_data_split_with_external_indexing(
external_indexing
A list of data split indices in the order of training, validation, and test sets.
Validation and test set are not required and can be left empty.
batch_size
batch size of each iteration. If `None`, do not minibatch
drop_last
drops last non-full batch
"""
if not isinstance(external_indexing, list):
raise ValueError("External indexing is not of list type")
Expand Down Expand Up @@ -132,6 +164,18 @@ def validate_data_split_with_external_indexing(
n_train = len(external_indexing[0])
n_val = len(external_indexing[1])

if batch_size is not None:
num_of_cells = n_train % batch_size
if (num_of_cells < 3 and num_of_cells > 0) and drop_last is False:
warnings.warn(
f"Last batch will have a small size of {num_of_cells} "
f"samples. Consider changing settings.batch_size or batch_size in model.train "
f"from currently {settings.batch_size} to avoid errors during model training "
f"or change the given external indices accordingly.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)

return n_train, n_val


Expand All @@ -145,7 +189,8 @@ class DataSplitter(pl.LightningDataModule):
adata_manager
:class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``.
train_size
float, or None (default is 0.9)
float, or None (default is None, which is practicaly 0.9 and potentially adding small last
batch to validation cells)
validation_size
float, or None (default is None)
shuffle_set_split
Expand Down Expand Up @@ -182,7 +227,7 @@ class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`,
def __init__(
self,
adata_manager: AnnDataManager,
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
load_sparse_tensor: bool = False,
Expand All @@ -192,7 +237,8 @@ def __init__(
):
super().__init__()
self.adata_manager = adata_manager
self.train_size = float(train_size)
self.train_size_is_none = not bool(train_size)
self.train_size = 0.9 if self.train_size_is_none else float(train_size)
self.validation_size = validation_size
self.shuffle_set_split = shuffle_set_split
self.load_sparse_tensor = load_sparse_tensor
Expand All @@ -205,10 +251,17 @@ def __init__(
self.n_train, self.n_val = validate_data_split_with_external_indexing(
self.adata_manager.adata.n_obs,
self.external_indexing,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
else:
self.n_train, self.n_val = validate_data_split(
self.adata_manager.adata.n_obs, self.train_size, self.validation_size
self.adata_manager.adata.n_obs,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)

def setup(self, stage: str | None = None):
Expand Down Expand Up @@ -298,7 +351,8 @@ class SemiSupervisedDataSplitter(pl.LightningDataModule):
adata_manager
:class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``.
train_size
float, or None (default is 0.9)
float, or None (default is None, which is practicaly 0.9 and potentially adding small last
batch to validation cells)
validation_size
float, or None (default is None)
shuffle_set_split
Expand Down Expand Up @@ -333,7 +387,7 @@ class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`,
def __init__(
self,
adata_manager: AnnDataManager,
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
n_samples_per_label: int | None = None,
Expand All @@ -343,7 +397,8 @@ def __init__(
):
super().__init__()
self.adata_manager = adata_manager
self.train_size = float(train_size)
self.train_size_is_none = not bool(train_size)
self.train_size = 0.9 if self.train_size_is_none else float(train_size)
self.validation_size = validation_size
self.shuffle_set_split = shuffle_set_split
self.drop_last = kwargs.pop("drop_last", False)
Expand Down Expand Up @@ -379,10 +434,17 @@ def setup(self, stage: str | None = None):
n_labeled_train, n_labeled_val = validate_data_split_with_external_indexing(
n_labeled_idx,
[labeled_idx_train, labeled_idx_val, labeled_idx_test],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
else:
n_labeled_train, n_labeled_val = validate_data_split(
n_labeled_idx, self.train_size, self.validation_size
n_labeled_idx,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)

labeled_permutation = self._labeled_indices
Expand Down Expand Up @@ -413,10 +475,17 @@ def setup(self, stage: str | None = None):
n_unlabeled_train, n_unlabeled_val = validate_data_split_with_external_indexing(
n_unlabeled_idx,
[unlabeled_idx_train, unlabeled_idx_val, unlabeled_idx_test],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
else:
n_unlabeled_train, n_unlabeled_val = validate_data_split(
n_unlabeled_idx, self.train_size, self.validation_size
n_unlabeled_idx,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)

unlabeled_permutation = self._unlabeled_indices
Expand Down Expand Up @@ -508,7 +577,8 @@ class DeviceBackedDataSplitter(DataSplitter):
adata_manager
:class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``.
train_size
float, or None (default is 0.9)
float, or None (default is None, which is practicaly 0.9 and potentially adding small last
batch to validation cells)
validation_size
float, or None (default is None)
%(param_accelerator)s
Expand Down Expand Up @@ -536,7 +606,7 @@ class DeviceBackedDataSplitter(DataSplitter):
def __init__(
self,
adata_manager: AnnDataManager,
train_size: float = 1.0,
train_size: float | None = None,
validation_size: float | None = None,
accelerator: str = "auto",
device: int | str = "auto",
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/cellassign/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def train(
lr: float = 3e-3,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 1024,
Expand Down
20 changes: 17 additions & 3 deletions src/scvi/external/contrastivevi/_contrastive_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
adata_manager: AnnDataManager,
background_indices: list[int],
target_indices: list[int],
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
load_sparse_tensor: bool = False,
Expand All @@ -78,10 +78,20 @@ def __init__(
self.n_target = len(target_indices)
if external_indexing is None:
self.n_background_train, self.n_background_val = validate_data_split(
self.n_background, self.train_size, self.validation_size
self.n_background,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)
self.n_target_train, self.n_target_val = validate_data_split(
self.n_target, self.train_size, self.validation_size
self.n_target,
self.train_size,
self.validation_size,
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
self.train_size_is_none,
)
else:
# we need to intersect the external indexing given with the bg/target indices
Expand All @@ -93,6 +103,8 @@ def __init__(
validate_data_split_with_external_indexing(
self.n_background,
[self.background_train_idx, self.background_val_idx, self.background_test_idx],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
)
self.background_train_idx, self.background_val_idx, self.background_test_idx = (
Expand All @@ -107,6 +119,8 @@ def __init__(
self.n_target_train, self.n_target_val = validate_data_split_with_external_indexing(
self.n_target,
[self.target_train_idx, self.target_val_idx, self.target_test_idx],
self.data_loader_kwargs.pop("batch_size", settings.batch_size),
self.drop_last,
)
self.target_train_idx, self.target_val_idx, self.target_test_idx = (
self.target_train_idx.tolist(),
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/contrastivevi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def train(
max_epochs: int | None = None,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
load_sparse_tensor: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/gimvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def train(
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
kappa: int = 5,
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/mrvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def train(
max_epochs: int | None = None,
accelerator: str | None = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
batch_size: int = 128,
early_stopping: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/scbasset/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def train(
lr: float = 0.01,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/solo/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def train(
lr: float = 1e-3,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/velovi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def train(
weight_decay: float = 1e-2,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
batch_size: int = 256,
early_stopping: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/_multivi.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def train(
lr: float = 1e-4,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/_peakvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def train(
lr: float = 1e-4,
accelerator: str = "auto",
devices: int | list[int] | str = "auto",
train_size: float = 0.9,
train_size: float | None = None,
validation_size: float | None = None,
shuffle_set_split: bool = True,
batch_size: int = 128,
Expand Down
Loading

0 comments on commit b08e5df

Please sign in to comment.