diff --git a/CHANGELOG.md b/CHANGELOG.md index d882703b..dd8c9834 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ * Adds missing `pos` attribute to GearNet `required_batch_attributes` (fixes [#73](https://github.com/a-r-j/ProteinWorkshop/issues/73)) [#74](https://github.com/a-r-j/ProteinWorkshop/pull/74) * Fixes PDB download failure due to missing protein data [#77](https://github.com/a-r-j/ProteinWorkshop/pull/77) * Add support for handling training/validation OOMs gracefully [#81](https://github.com/a-r-j/ProteinWorkshop/pull/81) +* Add support for handling backward OOMs gracefully [#82](https://github.com/a-r-j/ProteinWorkshop/pull/82) ### Framework diff --git a/proteinworkshop/models/base.py b/proteinworkshop/models/base.py index b45cb9dd..3ed98b2c 100644 --- a/proteinworkshop/models/base.py +++ b/proteinworkshop/models/base.py @@ -17,6 +17,7 @@ from proteinworkshop.models.utils import get_loss from proteinworkshop.types import EncoderOutput, Label, ModelOutput +from proteinworkshop.utils.memory_utils import clean_up_torch_gpu_memory class BaseModel(L.LightningModule, abc.ABC): @@ -743,3 +744,56 @@ def test_step( :rtype: torch.Tensor """ return self._do_step(batch, batch_idx, "test") + + def backward(self, loss: torch.Tensor, *args: Any, **kwargs: Dict[str, Any]): + """Overrides Lightning's `backward` hook to add an out-of-memory (OOM) check. + + :param loss: The loss value to backpropagate. + :param args: Additional positional arguments to pass to `torch.Tensor.backward`. + :param kwargs: Additional keyword arguments to pass to `torch.Tensor.backward`. + """ + # by default, do not skip the current batch + skip_flag = torch.zeros( + (), device=self.device, dtype=torch.bool + ) # NOTE: for skipping batches in a multi-device setting + + try: + loss.backward(*args, **kwargs) + except Exception as e: + skip_flag = torch.ones((), device=self.device, dtype=torch.bool) + logger.warning(f"Failed the backward pass. Skipping it for the current rank due to: {e}") + for p in self.trainer.model.parameters(): + if p.grad is not None: + del p.grad + logger.warning("Finished cleaning up all gradients following the failed backward pass.") + if "out of memory" not in str(e) and not torch_dist.is_initialized(): + raise e + + # NOTE: for skipping batches in a multi-device setting + # credit: https://github.com/Lightning-AI/lightning/issues/5243#issuecomment-1553404417 + if torch_dist.is_initialized(): + # if any rank skips a batch, then all other ranks need to skip + # their batches as well so DDP can properly keep all ranks synced + world_size = torch_dist.get_world_size() + torch_dist.barrier() + result = [torch.zeros_like(skip_flag) for _ in range(world_size)] + torch_dist.all_gather(result, skip_flag) + any_skipped = torch.sum(torch.stack(result)).bool().item() + if any_skipped: + logger.warning( + "Skipping backward for all ranks after detecting a failed backward pass." + ) + del loss # delete the computation graph + logger.warning( + "Finished cleaning up the computation graph following one of the rank's failed backward pass." + ) + for p in self.trainer.model.parameters(): + if p.grad is not None: + del p.grad + logger.warning( + "Finished cleaning up all gradients following one of the rank's failed backward pass." + ) + clean_up_torch_gpu_memory() + logger.warning( + "Finished manually freeing up memory following one of the rank's failed backward pass." + ) diff --git a/proteinworkshop/utils/memory_utils.py b/proteinworkshop/utils/memory_utils.py new file mode 100644 index 00000000..23e8757e --- /dev/null +++ b/proteinworkshop/utils/memory_utils.py @@ -0,0 +1,85 @@ +"""Implement CUDA memory monitoring and management utilities.""" +import gc + +import torch + +from beartype.typing import Union +from loguru import logger as log + + +def gpu_memory_usage(device: Union[int, torch.device] = 0) -> float: + """ + Get GPU memory usage in GB. + From: https://github.com/pytorch/pytorch/issues/82218#issuecomment-1675254117 + + :param device: GPU device as an index or a `device` object. + :return: GPU memory usage in GB. + """ + return torch.cuda.memory_allocated(device) / 1024.0**3 + + +def gpu_memory_usage_all(device: Union[int, torch.device] = 0) -> tuple[float, float]: + """ + Get GPU memory usage in GB. + From: https://github.com/pytorch/pytorch/issues/82218#issuecomment-1675254117 + + :param device: GPU device as an index or a `device` object. + :return: GPU memory usage and cache in GB. + """ + usage = torch.cuda.memory_allocated(device) / 1024.0**3 + reserved = torch.cuda.memory_reserved(device) / 1024.0**3 + cache = reserved - usage + return usage, cache + + +def clean_up_torch_gpu_memory(device: Union[int, torch.device] = 0): + """ + Clean up PyTorch GPU memory systematically. + From: https://github.com/pytorch/pytorch/issues/82218#issuecomment-1675254117 + + :param device: GPU device as an index or a `device` object. + """ + try: + gc.collect() + torch.cuda.empty_cache() + finally: + gc.collect() + torch.cuda.empty_cache() + + if (mem := gpu_memory_usage()) > 3.0: + log.warning(f"GPU memory usage is still high, with `mem={mem}`!") + cnt = 0 + for obj in get_tensors(): + obj.detach() + obj.grad = None + obj.storage().resize_(0) + cnt += 1 + gc.collect() + torch.cuda.empty_cache() + usage, cache = gpu_memory_usage_all(device=device) + log.warning( + f"Forcibly cleared {cnt} tensors: {mem:.03f}GB -> {usage:.03f}GB (+{cache:.03f}GB cache)" + ) + + +def get_tensors(gpu_only: bool = True): + """ + Get all tensors in memory. + From: https://github.com/pytorch/pytorch/issues/82218#issuecomment-1675254117 + + :param gpu_only: If True, only return tensors on GPU. + :return: Generator of tensors. + """ + for obj in gc.get_objects(): + try: + if torch.is_tensor(obj): + tensor = obj + elif hasattr(obj, "data") and torch.is_tensor(obj.data): + tensor = obj.data + else: + continue + + if tensor.is_cuda or not gpu_only: + yield tensor + except Exception: # nosec B112 pylint: disable=broad-exception-caught + continue