-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
create test for vae encode & decoder
- Loading branch information
1 parent
3b8cb55
commit f77bf9b
Showing
6 changed files
with
60 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
"""Test VAE decoder""" | ||
|
||
import torch | ||
from src.configs import cfg | ||
from src.models.decoder import VAEDecoder | ||
|
||
|
||
VAEDecoder_OUTPUT = VAEDecoder()(torch.randn((1, 4, | ||
cfg.LATENTS_HEIGHT, | ||
cfg.LATENTS_WIDTH))) | ||
|
||
|
||
def test_type_decoder() -> None: | ||
"""Test the VAE decoder output type""" | ||
assert isinstance(VAEDecoder_OUTPUT, torch.Tensor), \ | ||
f"The model output type {type(VAEDecoder_OUTPUT)}!={torch.Tensor}" | ||
|
||
|
||
def test_shape_decoder() -> None: | ||
"""Test the VAE decoder output shape""" | ||
target_shape = (1, 3, cfg.HEIGHT, cfg.WIDTH) | ||
shape = VAEDecoder_OUTPUT.shape | ||
assert shape == target_shape, \ | ||
f"The model output shape is {shape}!={target_shape}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
"""Test VAE encoder""" | ||
|
||
import torch | ||
from src.configs import cfg | ||
from src.models.encoder import VAEEncoder | ||
|
||
VAEEncoder_OUTPUT = VAEEncoder()(torch.rand((1, 3, | ||
cfg.HEIGHT, | ||
cfg.WIDTH)), | ||
torch.randn((1, 4, | ||
cfg.LATENTS_HEIGHT, | ||
cfg.LATENTS_WIDTH))) | ||
|
||
|
||
def test_type_encoder() -> None: | ||
"""Test the VAE encoder output type""" | ||
assert isinstance(VAEEncoder_OUTPUT, torch.Tensor), \ | ||
f"The model output type {type(VAEEncoder_OUTPUT)}!={torch.Tensor}" | ||
|
||
|
||
def test_shape_encoder() -> None: | ||
"""Test the VAE encoder output shape""" | ||
target_shape = (1, 4, | ||
cfg.LATENTS_HEIGHT, | ||
cfg.LATENTS_WIDTH) | ||
shape = VAEEncoder_OUTPUT.shape | ||
assert shape == target_shape, \ | ||
f"The model output shape is {shape}!={target_shape}" |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,10 @@ | ||
"""Main pytest script""" | ||
|
||
from src.tests.test_vae import test_vae | ||
from src.tests.test_encoder import test_type_encoder, test_shape_encoder | ||
from src.tests.test_decoder import test_type_decoder, test_shape_decoder | ||
|
||
test_vae() | ||
test_type_encoder() | ||
test_shape_encoder() | ||
|
||
test_type_decoder() | ||
test_shape_decoder() |