Skip to content

Commit

Permalink
Simplify head of Regressor and update for Clay v1.5
Browse files Browse the repository at this point in the history
  • Loading branch information
srmsoumya authored and yellowcap committed Nov 26, 2024
1 parent d175a95 commit e8d9ca6
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 142 deletions.
23 changes: 9 additions & 14 deletions configs/regression_biomasters.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,18 @@
seed_everything: 42
data:
metadata_path: configs/metadata.yaml
batch_size: 10
batch_size: 25
num_workers: 8
train_chip_dir: data/biomasters/train_cube
train_label_dir: data/biomasters/train_agbm
val_chip_dir: data/biomasters/test_cube
val_label_dir: data/biomasters/test_agbm
model:
ckpt_path: checkpoints/clay-v1-base.ckpt
lr: 1e-3
ckpt_path: checkpoints/clay_v1.5.ckpt
lr: 1e-2
wd: 0.05
b1: 0.9
b2: 0.95
feature_maps:
- 2
- 5
- 7
- 9
- 11
trainer:
accelerator: auto
strategy: ddp
Expand All @@ -33,13 +27,14 @@ trainer:
num_sanity_val_steps: 0
# limit_train_batches: 0.25
# limit_val_batches: 0.25
accumulate_grad_batches: 4
accumulate_grad_batches: 1
logger:
- class_path: lightning.pytorch.loggers.WandbLogger
init_args:
entity: developmentseed
project: clay-regression
log_model: false
group: v1.5-test
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
Expand All @@ -55,9 +50,9 @@ trainer:
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
- class_path: src.callbacks.LayerwiseFinetuning
init_args:
phase: 10
train_bn: True
# - class_path: src.callbacks.LayerwiseFinetuning
# init_args:
# phase: 10
# train_bn: True
plugins:
- class_path: lightning.pytorch.plugins.io.AsyncCheckpointIO
56 changes: 38 additions & 18 deletions finetune/regression/biomasters_inference.ipynb

Large diffs are not rendered by default.

8 changes: 3 additions & 5 deletions finetune/regression/biomasters_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,11 @@ class BioMastersClassifier(L.LightningModule):
b2 (float): Beta2 parameter for the Adam optimizer.
"""

def __init__(self, ckpt_path, feature_maps, lr, wd, b1, b2): # noqa: PLR0913
def __init__(self, ckpt_path, lr, wd, b1, b2): # noqa: PLR0913
super().__init__()
self.save_hyperparameters()
# self.model = Classifier(num_classes=1, ckpt_path=ckpt_path)
self.model = Regressor(
num_classes=1, feature_maps=feature_maps, ckpt_path=ckpt_path
)
self.model = Regressor(num_classes=1, ckpt_path=ckpt_path)
self.loss_fn = NoNaNRMSE()
self.score_fn = MeanSquaredError()

Expand Down Expand Up @@ -110,7 +108,7 @@ def configure_optimizers(self):
weight_decay=self.hparams.wd,
betas=(self.hparams.b1, self.hparams.b2),
)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
return {
"optimizer": optimizer,
"lr_scheduler": {
Expand Down
165 changes: 60 additions & 105 deletions finetune/regression/factory.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""
Clay Segmentor for semantic segmentation tasks.
Clay Regressor for semantic regression tasks using PixelShuffle.
Attribution:
Decoder from Segformer: Simple and Efficient Design for Semantic Segmentation
with Transformers
Paper URL: https://arxiv.org/abs/2105.15203
Decoder inspired by PixelShuffle-based upsampling.
"""

import re
Expand All @@ -17,18 +15,15 @@
from src.model import Encoder


class SegmentEncoder(Encoder):
class RegressionEncoder(Encoder):
"""
Encoder class for segmentation tasks, incorporating a feature pyramid
network (FPN).
Encoder class for regression tasks.
Attributes:
feature_maps (list): Indices of layers to be used for generating
feature maps.
ckpt_path (str): Path to the clay checkpoint file.
"""

def __init__( # noqa: PLR0913
def __init__(
self,
mask_ratio,
patch_size,
Expand All @@ -38,7 +33,6 @@ def __init__( # noqa: PLR0913
heads,
dim_head,
mlp_ratio,
feature_maps,
ckpt_path=None,
):
super().__init__(
Expand All @@ -51,30 +45,6 @@ def __init__( # noqa: PLR0913
dim_head,
mlp_ratio,
)
self.feature_maps = feature_maps

# Define Feature Pyramid Network (FPN) layers
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2),
nn.BatchNorm2d(dim),
nn.GELU(),
nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2),
)

self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2),
)

self.fpn3 = nn.Identity()

self.fpn4 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
)

self.fpn5 = nn.Sequential(
nn.MaxPool2d(kernel_size=4, stride=4),
)

# Set device
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
Expand Down Expand Up @@ -119,14 +89,14 @@ def load_from_ckpt(self, ckpt_path):

def forward(self, datacube):
"""
Forward pass of the SegmentEncoder.
Forward pass of the RegressionEncoder.
Args:
datacube (dict): A dictionary containing the input datacube and
meta information like time, latlon, gsd & wavelenths.
Returns:
list: A list of feature maps extracted from the datacube.
torch.Tensor: The embeddings from the final layer.
"""
cube, time, latlon, gsd, waves = (
datacube["pixels"], # [B C H W]
Expand All @@ -146,84 +116,56 @@ def forward(self, datacube):
cls_tokens = repeat(self.cls_token, "1 1 D -> B 1 D", B=B) # [B 1 D]
patches = torch.cat((cls_tokens, patches), dim=1) # [B (1 + L) D]

features = []
for idx, (attn, ff) in enumerate(self.transformer.layers):
patches = attn(patches) + patches
patches = ff(patches) + patches
if idx in self.feature_maps:
_cube = rearrange(
patches[:, 1:, :], "B (H W) D -> B D H W", H=H // 8, W=W // 8
)
features.append(_cube)
# patches = self.transformer.norm(patches)
# _cube = rearrange(patches[:, 1:, :], "B (H W) D -> B D H W", H=H//8, W=W//8)
# features.append(_cube)

# Apply FPN layers
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4, self.fpn5]
for i in range(len(features)):
features[i] = ops[i](features[i])

return features


class FusionBlock(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1)
self.bn = nn.BatchNorm2d(output_dim)

def forward(self, x):
x = F.relu(self.bn(self.conv(x)))
return x

# Transformer encoder
patches = self.transformer(patches)

class SegmentationHead(nn.Module):
def __init__(self, input_dim, num_classes):
super().__init__()
self.conv1 = nn.Conv2d(input_dim, input_dim // 2, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(
input_dim // 2, num_classes, kernel_size=1
) # final conv to num_classes
self.bn1 = nn.BatchNorm2d(input_dim // 2)
# Remove class token
patches = patches[:, 1:, :] # [B, L, D]

def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = self.conv2(x) # No activation before final layer
return x
return patches


class Regressor(nn.Module):
"""
Clay Regressor class that combines the Encoder with FPN layers for semantic
regression.
Clay Regressor class that combines the Encoder with PixelShuffle for regression.
Attributes:
num_classes (int): Number of output classes for segmentation.
feature_maps (list): Indices of layers to be used for generating feature maps.
num_classes (int): Number of output classes for regression.
ckpt_path (str): Path to the checkpoint file.
"""

def __init__(self, num_classes, feature_maps, ckpt_path):
def __init__(self, num_classes, ckpt_path):
super().__init__()
# Default values are for the clay mae base model.
self.encoder = SegmentEncoder(
# Initialize the encoder
self.encoder = RegressionEncoder(
mask_ratio=0.0,
patch_size=8,
shuffle=False,
dim=768,
depth=12,
heads=12,
dim=1024,
depth=24,
heads=16,
dim_head=64,
mlp_ratio=4.0,
feature_maps=feature_maps,
ckpt_path=ckpt_path,
)
self.upsamples = [nn.Upsample(scale_factor=2**i) for i in range(5)]
self.fusion = FusionBlock(self.encoder.dim, self.encoder.dim // 4)
self.seg_head = nn.Conv2d(
self.encoder.dim // 4, num_classes, kernel_size=3, padding=1
)

# Freeze the encoder parameters
for param in self.encoder.parameters():
param.requires_grad = False

# Define layers after the encoder
D = self.encoder.dim # embedding dimension
hidden_dim = 512
C_out = 64
r = self.encoder.patch_size # upscale factor (patch_size)

self.conv1 = nn.Conv2d(D, hidden_dim, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(hidden_dim)
self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(hidden_dim)
self.conv_ps = nn.Conv2d(hidden_dim, C_out * r * r, kernel_size=3, padding=1)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor=r)
self.conv_out = nn.Conv2d(C_out, num_classes, kernel_size=3, padding=1)

def forward(self, datacube):
"""
Expand All @@ -234,15 +176,28 @@ def forward(self, datacube):
meta information like time, latlon, gsd & wavelenths.
Returns:
torch.Tensor: The segmentation logits.
torch.Tensor: The regression output.
"""
features = self.encoder(datacube)
for i in range(len(features)):
features[i] = self.upsamples[i](features[i])
cube = datacube["pixels"] # [B C H_in W_in]
B, C, H_in, W_in = cube.shape

# fused = torch.cat(features, dim=1)
fused = torch.sum(torch.stack(features), dim=0)
fused = self.fusion(fused)
# Get embeddings from the encoder
patches = self.encoder(datacube) # [B, L, D]

logits = self.seg_head(fused)
return logits
# Reshape embeddings to [B, D, H', W']
H_patches = H_in // self.encoder.patch_size
W_patches = W_in // self.encoder.patch_size
x = rearrange(patches, "B (H W) D -> B D H W", H=H_patches, W=W_patches)

# Pass through convolutional layers
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = self.conv_ps(x) # [B, C_out * r^2, H', W']

# Upsample using PixelShuffle
x = self.pixel_shuffle(x) # [B, C_out, H_in, W_in]

# Final convolution to get desired output channels
x = self.conv_out(x) # [B, num_outputs, H_in, W_in]

return x

0 comments on commit e8d9ca6

Please sign in to comment.