Skip to content

Commit

Permalink
build VAE and its test
Browse files Browse the repository at this point in the history
  • Loading branch information
eljandoubi committed Aug 26, 2024
1 parent 8131a33 commit 2ea67cc
Show file tree
Hide file tree
Showing 13 changed files with 549 additions and 5 deletions.
49 changes: 45 additions & 4 deletions .github/workflows/python-package-conda.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
name: Python Package using Conda
name: CI Pipeline

on: [push]
on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
build-linux-conda:
Expand All @@ -21,6 +27,41 @@ jobs:
- name: Build
run: |
make build
- name: Clean
pylint-test:
runs-on: ubuntu-latest
strategy:
max-parallel: 5

steps:
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v3
with:
python-version: '3.10'
- name: Add conda to system path
run: |
# $CONDA is an environment variable pointing to the root of the miniconda directory
echo $CONDA/bin >> $GITHUB_PATH
- name: Check lint
run: |
make pylint
pytest:
runs-on: ubuntu-latest
strategy:
max-parallel: 5

steps:
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v3
with:
python-version: '3.10'
- name: Add conda to system path
run: |
# $CONDA is an environment variable pointing to the root of the miniconda directory
echo $CONDA/bin >> $GITHUB_PATH
- name: Test models
run: |
make clean
make pytest
10 changes: 10 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,15 @@ build:
conda activate $(ENV_NAME) && \
pip install -r requirements.txt

pytest:
. $(CONDA_BASE)/etc/profile.d/conda.sh && \
conda activate $(ENV_NAME) && \
pytest test_models.py

pylint:
. $(CONDA_BASE)/etc/profile.d/conda.sh && \
conda activate $(ENV_NAME) && \
pylint **/*.py

clean:
conda remove --name $(ENV_NAME) --all -y
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ torch==2.4.0
torchvision==0.19.0
tqdm==4.66.4
pillow==10.4.0
numpy==1.26.4
numpy==1.26.4
pylint==3.2.6
pytest==8.3.2
Binary file added samples/EiffelTower.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions src/configs/cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""It holds some constants"""

WIDTH = 512
HEIGHT = 512
LATENTS_WIDTH = WIDTH // 8
LATENTS_HEIGHT = HEIGHT // 8

ALLOW_CUDA = False
ALLOW_MPS = False
77 changes: 77 additions & 0 deletions src/models/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Attention mechanisms"""

import math
import torch
from torch import nn
from torch.nn import functional as F


class SelfAttention(nn.Module):
"""Self Attention"""

def __init__(self, n_heads: int,
d_embed: int,
in_proj_bias: bool = True,
out_proj_bias: bool = True
):
super().__init__()
# This combines the Wq, Wk and Wv matrices into one matrix
self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
# This one represents the Wo matrix
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
self.n_heads = n_heads
self.d_head = d_embed // n_heads

def forward(self, x: torch.Tensor,
causal_mask: bool = False) -> torch.Tensor:
"""Foward method"""
# (Batch_Size, Seq_Len, Dim)
input_shape = x.shape

# (Batch_Size, Seq_Len, Dim)
batch_size, sequence_length, _ = input_shape

# (Batch_Size, Seq_Len, H, Dim / H)
inter_shape = (batch_size, sequence_length, self.n_heads, self.d_head)

# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim * 3)
qkv: torch.Tensor = self.in_proj(x)

# -> 3 tensor of shape (Batch_Size, Seq_Len, Dim)
q, k, v = qkv.chunk(3, dim=-1)

# (Batch_Size, Seq_Len, Dim) ->
# (Batch_Size, Seq_Len, H, Dim / H) ->
# (Batch_Size, H, Seq_Len, Dim / H)
q = q.view(inter_shape).transpose(1, 2)
k = k.view(inter_shape).transpose(1, 2)
v = v.view(inter_shape).transpose(1, 2)

# (Batch_Size, H, Seq_Len, Dim / H) @
# (Batch_Size, H, Dim / H, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
weight = q @ k.transpose(-1, -2) / math.sqrt(self.d_head)

if causal_mask:
# Mask where the upper triangle (above the principal diagonal) is 1
mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
# Fill the upper triangle with -inf
weight.masked_fill_(mask, -torch.inf)

# (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
weight = F.softmax(weight, dim=-1)

# (Batch_Size, H, Seq_Len, Seq_Len) @
# (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
output = weight @ v

# (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, Seq_Len, H, Dim / H)
output = output.transpose(1, 2)

# (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, Seq_Len, Dim)
output = output.reshape(input_shape)

# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
output = self.out_proj(output)

# (Batch_Size, Seq_Len, Dim)
return output
107 changes: 107 additions & 0 deletions src/models/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""VAE Decoder Pytorch Module"""

import torch
from torch import nn
from src.models.vae_blocks import VAEResidualBlock, VAEAttentionBlock


class VAEDecoder(nn.Module):
"""VAE Decoder"""

def __init__(self):
super().__init__()
self.layers = nn.ModuleList(
[
# (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
nn.Conv2d(4, 4, kernel_size=1, padding=0),

# (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
nn.Conv2d(4, 512, kernel_size=3, padding=1),

# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
VAEResidualBlock(512, 512),

# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
VAEAttentionBlock(512),

# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
VAEResidualBlock(512, 512),

# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
VAEResidualBlock(512, 512),

# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
VAEResidualBlock(512, 512),

# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
VAEResidualBlock(512, 512),

# Repeats the rows and columns of the data by scale_factor.
# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 4, Width / 4)
nn.Upsample(scale_factor=2),

# (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
nn.Conv2d(512, 512, kernel_size=3, padding=1),

# (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
VAEResidualBlock(512, 512),

# (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
VAEResidualBlock(512, 512),

# (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
VAEResidualBlock(512, 512),

# (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 2, Width / 2)
nn.Upsample(scale_factor=2),

# (Batch_Size, 512, Height / 2, Width / 2) -> (Batch_Size, 512, Height / 2, Width / 2)
nn.Conv2d(512, 512, kernel_size=3, padding=1),

# (Batch_Size, 512, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
VAEResidualBlock(512, 256),

# (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
VAEResidualBlock(256, 256),

# (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
VAEResidualBlock(256, 256),

# (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height, Width)
nn.Upsample(scale_factor=2),

# (Batch_Size, 256, Height, Width) -> (Batch_Size, 256, Height, Width)
nn.Conv2d(256, 256, kernel_size=3, padding=1),

# (Batch_Size, 256, Height, Width) -> (Batch_Size, 128, Height, Width)
VAEResidualBlock(256, 128),

# (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
VAEResidualBlock(128, 128),

# (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
VAEResidualBlock(128, 128),

# (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
nn.GroupNorm(32, 128),

# (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
nn.SiLU(),

# (Batch_Size, 128, Height, Width) -> (Batch_Size, 3, Height, Width)
nn.Conv2d(128, 3, kernel_size=3, padding=1),
]
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward method"""
# x: (Batch_Size, 4, Height / 8, Width / 8)
# Remove the scaling added by the Encoder.
x = x / 0.18215

for layer in self.layers:

x = layer(x)

# (Batch_Size, 3, Height, Width)
return x
Loading

0 comments on commit 2ea67cc

Please sign in to comment.