Skip to content

Commit

Permalink
Add wrapper for catching OOMs during backward pass
Browse files Browse the repository at this point in the history
  • Loading branch information
amorehead committed Mar 6, 2024
1 parent b97cbe9 commit d11a5ed
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
* Adds missing `pos` attribute to GearNet `required_batch_attributes` (fixes [#73](https://github.com/a-r-j/ProteinWorkshop/issues/73)) [#74](https://github.com/a-r-j/ProteinWorkshop/pull/74)
* Fixes PDB download failure due to missing protein data [#77](https://github.com/a-r-j/ProteinWorkshop/pull/77)
* Add support for handling training/validation OOMs gracefully [#81](https://github.com/a-r-j/ProteinWorkshop/pull/81)
* Add support for handling backward OOMs gracefully [#82](https://github.com/a-r-j/ProteinWorkshop/pull/82)

### Framework

Expand Down
54 changes: 54 additions & 0 deletions proteinworkshop/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from proteinworkshop.models.utils import get_loss
from proteinworkshop.types import EncoderOutput, Label, ModelOutput
from proteinworkshop.utils.memory_utils import clean_up_torch_gpu_memory


class BaseModel(L.LightningModule, abc.ABC):
Expand Down Expand Up @@ -743,3 +744,56 @@ def test_step(
:rtype: torch.Tensor
"""
return self._do_step(batch, batch_idx, "test")

def backward(self, loss: torch.Tensor, *args: Any, **kwargs: Dict[str, Any]):
"""Overrides Lightning's `backward` hook to add an out-of-memory (OOM) check.
:param loss: The loss value to backpropagate.
:param args: Additional positional arguments to pass to `torch.Tensor.backward`.
:param kwargs: Additional keyword arguments to pass to `torch.Tensor.backward`.
"""
# by default, do not skip the current batch
skip_flag = torch.zeros(
(), device=self.device, dtype=torch.bool
) # NOTE: for skipping batches in a multi-device setting

try:
loss.backward(*args, **kwargs)
except Exception as e:
skip_flag = torch.ones((), device=self.device, dtype=torch.bool)
logger.warning(f"Failed the backward pass. Skipping it for the current rank due to: {e}")
for p in self.trainer.model.parameters():
if p.grad is not None:
del p.grad
logger.warning("Finished cleaning up all gradients following the failed backward pass.")
if "out of memory" not in str(e) and not torch_dist.is_initialized():
raise e

# NOTE: for skipping batches in a multi-device setting
# credit: https://github.com/Lightning-AI/lightning/issues/5243#issuecomment-1553404417
if torch_dist.is_initialized():
# if any rank skips a batch, then all other ranks need to skip
# their batches as well so DDP can properly keep all ranks synced
world_size = torch_dist.get_world_size()
torch_dist.barrier()
result = [torch.zeros_like(skip_flag) for _ in range(world_size)]
torch_dist.all_gather(result, skip_flag)
any_skipped = torch.sum(torch.stack(result)).bool().item()
if any_skipped:
logger.warning(
"Skipping backward for all ranks after detecting a failed backward pass."
)
del loss # delete the computation graph
logger.warning(
"Finished cleaning up the computation graph following one of the rank's failed backward pass."
)
for p in self.trainer.model.parameters():
if p.grad is not None:
del p.grad
logger.warning(
"Finished cleaning up all gradients following one of the rank's failed backward pass."
)
clean_up_torch_gpu_memory()
logger.warning(
"Finished manually freeing up memory following one of the rank's failed backward pass."
)
85 changes: 85 additions & 0 deletions proteinworkshop/utils/memory_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Implement CUDA memory monitoring and management utilities."""
import gc

import torch

from beartype.typing import Union
from loguru import logger as log


def gpu_memory_usage(device: Union[int, torch.device] = 0) -> float:
"""
Get GPU memory usage in GB.
From: https://github.com/pytorch/pytorch/issues/82218#issuecomment-1675254117
:param device: GPU device as an index or a `device` object.
:return: GPU memory usage in GB.
"""
return torch.cuda.memory_allocated(device) / 1024.0**3


def gpu_memory_usage_all(device: Union[int, torch.device] = 0) -> tuple[float, float]:
"""
Get GPU memory usage in GB.
From: https://github.com/pytorch/pytorch/issues/82218#issuecomment-1675254117
:param device: GPU device as an index or a `device` object.
:return: GPU memory usage and cache in GB.
"""
usage = torch.cuda.memory_allocated(device) / 1024.0**3
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
cache = reserved - usage
return usage, cache


def clean_up_torch_gpu_memory(device: Union[int, torch.device] = 0):
"""
Clean up PyTorch GPU memory systematically.
From: https://github.com/pytorch/pytorch/issues/82218#issuecomment-1675254117
:param device: GPU device as an index or a `device` object.
"""
try:
gc.collect()
torch.cuda.empty_cache()
finally:
gc.collect()
torch.cuda.empty_cache()

if (mem := gpu_memory_usage()) > 3.0:
log.warning(f"GPU memory usage is still high, with `mem={mem}`!")
cnt = 0
for obj in get_tensors():
obj.detach()
obj.grad = None
obj.storage().resize_(0)
cnt += 1
gc.collect()
torch.cuda.empty_cache()
usage, cache = gpu_memory_usage_all(device=device)
log.warning(
f"Forcibly cleared {cnt} tensors: {mem:.03f}GB -> {usage:.03f}GB (+{cache:.03f}GB cache)"
)


def get_tensors(gpu_only: bool = True):
"""
Get all tensors in memory.
From: https://github.com/pytorch/pytorch/issues/82218#issuecomment-1675254117
:param gpu_only: If True, only return tensors on GPU.
:return: Generator of tensors.
"""
for obj in gc.get_objects():
try:
if torch.is_tensor(obj):
tensor = obj
elif hasattr(obj, "data") and torch.is_tensor(obj.data):
tensor = obj.data
else:
continue

if tensor.is_cuda or not gpu_only:
yield tensor
except Exception: # nosec B112 pylint: disable=broad-exception-caught
continue

0 comments on commit d11a5ed

Please sign in to comment.