diff --git a/CHANGELOG.md b/CHANGELOG.md index 818b3bfd..d882703b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ * Adds missing `pos` attribute to GearNet `required_batch_attributes` (fixes [#73](https://github.com/a-r-j/ProteinWorkshop/issues/73)) [#74](https://github.com/a-r-j/ProteinWorkshop/pull/74) * Fixes PDB download failure due to missing protein data [#77](https://github.com/a-r-j/ProteinWorkshop/pull/77) +* Add support for handling training/validation OOMs gracefully [#81](https://github.com/a-r-j/ProteinWorkshop/pull/81) ### Framework @@ -16,6 +17,8 @@ ### Command * Adds `--force-cuda-version` to `workshop install` [#78](https://github.com/a-r-j/ProteinWorkshop/pull/78) +### Features +* Fix `sequence_edges` behaviour when argument `b` is a `Data` object [#80](https://github.com/a-r-j/ProteinWorkshop/pull/80) ### 0.2.5 (28/12/2023) diff --git a/README.md b/README.md index d89ef051..8836ea20 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,8 @@ [Documentation](https://www.proteins.sh) -This [repository](https://github.com/a-r-j/ProteinWorkshop) provides the code for the protein structure representation learning benchmark detailed in the paper [*Evaluating Representation Learning on the Protein Structure Universe*](https://openreview.net/forum?id=sTYuRVrdK3). + +This [repository](https://github.com/a-r-j/ProteinWorkshop) provides the code for the protein structure representation learning benchmark detailed in the paper [*Evaluating Representation Learning on the Protein Structure Universe*](https://openreview.net/forum?id=sTYuRVrdK3) (ICLR 2024). In the benchmark, we implement numerous [featurisation](https://www.proteins.sh/configs/features) schemes, [datasets](https://www.proteins.sh/configs/dataset) for [self-supervised pre-training](https://proteins.sh/quickstart_component/pretrain.html) and [downstream evaluation](https://proteins.sh/quickstart_component/downstream.html), [pre-training](https://proteins.sh/configs/task) tasks, and [auxiliary tasks](https://proteins.sh/configs/task.html#auxiliary-tasks). diff --git a/proteinworkshop/features/edges.py b/proteinworkshop/features/edges.py index 4c888892..c375cbe2 100644 --- a/proteinworkshop/features/edges.py +++ b/proteinworkshop/features/edges.py @@ -104,7 +104,7 @@ def sequence_edges( idx_b = torch.arange(1, b.ptr[-1], device=b.ptr.device) elif isinstance(b, Data): idx_a = torch.arange(0, b.coords.shape[0] - 1, device=b.coords.device) - idx_a = torch.arange(1, b.coords.shape[0] - 1, device=b.coords.device) + idx_b = torch.arange(1, b.coords.shape[0], device=b.coords.device) # Concatenate indices to create edge list if direction == "forward": e_index = torch.stack([idx_a, idx_b], dim=0) diff --git a/proteinworkshop/models/base.py b/proteinworkshop/models/base.py index 3c0149d0..b45cb9dd 100644 --- a/proteinworkshop/models/base.py +++ b/proteinworkshop/models/base.py @@ -4,6 +4,7 @@ import hydra import lightning as L import torch +import torch.distributed as torch_dist import torch.nn as nn import torch.nn.functional as F from beartype import beartype as typechecker @@ -598,9 +599,89 @@ def _do_step( self.log_metrics(loss, y_hat, y, stage, batch=batch) return loss["total"] + @typechecker + def _do_step_catch_oom( + self, + batch: Batch, + batch_idx: int, + stage: Literal["train", "val"], + ) -> Optional[torch.Tensor]: + """Performs a training/validation step + while catching out of memory errors. + Note that this should not be used for + test steps for proper benchmarking. + + 1. Obtains labels from :py:meth:`get_labels` + 2. Computes model output :py:meth:`forward` + 3. Computes loss :py:meth:`compute_loss` + 4. Logs metrics :py:meth:`log_metrics` + + Returns the total loss. + + :param batch: Mini-batch of data. + :type batch: Batch + :param batch_idx: Index of batch. + :type batch_idx: int + :param stage: Stage of training (``"train"``, ``"val"``) + :type stage: Literal["train", "val"] + :return: Loss + :rtype: torch.Tensor + """ + # by default, do not skip the current batch + skip_flag = torch.zeros( + (), device=self.device, dtype=torch.bool + ) # NOTE: for skipping batches in a multi-device setting + + try: + y = self.get_labels(batch) + y_hat = self(batch) + + loss = self.compute_loss(y_hat, y) + self.log_metrics(loss, y_hat, y, stage, batch=batch) + + except Exception as e: + skip_flag = torch.ones((), device=self.device, dtype=torch.bool) + + if "out of memory" in str(e): + logger.warning( + f"Ran out of memory in the forward pass. Skipping current {stage} batch with index {batch_idx}." + ) + if not torch_dist.is_initialized(): + # NOTE: for skipping batches in a single-device setting + if self.training: + for p in self.trainer.model.parameters(): + if p.grad is not None: + del p.grad # free some memory + return None + else: + if not torch_dist.is_initialized(): + raise e + + # NOTE: for skipping batches in a multi-device setting + # credit: https://github.com/Lightning-AI/lightning/issues/5243#issuecomment-1553404417 + if torch_dist.is_initialized(): + # if any rank skips a batch, then all other ranks need to skip + # their batches as well so DDP can properly keep all ranks synced + world_size = torch_dist.get_world_size() + torch_dist.barrier() + result = [torch.zeros_like(skip_flag) for _ in range(world_size)] + torch_dist.all_gather(result, skip_flag) + any_skipped = torch.sum(torch.stack(result)).bool().item() + if any_skipped: + if self.training: + for p in self.trainer.model.parameters(): + if p.grad is not None: + del p.grad + logger.warning( + f"Failed to perform the forward pass for at least one rank. Skipping {stage} batches for all ranks." + ) + return None + + return loss["total"] + def training_step( self, batch: Union[Batch, ProteinBatch], batch_idx: int - ) -> torch.Tensor: + ) -> Optional[torch.Tensor]: """ Perform training step. @@ -616,13 +697,13 @@ def training_step( :param batch_idx: Index of batch. :type batch_idx: int :return: Loss - :rtype: torch.Tensor + :rtype: Optional[torch.Tensor] """ - return self._do_step(batch, batch_idx, "train") + return self._do_step_catch_oom(batch, batch_idx, "train") def validation_step( self, batch: Union[Batch, ProteinBatch], batch_idx: int - ) -> torch.Tensor: + ) -> Optional[torch.Tensor]: """ Perform validation step. @@ -638,9 +719,9 @@ def validation_step( :param batch_idx: Index of batch. :type batch_idx: int :return: Loss - :rtype: torch.Tensor + :rtype: Optional[torch.Tensor] """ - return self._do_step(batch, batch_idx, "val") + return self._do_step_catch_oom(batch, batch_idx, "val") def test_step( self, batch: Union[Batch, ProteinBatch], batch_idx: int