diff --git a/CHANGELOG.md b/CHANGELOG.md index 74c4c454..d882703b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ * Adds missing `pos` attribute to GearNet `required_batch_attributes` (fixes [#73](https://github.com/a-r-j/ProteinWorkshop/issues/73)) [#74](https://github.com/a-r-j/ProteinWorkshop/pull/74) * Fixes PDB download failure due to missing protein data [#77](https://github.com/a-r-j/ProteinWorkshop/pull/77) +* Add support for handling training/validation OOMs gracefully [#81](https://github.com/a-r-j/ProteinWorkshop/pull/81) ### Framework diff --git a/proteinworkshop/models/base.py b/proteinworkshop/models/base.py index 3c0149d0..b45cb9dd 100644 --- a/proteinworkshop/models/base.py +++ b/proteinworkshop/models/base.py @@ -4,6 +4,7 @@ import hydra import lightning as L import torch +import torch.distributed as torch_dist import torch.nn as nn import torch.nn.functional as F from beartype import beartype as typechecker @@ -598,9 +599,89 @@ def _do_step( self.log_metrics(loss, y_hat, y, stage, batch=batch) return loss["total"] + @typechecker + def _do_step_catch_oom( + self, + batch: Batch, + batch_idx: int, + stage: Literal["train", "val"], + ) -> Optional[torch.Tensor]: + """Performs a training/validation step + while catching out of memory errors. + Note that this should not be used for + test steps for proper benchmarking. + + 1. Obtains labels from :py:meth:`get_labels` + 2. Computes model output :py:meth:`forward` + 3. Computes loss :py:meth:`compute_loss` + 4. Logs metrics :py:meth:`log_metrics` + + Returns the total loss. + + :param batch: Mini-batch of data. + :type batch: Batch + :param batch_idx: Index of batch. + :type batch_idx: int + :param stage: Stage of training (``"train"``, ``"val"``) + :type stage: Literal["train", "val"] + :return: Loss + :rtype: torch.Tensor + """ + # by default, do not skip the current batch + skip_flag = torch.zeros( + (), device=self.device, dtype=torch.bool + ) # NOTE: for skipping batches in a multi-device setting + + try: + y = self.get_labels(batch) + y_hat = self(batch) + + loss = self.compute_loss(y_hat, y) + self.log_metrics(loss, y_hat, y, stage, batch=batch) + + except Exception as e: + skip_flag = torch.ones((), device=self.device, dtype=torch.bool) + + if "out of memory" in str(e): + logger.warning( + f"Ran out of memory in the forward pass. Skipping current {stage} batch with index {batch_idx}." + ) + if not torch_dist.is_initialized(): + # NOTE: for skipping batches in a single-device setting + if self.training: + for p in self.trainer.model.parameters(): + if p.grad is not None: + del p.grad # free some memory + return None + else: + if not torch_dist.is_initialized(): + raise e + + # NOTE: for skipping batches in a multi-device setting + # credit: https://github.com/Lightning-AI/lightning/issues/5243#issuecomment-1553404417 + if torch_dist.is_initialized(): + # if any rank skips a batch, then all other ranks need to skip + # their batches as well so DDP can properly keep all ranks synced + world_size = torch_dist.get_world_size() + torch_dist.barrier() + result = [torch.zeros_like(skip_flag) for _ in range(world_size)] + torch_dist.all_gather(result, skip_flag) + any_skipped = torch.sum(torch.stack(result)).bool().item() + if any_skipped: + if self.training: + for p in self.trainer.model.parameters(): + if p.grad is not None: + del p.grad + logger.warning( + f"Failed to perform the forward pass for at least one rank. Skipping {stage} batches for all ranks." + ) + return None + + return loss["total"] + def training_step( self, batch: Union[Batch, ProteinBatch], batch_idx: int - ) -> torch.Tensor: + ) -> Optional[torch.Tensor]: """ Perform training step. @@ -616,13 +697,13 @@ def training_step( :param batch_idx: Index of batch. :type batch_idx: int :return: Loss - :rtype: torch.Tensor + :rtype: Optional[torch.Tensor] """ - return self._do_step(batch, batch_idx, "train") + return self._do_step_catch_oom(batch, batch_idx, "train") def validation_step( self, batch: Union[Batch, ProteinBatch], batch_idx: int - ) -> torch.Tensor: + ) -> Optional[torch.Tensor]: """ Perform validation step. @@ -638,9 +719,9 @@ def validation_step( :param batch_idx: Index of batch. :type batch_idx: int :return: Loss - :rtype: torch.Tensor + :rtype: Optional[torch.Tensor] """ - return self._do_step(batch, batch_idx, "val") + return self._do_step_catch_oom(batch, batch_idx, "val") def test_step( self, batch: Union[Batch, ProteinBatch], batch_idx: int