Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
chaitjo authored Mar 5, 2024
2 parents aa61236 + b97cbe9 commit be30d8e
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 8 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

* Adds missing `pos` attribute to GearNet `required_batch_attributes` (fixes [#73](https://github.com/a-r-j/ProteinWorkshop/issues/73)) [#74](https://github.com/a-r-j/ProteinWorkshop/pull/74)
* Fixes PDB download failure due to missing protein data [#77](https://github.com/a-r-j/ProteinWorkshop/pull/77)
* Add support for handling training/validation OOMs gracefully [#81](https://github.com/a-r-j/ProteinWorkshop/pull/81)

### Framework

Expand All @@ -16,6 +17,8 @@
### Command
* Adds `--force-cuda-version` to `workshop install` [#78](https://github.com/a-r-j/ProteinWorkshop/pull/78)

### Features
* Fix `sequence_edges` behaviour when argument `b` is a `Data` object [#80](https://github.com/a-r-j/ProteinWorkshop/pull/80)

### 0.2.5 (28/12/2023)

Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

[Documentation](https://www.proteins.sh)

This [repository](https://github.com/a-r-j/ProteinWorkshop) provides the code for the protein structure representation learning benchmark detailed in the paper [*Evaluating Representation Learning on the Protein Structure Universe*](https://openreview.net/forum?id=sTYuRVrdK3).

This [repository](https://github.com/a-r-j/ProteinWorkshop) provides the code for the protein structure representation learning benchmark detailed in the paper [*Evaluating Representation Learning on the Protein Structure Universe*](https://openreview.net/forum?id=sTYuRVrdK3) (ICLR 2024).

In the benchmark, we implement numerous [featurisation](https://www.proteins.sh/configs/features) schemes, [datasets](https://www.proteins.sh/configs/dataset) for [self-supervised pre-training](https://proteins.sh/quickstart_component/pretrain.html) and [downstream evaluation](https://proteins.sh/quickstart_component/downstream.html), [pre-training](https://proteins.sh/configs/task) tasks, and [auxiliary tasks](https://proteins.sh/configs/task.html#auxiliary-tasks).

Expand Down
2 changes: 1 addition & 1 deletion proteinworkshop/features/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def sequence_edges(
idx_b = torch.arange(1, b.ptr[-1], device=b.ptr.device)
elif isinstance(b, Data):
idx_a = torch.arange(0, b.coords.shape[0] - 1, device=b.coords.device)
idx_a = torch.arange(1, b.coords.shape[0] - 1, device=b.coords.device)
idx_b = torch.arange(1, b.coords.shape[0], device=b.coords.device)
# Concatenate indices to create edge list
if direction == "forward":
e_index = torch.stack([idx_a, idx_b], dim=0)
Expand Down
93 changes: 87 additions & 6 deletions proteinworkshop/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import hydra
import lightning as L
import torch
import torch.distributed as torch_dist
import torch.nn as nn
import torch.nn.functional as F
from beartype import beartype as typechecker
Expand Down Expand Up @@ -598,9 +599,89 @@ def _do_step(
self.log_metrics(loss, y_hat, y, stage, batch=batch)
return loss["total"]

@typechecker
def _do_step_catch_oom(
self,
batch: Batch,
batch_idx: int,
stage: Literal["train", "val"],
) -> Optional[torch.Tensor]:
"""Performs a training/validation step
while catching out of memory errors.
Note that this should not be used for
test steps for proper benchmarking.
1. Obtains labels from :py:meth:`get_labels`
2. Computes model output :py:meth:`forward`
3. Computes loss :py:meth:`compute_loss`
4. Logs metrics :py:meth:`log_metrics`
Returns the total loss.
:param batch: Mini-batch of data.
:type batch: Batch
:param batch_idx: Index of batch.
:type batch_idx: int
:param stage: Stage of training (``"train"``, ``"val"``)
:type stage: Literal["train", "val"]
:return: Loss
:rtype: torch.Tensor
"""
# by default, do not skip the current batch
skip_flag = torch.zeros(
(), device=self.device, dtype=torch.bool
) # NOTE: for skipping batches in a multi-device setting

try:
y = self.get_labels(batch)
y_hat = self(batch)

loss = self.compute_loss(y_hat, y)
self.log_metrics(loss, y_hat, y, stage, batch=batch)

except Exception as e:
skip_flag = torch.ones((), device=self.device, dtype=torch.bool)

if "out of memory" in str(e):
logger.warning(
f"Ran out of memory in the forward pass. Skipping current {stage} batch with index {batch_idx}."
)
if not torch_dist.is_initialized():
# NOTE: for skipping batches in a single-device setting
if self.training:
for p in self.trainer.model.parameters():
if p.grad is not None:
del p.grad # free some memory
return None
else:
if not torch_dist.is_initialized():
raise e

# NOTE: for skipping batches in a multi-device setting
# credit: https://github.com/Lightning-AI/lightning/issues/5243#issuecomment-1553404417
if torch_dist.is_initialized():
# if any rank skips a batch, then all other ranks need to skip
# their batches as well so DDP can properly keep all ranks synced
world_size = torch_dist.get_world_size()
torch_dist.barrier()
result = [torch.zeros_like(skip_flag) for _ in range(world_size)]
torch_dist.all_gather(result, skip_flag)
any_skipped = torch.sum(torch.stack(result)).bool().item()
if any_skipped:
if self.training:
for p in self.trainer.model.parameters():
if p.grad is not None:
del p.grad
logger.warning(
f"Failed to perform the forward pass for at least one rank. Skipping {stage} batches for all ranks."
)
return None

return loss["total"]

def training_step(
self, batch: Union[Batch, ProteinBatch], batch_idx: int
) -> torch.Tensor:
) -> Optional[torch.Tensor]:
"""
Perform training step.
Expand All @@ -616,13 +697,13 @@ def training_step(
:param batch_idx: Index of batch.
:type batch_idx: int
:return: Loss
:rtype: torch.Tensor
:rtype: Optional[torch.Tensor]
"""
return self._do_step(batch, batch_idx, "train")
return self._do_step_catch_oom(batch, batch_idx, "train")

def validation_step(
self, batch: Union[Batch, ProteinBatch], batch_idx: int
) -> torch.Tensor:
) -> Optional[torch.Tensor]:
"""
Perform validation step.
Expand All @@ -638,9 +719,9 @@ def validation_step(
:param batch_idx: Index of batch.
:type batch_idx: int
:return: Loss
:rtype: torch.Tensor
:rtype: Optional[torch.Tensor]
"""
return self._do_step(batch, batch_idx, "val")
return self._do_step_catch_oom(batch, batch_idx, "val")

def test_step(
self, batch: Union[Batch, ProteinBatch], batch_idx: int
Expand Down

0 comments on commit be30d8e

Please sign in to comment.