Skip to content

Commit

Permalink
Fixed pdb dataset download
Browse files Browse the repository at this point in the history
  • Loading branch information
kierandidi committed Mar 13, 2024
1 parent 22d6379 commit 2df0aad
Showing 1 changed file with 47 additions and 14 deletions.
61 changes: 47 additions & 14 deletions proteinworkshop/datasets/pdb_dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Callable, Iterable, List, Optional
from typing import Callable, Iterable, List, Optional, Dict

import hydra
import omegaconf
import os
import pandas as pd
import pathlib
from graphein.ml.datasets import PDBManager
from loguru import logger as log
from torch_geometric.data import Dataset
Expand All @@ -11,9 +13,8 @@
from proteinworkshop.datasets.base import ProteinDataModule, ProteinDataset
from proteinworkshop.datasets.utils import download_pdb_mmtf


class PDBData:
def __init__(
def _init_(
self,
fraction: float,
min_length: int,
Expand Down Expand Up @@ -127,24 +128,31 @@ def create_dataset(self):


class PDBDataModule(ProteinDataModule):
def __init__(
def _init_(
self,
path: Optional[str] = None,
structure_dir: Optional[str] = None,
pdb_dataset: Optional[PDBData] = None,
transforms: Optional[Iterable[Callable]] = None,
in_memory: bool = False,
batch_size: int = 32,
num_workers: int = 0,
pin_memory: bool = False,
structure_format: str = "mmtf.gz",
overwrite: bool = False,
):
super().__init__()
super()._init_()
self.root = path
self.dataset = pdb_dataset
self.dataset.path = path
self.format = "mmtf.gz"
self.format = structure_format
self.overwrite = overwrite

if structure_dir is not None:
self.structure_dir = pathlib.Path(structure_dir)
else:
self.structure_dir = pathlib.Path(self.root) / "raw"

self.in_memory = in_memory

if transforms is not None:
Expand All @@ -157,19 +165,45 @@ def __init__(
self.num_workers = num_workers
self.pin_memory = pin_memory
self.batch_size = batch_size

def parse_dataset(self) -> pd.DataFrame:
return self.dataset.create_dataset()


def parse_dataset(self) -> Dict[str, pd.DataFrame]:
if hasattr(self, "splits"):
return getattr(self, "splits")

splits = self.dataset.create_dataset()
ids_to_exclude = self.exclude_pdbs()

if ids_to_exclude is not None:
for k, v in splits.items():
log.info(f"Split {k} has {len(v)} chains before excluding failing PDB")
v["id"] = v["pdb"] + "_" + v["chain"].str.join("")
log.info(v)
splits[k] = v.loc[v.id.isin(ids_to_exclude) == False]
log.info(
f"Split {k} has {len(splits[k])} chains after excluding failing PDB"
)
self.splits = splits
breakpoint()
return splits
# def parse_dataset(self) -> pd.DataFrame:
# return self.dataset.create_dataset()

def exclude_pdbs(self):
pass

def download(self):
pdbs = self.parse_dataset()

for k, v in pdbs:
log.info(f"Downloading {k} PDBs")
download_pdb_mmtf(pathlib.Path(self.root) / "raw", v.pdb.tolist())
for k, v in pdbs.items():
log.info(f"Downloading {k} PDBs to {self.structure_dir}")
pdblist = v.pdb.tolist()
pdblist = [
pdb
for pdb in pdblist
if not os.path.exists(self.structure_dir / f"{pdb}.{self.format}")
]
download_pdb_mmtf(self.structure_dir, pdblist)

def parse_labels(self):
raise NotImplementedError
Expand Down Expand Up @@ -223,7 +257,6 @@ def test_dataset(self) -> Dataset:


if __name__ == "__main__":
import pathlib

from proteinworkshop import constants

Expand All @@ -234,4 +267,4 @@ def test_dataset(self) -> Dataset:
print(cfg)
ds = hydra.utils.instantiate(cfg)["datamodule"]
print(ds)
ds.val_dataset()
ds.val_dataset()

0 comments on commit 2df0aad

Please sign in to comment.