Skip to content

Commit

Permalink
debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
ieee8023 committed Mar 13, 2024
1 parent f4505b8 commit e154062
Show file tree
Hide file tree
Showing 6 changed files with 903 additions and 0 deletions.
340 changes: 340 additions & 0 deletions scripts/medical_mae_example.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions torchxrayvision/baseline_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from . import emory_hiti
from . import riken
from . import xinario
from . import medical_mae
1 change: 1 addition & 0 deletions torchxrayvision/baseline_models/medical_mae/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import models_mae_cnn
356 changes: 356 additions & 0 deletions torchxrayvision/baseline_models/medical_mae/models_mae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,356 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------

from functools import partial

import torch
import torch.nn as nn

from timm.models.vision_transformer import PatchEmbed, Block

from .pos_embed import get_2d_sincos_pos_embed


class MaskedAutoencoderViT(nn.Module):
""" Masked Autoencoder with VisionTransformer backbone
"""

def __init__(self, img_size=224, patch_size=16, in_chans=3,
embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, mask_strategy='random'):
super().__init__()

# --------------------------------------------------------------------------
# MAE encoder specifics
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim),
requires_grad=False) # fixed sin-cos embedding

self.blocks = nn.ModuleList([
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# --------------------------------------------------------------------------

# --------------------------------------------------------------------------
# MAE decoder specifics
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim),
requires_grad=False) # fixed sin-cos embedding

self.decoder_blocks = nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
for i in range(decoder_depth)])

self.decoder_norm = norm_layer(decoder_embed_dim)
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch
# --------------------------------------------------------------------------

self.norm_pix_loss = norm_pix_loss

self.initialize_weights()
# if heatmap is not None:
# self.heatmap_weights = self.extract_patch_weights(heatmap, img_size, patch_size,
# weight_min=weight_range[0], weight_max=weight_range[1],
# heatmap_binary_threshold=heatmap_binary_threshold)
#

self.img_size = img_size
self.patch_size = patch_size
self.mask_strategy = mask_strategy
self.local_attention_mask = self.get_local_attention_mask(img_size, patch_size)#.cuda()

# def extract_patch_weights(self, heatmap, img_size, patch_size, weight_min=0.1, weight_max=1.0,
# heatmap_binary_threshold=None):
# heatmap = cv2.resize(heatmap, (img_size, img_size), interpolation=cv2.INTER_AREA)
# heatmap = heatmap[:, :, 0] # only need one channel for mask
# heatmap = heatmap.astype(np.float32)
#
# if heatmap_binary_threshold is not None:
# if heatmap_binary_threshold == 'mean':
# threshold = np.mean(heatmap)
# else:
# raise NotImplementedError
# heatmap = (heatmap > threshold).astype(np.float32)
#
# heatmap = torch.tensor(heatmap)
# h = w = heatmap.shape[0] // patch_size
# heatmap = heatmap.reshape(h, patch_size, w, patch_size)
# heatmap = torch.einsum('hpwq->hwpq', heatmap)
# heatmap_weights = heatmap.reshape(h * w, patch_size ** 2).sum(dim=-1)
# print('**************************')
# print(weight_min, weight_max)
# print('**************************')
# heatmap_weights = (heatmap_weights / heatmap_weights.max() * (weight_max - weight_min) + weight_min)
# return heatmap_weights

def get_local_attention_mask(self, img_size, patch_size):
h = w = img_size // patch_size
masks = []
for i in range(h):
for j in range(w):
mask = torch.zeros(h, w)

x_min = max(0, i - 1)
x_max = min(h - 1, i + 1)
y_min = max(0, j - 1)
y_max = min(w - 1, j + 1)

mask[x_min:x_max + 1, y_min:y_max + 1] = 1
# print(x_min, x_max, y_min, y_max)
masks.append(mask.flatten())
masks = torch.stack(masks, dim=0).unsqueeze(dim=0)
return masks


def initialize_weights(self):
# initialization
# initialize (and freeze) pos_embed by sin-cos embedding
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5),
cls_token=True)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1],
int(self.patch_embed.num_patches ** .5), cls_token=True)
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.cls_token, std=.02)
torch.nn.init.normal_(self.mask_token, std=.02)

# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)

def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

def patchify(self, imgs):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
p = self.patch_embed.patch_size[0]
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
return x

def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
p = self.patch_embed.patch_size[0]
h = w = int(x.shape[1] ** .5)
assert h * w == x.shape[1]

x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return imgs

def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))

noise = torch.rand(N, L, device=x.device) # noise in [0, 1]


# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)

# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)

return x_masked, mask, ids_restore

def forward_encoder(self, x, mask_ratio):
# embed patches
x = self.patch_embed(x)

# add pos embed w/o cls token

x = x + self.pos_embed[:, 1:, :]

x, mask, ids_restore = self.random_masking(x, mask_ratio)

# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)

# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x)

return x, mask, ids_restore

def forward_decoder(self, x, ids_restore):
# embed tokens
x = self.decoder_embed(x)

# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token

# add pos embed
x = x + self.decoder_pos_embed

# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)

# predictor projection
x = self.decoder_pred(x)

# remove cls token
x = x[:, 1:, :]

return x

def forward_loss(self, imgs, pred, mask):
"""
imgs: [N, 3, H, W]
pred: [N, L, p*p*3]
mask: [N, L], 0 is keep, 1 is remove,
"""
target = self.patchify(imgs)
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6) ** .5

loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch

loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss

def forward(self, imgs, mask_ratio=0.75, heatmaps=None):
if heatmaps is not None:
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio, heatmaps)
else:
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
loss = self.forward_loss(imgs, pred, mask)
return loss, pred, mask


def mae_vit_small_patch16_dec128d2b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16, embed_dim=384, depth=12, num_heads=6,
decoder_embed_dim=128, decoder_depth=2, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model


def mae_vit_small_patch16_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16, embed_dim=384, depth=12, num_heads=6,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model


def mae_vit_small_patch16_dec512d2b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16, embed_dim=384, depth=12, num_heads=6,
decoder_embed_dim=512, decoder_depth=2, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model


def mae_vit_small_patch16_dec512d2b_448(**kwargs):
model = MaskedAutoencoderViT(
img_size=448, patch_size=16, embed_dim=384, depth=12, num_heads=6,
decoder_embed_dim=512, decoder_depth=2, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model


def mae_vit_base_patch16_dec512d2b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16, embed_dim=768, depth=12, num_heads=12,
decoder_embed_dim=512, decoder_depth=2, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model


def mae_vit_base_patch16_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16, embed_dim=768, depth=12, num_heads=12,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model


def mae_vit_large_patch16_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model


def mae_vit_huge_patch14_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=14, embed_dim=1280, depth=32, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model


# set recommended archs
mae_vit_small_patch16_dec512d8 = mae_vit_small_patch16_dec512d8b
mae_vit_small_patch16_dec512d2 = mae_vit_small_patch16_dec512d2b
mae_vit_small_patch16 = mae_vit_small_patch16_dec128d2b
mae_vit_base_patch16_dec512d2 = mae_vit_base_patch16_dec512d2b
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks
Loading

0 comments on commit e154062

Please sign in to comment.