Skip to content

Commit

Permalink
Merge pull request #81 from a-r-j/catch-oom
Browse files Browse the repository at this point in the history
Add support for handling training/validation OOMs gracefully
  • Loading branch information
amorehead authored Mar 5, 2024
2 parents 3bea793 + 8915bca commit b97cbe9
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

* Adds missing `pos` attribute to GearNet `required_batch_attributes` (fixes [#73](https://github.com/a-r-j/ProteinWorkshop/issues/73)) [#74](https://github.com/a-r-j/ProteinWorkshop/pull/74)
* Fixes PDB download failure due to missing protein data [#77](https://github.com/a-r-j/ProteinWorkshop/pull/77)
* Add support for handling training/validation OOMs gracefully [#81](https://github.com/a-r-j/ProteinWorkshop/pull/81)

### Framework

Expand Down
93 changes: 87 additions & 6 deletions proteinworkshop/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import hydra
import lightning as L
import torch
import torch.distributed as torch_dist
import torch.nn as nn
import torch.nn.functional as F
from beartype import beartype as typechecker
Expand Down Expand Up @@ -598,9 +599,89 @@ def _do_step(
self.log_metrics(loss, y_hat, y, stage, batch=batch)
return loss["total"]

@typechecker
def _do_step_catch_oom(
self,
batch: Batch,
batch_idx: int,
stage: Literal["train", "val"],
) -> Optional[torch.Tensor]:
"""Performs a training/validation step
while catching out of memory errors.
Note that this should not be used for
test steps for proper benchmarking.
1. Obtains labels from :py:meth:`get_labels`
2. Computes model output :py:meth:`forward`
3. Computes loss :py:meth:`compute_loss`
4. Logs metrics :py:meth:`log_metrics`
Returns the total loss.
:param batch: Mini-batch of data.
:type batch: Batch
:param batch_idx: Index of batch.
:type batch_idx: int
:param stage: Stage of training (``"train"``, ``"val"``)
:type stage: Literal["train", "val"]
:return: Loss
:rtype: torch.Tensor
"""
# by default, do not skip the current batch
skip_flag = torch.zeros(
(), device=self.device, dtype=torch.bool
) # NOTE: for skipping batches in a multi-device setting

try:
y = self.get_labels(batch)
y_hat = self(batch)

loss = self.compute_loss(y_hat, y)
self.log_metrics(loss, y_hat, y, stage, batch=batch)

except Exception as e:
skip_flag = torch.ones((), device=self.device, dtype=torch.bool)

if "out of memory" in str(e):
logger.warning(
f"Ran out of memory in the forward pass. Skipping current {stage} batch with index {batch_idx}."
)
if not torch_dist.is_initialized():
# NOTE: for skipping batches in a single-device setting
if self.training:
for p in self.trainer.model.parameters():
if p.grad is not None:
del p.grad # free some memory
return None
else:
if not torch_dist.is_initialized():
raise e

# NOTE: for skipping batches in a multi-device setting
# credit: https://github.com/Lightning-AI/lightning/issues/5243#issuecomment-1553404417
if torch_dist.is_initialized():
# if any rank skips a batch, then all other ranks need to skip
# their batches as well so DDP can properly keep all ranks synced
world_size = torch_dist.get_world_size()
torch_dist.barrier()
result = [torch.zeros_like(skip_flag) for _ in range(world_size)]
torch_dist.all_gather(result, skip_flag)
any_skipped = torch.sum(torch.stack(result)).bool().item()
if any_skipped:
if self.training:
for p in self.trainer.model.parameters():
if p.grad is not None:
del p.grad
logger.warning(
f"Failed to perform the forward pass for at least one rank. Skipping {stage} batches for all ranks."
)
return None

return loss["total"]

def training_step(
self, batch: Union[Batch, ProteinBatch], batch_idx: int
) -> torch.Tensor:
) -> Optional[torch.Tensor]:
"""
Perform training step.
Expand All @@ -616,13 +697,13 @@ def training_step(
:param batch_idx: Index of batch.
:type batch_idx: int
:return: Loss
:rtype: torch.Tensor
:rtype: Optional[torch.Tensor]
"""
return self._do_step(batch, batch_idx, "train")
return self._do_step_catch_oom(batch, batch_idx, "train")

def validation_step(
self, batch: Union[Batch, ProteinBatch], batch_idx: int
) -> torch.Tensor:
) -> Optional[torch.Tensor]:
"""
Perform validation step.
Expand All @@ -638,9 +719,9 @@ def validation_step(
:param batch_idx: Index of batch.
:type batch_idx: int
:return: Loss
:rtype: torch.Tensor
:rtype: Optional[torch.Tensor]
"""
return self._do_step(batch, batch_idx, "val")
return self._do_step_catch_oom(batch, batch_idx, "val")

def test_step(
self, batch: Union[Batch, ProteinBatch], batch_idx: int
Expand Down

0 comments on commit b97cbe9

Please sign in to comment.