Skip to content

Commit

Permalink
create test for vae encode & decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
eljandoubi committed Aug 26, 2024
1 parent 3b8cb55 commit f77bf9b
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 56 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ pylint:
pylint **/*.py

check:
make pylint
make pytest
make pylint

run:
echo "The code is not completed yet."
Expand Down
20 changes: 0 additions & 20 deletions src/models/vae.py

This file was deleted.

24 changes: 24 additions & 0 deletions src/tests/test_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Test VAE decoder"""

import torch
from src.configs import cfg
from src.models.decoder import VAEDecoder


VAEDecoder_OUTPUT = VAEDecoder()(torch.randn((1, 4,
cfg.LATENTS_HEIGHT,
cfg.LATENTS_WIDTH)))


def test_type_decoder() -> None:
"""Test the VAE decoder output type"""
assert isinstance(VAEDecoder_OUTPUT, torch.Tensor), \
f"The model output type {type(VAEDecoder_OUTPUT)}!={torch.Tensor}"


def test_shape_decoder() -> None:
"""Test the VAE decoder output shape"""
target_shape = (1, 3, cfg.HEIGHT, cfg.WIDTH)
shape = VAEDecoder_OUTPUT.shape
assert shape == target_shape, \
f"The model output shape is {shape}!={target_shape}"
28 changes: 28 additions & 0 deletions src/tests/test_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Test VAE encoder"""

import torch
from src.configs import cfg
from src.models.encoder import VAEEncoder

VAEEncoder_OUTPUT = VAEEncoder()(torch.rand((1, 3,
cfg.HEIGHT,
cfg.WIDTH)),
torch.randn((1, 4,
cfg.LATENTS_HEIGHT,
cfg.LATENTS_WIDTH)))


def test_type_encoder() -> None:
"""Test the VAE encoder output type"""
assert isinstance(VAEEncoder_OUTPUT, torch.Tensor), \
f"The model output type {type(VAEEncoder_OUTPUT)}!={torch.Tensor}"


def test_shape_encoder() -> None:
"""Test the VAE encoder output shape"""
target_shape = (1, 4,
cfg.LATENTS_HEIGHT,
cfg.LATENTS_WIDTH)
shape = VAEEncoder_OUTPUT.shape
assert shape == target_shape, \
f"The model output shape is {shape}!={target_shape}"
33 changes: 0 additions & 33 deletions src/tests/test_vae.py

This file was deleted.

9 changes: 7 additions & 2 deletions test_code.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
"""Main pytest script"""

from src.tests.test_vae import test_vae
from src.tests.test_encoder import test_type_encoder, test_shape_encoder
from src.tests.test_decoder import test_type_decoder, test_shape_decoder

test_vae()
test_type_encoder()
test_shape_encoder()

test_type_decoder()
test_shape_decoder()

0 comments on commit f77bf9b

Please sign in to comment.