Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pep8 cleanup and fix tests #147

Merged
merged 11 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ jobs:
strategy:
max-parallel: 2
matrix:
python-version: ['3.8']
torch-version: [1.10.0, 2.0.0]
os: [ubuntu-latest] # only run ubuntu for now because the other ones fail for no reason, macos-latest, windows-latest]
python-version: ['3.9']
torch-version: [2.1.1]
os: [ubuntu-latest, macos-latest, windows-latest] # only run ubuntu for now because the other ones fail for no reason, macos-latest, windows-latest]

# Steps represent a sequence of tasks that will be executed as part of the job
steps:
Expand Down
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
-r requirements.txt
wheel
pytest
pydicom>=2.3.1
pydicom>=2.3.1
2 changes: 1 addition & 1 deletion torchxrayvision/autoencoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def ResNetAE(weights=None):
"""A ResNet based autoencoder.

Possible weights for this class include:

- "101-elastic" trained on PadChest, NIH, CheXpert, and MIMIC. From the paper https://arxiv.org/abs/2102.09475

.. code-block:: python
Expand Down
4 changes: 2 additions & 2 deletions torchxrayvision/baseline_models/chexpert/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def infer(self, x, tasks):
for task in tasks:

idx = self.task_sequence[task]
#task_prob = probs.detach().cpu().numpy()[idx]
# task_prob = probs.detach().cpu().numpy()[idx]
task_prob = probs[idx]
task2results[task] = task_prob

Expand Down Expand Up @@ -226,7 +226,7 @@ def infer(self, img, tasks):
else:
task2ensemble_results[task].append(individual_task2results[task])

assert all([task in task2ensemble_results for task in tasks]),\
assert all([task in task2ensemble_results for task in tasks]), \
"Not all tasks in task2ensemble_results"

task2results = {}
Expand Down
18 changes: 9 additions & 9 deletions torchxrayvision/baseline_models/riken/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,18 @@ class AgeModel(nn.Module):
url = {https://www.nature.com/articles/s43856-022-00220-6},
year = {2022}
}

"""

targets: List[str] = ["Age"]
""""""

def __init__(self):

super(AgeModel, self).__init__()

url = "https://github.com/mlmed/torchxrayvision/releases/download/v1/baseline_models_riken_xray_age_every_model_age_senet154_v2_tl_26_ft_7_fp32.pt"

weights_filename = os.path.basename(url)
weights_storage_folder = os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data"))
self.weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder, weights_filename))
Expand All @@ -81,17 +81,17 @@ def __init__(self):
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225],
)

def forward(self, x):
x = x.repeat(1, 3, 1, 1)
x = self.upsample(x)

# expecting values between [-1024,1024]
x = (x + 1024) / 2048
# now between [0,1]

x = self.norm(x)
return self.model(x)

def __repr__(self):
return "riken-age-prediction"
23 changes: 11 additions & 12 deletions torchxrayvision/baseline_models/xinario/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class ViewModel(nn.Module):
"""


The native resolution of the model is 320x320. Images are scaled
automatically.
Expand All @@ -26,7 +26,7 @@ class ViewModel(nn.Module):

pred = model(image)
# tensor([[17.3186, 26.7156]]), grad_fn=<AddmmBackward0>)

model.targets[pred.argmax()]
# Lateral

Expand All @@ -37,13 +37,13 @@ class ViewModel(nn.Module):

targets: List[str] = ['Frontal', 'Lateral']
""""""

def __init__(self):

super(ViewModel, self).__init__()

url = "https://github.com/mlmed/torchxrayvision/releases/download/v1/xinario_chestViewSplit_resnet-50.pt"

weights_filename = os.path.basename(url)
weights_storage_folder = os.path.expanduser(os.path.join("~", ".torchxrayvision", "models_data"))
self.weights_filename_local = os.path.expanduser(os.path.join(weights_storage_folder, weights_filename))
Expand All @@ -54,7 +54,6 @@ def __init__(self):
pathlib.Path(weights_storage_folder).mkdir(parents=True, exist_ok=True)
xrv.utils.download(url, self.weights_filename_local)


self.model = torchvision.models.resnet.resnet50()
try:
weights = torch.load(self.weights_filename_local)
Expand All @@ -74,17 +73,17 @@ def __init__(self):
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225],
)

def forward(self, x):
x = x.repeat(1, 3, 1, 1)
x = self.upsample(x)

# expecting values between [-1024,1024]
x = (x + 1024) / 2048
# now between [0,1]

x = self.norm(x)
return self.model(x)[:,:2] # cut off the rest of the outputs
return self.model(x)[:, :2] # cut off the rest of the outputs

def __repr__(self):
return "xinario-view-prediction"
38 changes: 21 additions & 17 deletions torchxrayvision/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class Dataset:
metadata file and for some the metadata files are packaged in the library
so only the imgpath needs to be specified.
"""

def __init__(self):
pass

Expand Down Expand Up @@ -262,7 +263,7 @@ def __init__(self, datasets, seed=0, label_concat=False):
print("Could not merge dataframes (.csv not available):", sys.exc_info()[0])

self.csv = self.csv.reset_index(drop=True)

def __setattr__(self, name, value):
if hasattr(self, 'labels'):
# check only if have finished init, otherwise __init__ breaks
Expand Down Expand Up @@ -346,6 +347,7 @@ class SubsetDataset(Dataset):
- of PC_Dataset num_samples=94825 views=['PA', 'AP'] data_aug=None

"""

def __init__(self, dataset, idxs=None):
super(SubsetDataset, self).__init__()
self.dataset = dataset
Expand All @@ -365,7 +367,7 @@ def __setattr__(self, name, value):
# check only if have finished init, otherwise __init__ breaks
if name in ['transform', 'data_aug', 'labels', 'pathologies', 'targets']:
raise NotImplementedError(f'Cannot set {name} on a subset dataset. Set the transforms directly on the dataset object. If it was to be set via this subset dataset it would have to modify the internal dataset which could have unexpected side effects')

object.__setattr__(self, name, value)

def string(self):
Expand Down Expand Up @@ -895,17 +897,17 @@ def __init__(self,
"216840111366964012373310883942009170084120009_00-097-074.png",
"216840111366964012819207061112010315104455352_04-024-184.png",
"216840111366964012819207061112010306085429121_04-020-102.png",
"216840111366964012989926673512011083134050913_00-168-009.png", # broken PNG file (chunk b'\x00\x00\x00\x00')
"216840111366964012373310883942009152114636712_00-102-045.png", # "OSError: image file is truncated"
"216840111366964012819207061112010281134410801_00-129-131.png", # "OSError: image file is truncated"
"216840111366964012487858717522009280135853083_00-075-001.png", # "OSError: image file is truncated"
"216840111366964012989926673512011151082430686_00-157-045.png", # broken PNG file (chunk b'\x00\x00\x00\x00')
"216840111366964013686042548532013208193054515_02-026-007.png", # "OSError: image file is truncated"
"216840111366964013590140476722013058110301622_02-056-111.png", # "OSError: image file is truncated"
"216840111366964013590140476722013043111952381_02-065-198.png", # "OSError: image file is truncated"
"216840111366964013829543166512013353113303615_02-092-190.png", # "OSError: image file is truncated"
"216840111366964013962490064942014134093945580_01-178-104.png", # "OSError: image file is truncated"
]
"216840111366964012989926673512011083134050913_00-168-009.png", # broken PNG file (chunk b'\x00\x00\x00\x00')
"216840111366964012373310883942009152114636712_00-102-045.png", # "OSError: image file is truncated"
"216840111366964012819207061112010281134410801_00-129-131.png", # "OSError: image file is truncated"
"216840111366964012487858717522009280135853083_00-075-001.png", # "OSError: image file is truncated"
"216840111366964012989926673512011151082430686_00-157-045.png", # broken PNG file (chunk b'\x00\x00\x00\x00')
"216840111366964013686042548532013208193054515_02-026-007.png", # "OSError: image file is truncated"
"216840111366964013590140476722013058110301622_02-056-111.png", # "OSError: image file is truncated"
"216840111366964013590140476722013043111952381_02-065-198.png", # "OSError: image file is truncated"
"216840111366964013829543166512013353113303615_02-092-190.png", # "OSError: image file is truncated"
"216840111366964013962490064942014134093945580_01-178-104.png", # "OSError: image file is truncated"
]
self.csv = self.csv[~self.csv["ImageID"].isin(missing)]

if unique_patients:
Expand All @@ -920,7 +922,7 @@ def __init__(self,
mask = self.csv["Labels"].str.contains(pathology.lower())
if pathology in mapping:
for syn in mapping[pathology]:
#print("mapping", syn)
# print("mapping", syn)
mask |= self.csv["Labels"].str.contains(syn.lower())
labels.append(mask.values)
self.labels = np.asarray(labels).T
Expand Down Expand Up @@ -1094,7 +1096,7 @@ def __getitem__(self, idx):
sample["lab"] = self.labels[idx]

imgid = self.csv['Path'].iloc[idx]
#clean up path in csv so the user can specify the path
# clean up path in csv so the user can specify the path
imgid = imgid.replace("CheXpert-v1.0-small/", "").replace("CheXpert-v1.0/", "")
img_path = os.path.join(self.imgpath, imgid)
img = imread(img_path)
Expand Down Expand Up @@ -1344,7 +1346,7 @@ def __init__(self, imgpath,
mask = self.csv["labels_automatic"].str.contains(pathology.lower())
if pathology in mapping:
for syn in mapping[pathology]:
#print("mapping", syn)
# print("mapping", syn)
mask |= self.csv["labels_automatic"].str.contains(syn.lower())
labels.append(mask.values)

Expand Down Expand Up @@ -1994,7 +1996,7 @@ def __init__(self,
transform=None,
data_aug=None,
seed=0
):
):
super(ObjectCXR_Dataset, self).__init__()

np.random.seed(seed) # Reset the seed so all runs are the same.
Expand Down Expand Up @@ -2053,6 +2055,7 @@ def __call__(self, x):

class XRayResizer(object):
"""Resize an image to a specific size"""

def __init__(self, size: int, engine="skimage"):
self.size = size
self.engine = engine
Expand All @@ -2076,6 +2079,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray:

class XRayCenterCrop(object):
"""Perform a center crop on the long dimension of the input image"""

def crop_center(self, img: np.ndarray) -> np.ndarray:
_, y, x = img.shape
crop_size = np.min([y, x])
Expand Down
6 changes: 4 additions & 2 deletions torchxrayvision/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@
}

# Just created for documentation


class Model:
"""The library is composed of core and baseline classifiers. Core
classifiers are trained specifically for this library and baseline
Expand Down Expand Up @@ -132,6 +134,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
pass


class _DenseLayer(nn.Sequential):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
super(_DenseLayer, self).__init__()
Expand Down Expand Up @@ -190,7 +193,7 @@ class DenseNet(nn.Module):
:param weights: Specify a weight name to load pre-trained weights
:param op_threshs: Specify a weight name to load pre-trained weights
:param apply_sigmoid: Apply a sigmoid

"""

targets: List[str] = [
Expand Down Expand Up @@ -379,7 +382,6 @@ class ResNet(nn.Module):
]
""""""


def __init__(self, weights: str = None, apply_sigmoid: bool = False):
super(ResNet, self).__init__()

Expand Down
17 changes: 9 additions & 8 deletions torchxrayvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def load_image(fname: str):

return img

def read_xray_dcm(path:PathLike, voi_lut:bool=False, fix_monochrome:bool=True)->ndarray:

def read_xray_dcm(path: PathLike, voi_lut: bool = False, fix_monochrome: bool = True) -> ndarray:
"""read a dicom-like file and convert to numpy array

Args:
Expand All @@ -98,26 +99,26 @@ def read_xray_dcm(path:PathLike, voi_lut:bool=False, fix_monochrome:bool=True)->

# get the pixel array
ds = pydicom.dcmread(path, force=True)
data = ds.pixel_array

# we have not tested RGB, YBR_FULL, or YBR_FULL_422 yet.
if ds.PhotometricInterpretation not in ['MONOCHROME1', 'MONOCHROME2']:
if ds.PhotometricInterpretation not in ['MONOCHROME1', 'MONOCHROME2']:
raise NotImplementedError(f'PhotometricInterpretation `{ds.PhotometricInterpretation}` is not yet supported.')
# get the max possible pixel value from DCM header
max_possible_pixel_val = (2**ds.BitsStored - 1)

data = ds.pixel_array

# LUT for human friendly view
if voi_lut:
data = pydicom.pixel_data_handlers.util.apply_voi_lut(data, ds, index=0)


# `MONOCHROME1` have an inverted view; Bones are black; background is white
# https://web.archive.org/web/20150920230923/http://www.mccauslandcenter.sc.edu/mricro/dicom/index.html
if fix_monochrome and ds.PhotometricInterpretation == "MONOCHROME1":
warnings.warn(f"Coverting MONOCHROME1 to MONOCHROME2 interpretation for file: {path}. Can be avoided by setting `fix_monochrome=False`")
data = max_possible_pixel_val - data

# normalize data to [-1024, 1024]
# normalize data to [-1024, 1024]
data = normalize(data, max_possible_pixel_val)
return data

Expand All @@ -129,13 +130,13 @@ def infer(model: torch.nn.Module, dataset: torch.utils.data.Dataset, threads=4,
batch_size=threads,
num_workers=threads,
)

preds = []
with torch.inference_mode():
for i, batch in enumerate(tqdm(dl)):
output = model(batch["img"].to(device))

output = output.detach().cpu().numpy()
preds.append(output)

return np.concatenate(preds)