From 3cecfb8b017cff812510a91cf75bd01b9e817ffa Mon Sep 17 00:00:00 2001 From: Arjan Gijsberts Date: Tue, 18 Jun 2024 23:34:01 +0200 Subject: [PATCH 1/6] Simplified class initializers to no longer use an implicit config element passed via the keyword arguments. This takes into account recent changes to huggingface-hub mixins for serialization. --- dgmr/common.py | 21 +---------- dgmr/dgmr.py | 21 +---------- dgmr/discriminators.py | 33 +++-------------- dgmr/generators.py | 13 ++----- dgmr/hub.py | 6 +-- requirements.txt | 3 +- tests/test_model.py | 83 ++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 101 insertions(+), 79 deletions(-) diff --git a/dgmr/common.py b/dgmr/common.py index 2cf5e70..e2b1d23 100644 --- a/dgmr/common.py +++ b/dgmr/common.py @@ -290,8 +290,7 @@ def __init__( input_channels: int = 1, output_channels: int = 768, num_context_steps: int = 4, - conv_type: str = "standard", - **kwargs + conv_type: str = "standard" ): """ Conditioning Stack using the context images from Skillful Nowcasting, , see https://arxiv.org/pdf/2104.00954.pdf @@ -302,14 +301,6 @@ def __init__( conv_type: Type of 2D convolution to use, see satflow/models/utils.py for options """ super().__init__() - config = locals() - config.pop("__class__") - config.pop("self") - self.config = kwargs.get("config", config) - input_channels = self.config["input_channels"] - output_channels = self.config["output_channels"] - num_context_steps = self.config["num_context_steps"] - conv_type = self.config["conv_type"] conv2d = get_conv_layer(conv_type) self.space2depth = PixelUnshuffle(downscale_factor=2) @@ -416,8 +407,7 @@ def __init__( self, shape: (int, int, int) = (8, 8, 8), output_channels: int = 768, - use_attention: bool = True, - **kwargs + use_attention: bool = True ): """ Latent conditioning stack from Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf @@ -428,13 +418,6 @@ def __init__( use_attention: Whether to have a self-attention block or not """ super().__init__() - config = locals() - config.pop("__class__") - config.pop("self") - self.config = kwargs.get("config", config) - shape = self.config["shape"] - output_channels = self.config["output_channels"] - use_attention = self.config["use_attention"] self.shape = shape self.use_attention = use_attention diff --git a/dgmr/dgmr.py b/dgmr/dgmr.py index d1a71af..c3dab8b 100644 --- a/dgmr/dgmr.py +++ b/dgmr/dgmr.py @@ -6,6 +6,7 @@ from dgmr.discriminators import Discriminator from dgmr.generators import Generator, Sampler from dgmr.hub import NowcastingModelHubMixin + from dgmr.losses import ( GridCellLoss, NowcastingLoss, @@ -33,8 +34,7 @@ def __init__( beta2: float = 0.999, latent_channels: int = 768, context_channels: int = 384, - generation_steps: int = 6, - **kwargs, + generation_steps: int = 6 ): """ Nowcasting GAN is an attempt to recreate DeepMind's Skillful Nowcasting GAN from https://arxiv.org/abs/2104.00954 @@ -59,23 +59,6 @@ def __init__( pretrained: """ super().__init__() - config = locals() - config.pop("__class__") - config.pop("self") - self.config = kwargs.get("config", config) - input_channels = self.config["input_channels"] - forecast_steps = self.config["forecast_steps"] - output_shape = self.config["output_shape"] - gen_lr = self.config["gen_lr"] - disc_lr = self.config["disc_lr"] - conv_type = self.config["conv_type"] - num_samples = self.config["num_samples"] - grid_lambda = self.config["grid_lambda"] - beta1 = self.config["beta1"] - beta2 = self.config["beta2"] - latent_channels = self.config["latent_channels"] - context_channels = self.config["context_channels"] - visualize = self.config["visualize"] self.gen_lr = gen_lr self.disc_lr = disc_lr self.beta1 = beta1 diff --git a/dgmr/discriminators.py b/dgmr/discriminators.py index a2711e3..20444c6 100644 --- a/dgmr/discriminators.py +++ b/dgmr/discriminators.py @@ -12,17 +12,9 @@ def __init__( self, input_channels: int = 12, num_spatial_frames: int = 8, - conv_type: str = "standard", - **kwargs + conv_type: str = "standard" ): super().__init__() - config = locals() - config.pop("__class__") - config.pop("self") - self.config = kwargs.get("config", config) - input_channels = self.config["input_channels"] - num_spatial_frames = self.config["num_spatial_frames"] - conv_type = self.config["conv_type"] self.spatial_discriminator = SpatialDiscriminator( input_channels=input_channels, num_timesteps=num_spatial_frames, conv_type=conv_type @@ -40,7 +32,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TemporalDiscriminator(torch.nn.Module, PyTorchModelHubMixin): def __init__( - self, input_channels: int = 12, num_layers: int = 3, conv_type: str = "standard", **kwargs + self, + input_channels: int = 12, + num_layers: int = 3, + conv_type: str = "standard" ): """ Temporal Discriminator from the Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf @@ -52,13 +47,6 @@ def __init__( conv_type: Type of 2d convolutions to use, see satflow/models/utils.py for options """ super().__init__() - config = locals() - config.pop("__class__") - config.pop("self") - self.config = kwargs.get("config", config) - input_channels = self.config["input_channels"] - num_layers = self.config["num_layers"] - conv_type = self.config["conv_type"] self.downsample = torch.nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) self.space2depth = PixelUnshuffle(downscale_factor=2) @@ -138,8 +126,7 @@ def __init__( input_channels: int = 12, num_timesteps: int = 8, num_layers: int = 4, - conv_type: str = "standard", - **kwargs + conv_type: str = "standard" ): """ Spatial discriminator from Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf @@ -151,14 +138,6 @@ def __init__( conv_type: Type of 2d convolutions to use, see satflow/models/utils.py for options """ super().__init__() - config = locals() - config.pop("__class__") - config.pop("self") - self.config = kwargs.get("config", config) - input_channels = self.config["input_channels"] - num_timesteps = self.config["num_timesteps"] - num_layers = self.config["num_layers"] - conv_type = self.config["conv_type"] # Randomly, uniformly, select 8 timesteps to do this on from the input self.num_timesteps = num_timesteps # First step is mean pooling 2x2 to reduce input by half diff --git a/dgmr/generators.py b/dgmr/generators.py index c4b744e..d78746c 100644 --- a/dgmr/generators.py +++ b/dgmr/generators.py @@ -21,8 +21,7 @@ def __init__( forecast_steps: int = 18, latent_channels: int = 768, context_channels: int = 384, - output_channels: int = 1, - **kwargs + output_channels: int = 1 ): """ Sampler from the Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf @@ -35,14 +34,8 @@ def __init__( latent_channels: Number of input channels to the lowest ConvGRU layer """ super().__init__() - config = locals() - config.pop("__class__") - config.pop("self") - self.config = kwargs.get("config", config) - self.forecast_steps = self.config["forecast_steps"] - latent_channels = self.config["latent_channels"] - context_channels = self.config["context_channels"] - output_channels = self.config["output_channels"] + + self.forecast_steps = forecast_steps self.convGRU1 = ConvGRU( input_channels=latent_channels + context_channels, diff --git a/dgmr/hub.py b/dgmr/hub.py index 821fbf2..871b31e 100644 --- a/dgmr/hub.py +++ b/dgmr/hub.py @@ -129,7 +129,7 @@ def _from_pretrained( proxies, resume_download, local_files_only, - use_auth_token=False, + token=False, map_location="cpu", strict=False, **model_kwargs, @@ -148,10 +148,10 @@ def _from_pretrained( force_download=force_download, proxies=proxies, resume_download=resume_download, - token=use_auth_token, + token=token, local_files_only=local_files_only, ) - model = cls(**model_kwargs["config"]) + model = cls(**model_kwargs) state_dict = torch.load(model_file, map_location=map_location) model.load_state_dict(state_dict, strict=strict) diff --git a/requirements.txt b/requirements.txt index bcf1180..fa713f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ numpy torchvision>=0.11.0 pytorch_lightning einops -huggingface_hub==0.21.4 +huggingface_hub>=0.23.3 +safetensors diff --git a/tests/test_model.py b/tests/test_model.py index 34b8e3e..6cf9cbb 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -16,6 +16,14 @@ import einops import pytest from pytorch_lightning import Trainer +from torch.testing import assert_close + + +def assert_model_equal(actual, expected): + assert(actual.state_dict().keys() == expected.state_dict().keys()) + + for x, y in zip(actual.state_dict().values(), expected.state_dict().values()): + assert_close(x, y) def test_dblock(): @@ -328,3 +336,78 @@ def __getitem__(self, idx): model = DGMR(forecast_steps=forecast_steps) trainer.fit(model, train_loader, val_loader) + + +def test_model_serialization(tmp_path): + model = DGMR( + forecast_steps=1, + input_channels=1, + output_shape=128, + gen_lr=1e-5, + disc_lr=1e-4, + visualize=True, + conv_type="standard", + num_samples=1, + grid_lambda=16.0, + beta1=1.0, + beta2=0.995, + latent_channels=512, + context_channels=256, + generation_steps=1 + ) + + model.save_pretrained(tmp_path / "dgmr") + modelcopy = DGMR.from_pretrained(tmp_path / "dgmr") + assert(model.hparams == modelcopy.hparams) + assert_model_equal(model, modelcopy) + + +def test_discriminator_serialization(tmp_path): + discriminator = Discriminator( + input_channels=1, + num_spatial_frames=1, + conv_type="standard" + ) + + discriminator.save_pretrained(tmp_path / "discriminator") + discriminatorcopy = Discriminator.from_pretrained(tmp_path / "discriminator") + assert_model_equal(discriminator, discriminatorcopy) + + +def test_sampler_serialization(tmp_path): + sampler = Sampler( + forecast_steps=1, + latent_channels=256, + context_channels=256, + output_channels=1 + ) + + sampler.save_pretrained(tmp_path / "sampler") + samplercopy = Sampler.from_pretrained(tmp_path / "sampler") + assert_model_equal(sampler, samplercopy) + + +def test_context_conditioning_stack_serialization(tmp_path): + ctz = ContextConditioningStack( + input_channels=2, + output_channels=256, + num_context_steps=1, + conv_type="standard" + ) + + ctz.save_pretrained(tmp_path / "context-conditioning-stack") + ctzcopy = ContextConditioningStack.from_pretrained(tmp_path / "context-conditioning-stack") + assert_model_equal(ctz, ctzcopy) + + +def test_latent_conditioning_stack_serialization(tmp_path): + lat = LatentConditioningStack( + shape=(4, 4, 4), + output_channels=256, + use_attention=True + ) + + lat.save_pretrained(tmp_path / "latent-conditioning-stack") + latcopy = LatentConditioningStack.from_pretrained(tmp_path / "latent-conditioning-stack") + assert_model_equal(lat, latcopy) + From b57aa074e56f88ea2c9b7e5294ca883ea43fd91c Mon Sep 17 00:00:00 2001 From: Arjan Gijsberts Date: Wed, 31 Jul 2024 11:16:43 +0200 Subject: [PATCH 2/6] Renamed some variable names in snake_case for style consistency. --- tests/test_model.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 6cf9cbb..18198a1 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -357,9 +357,9 @@ def test_model_serialization(tmp_path): ) model.save_pretrained(tmp_path / "dgmr") - modelcopy = DGMR.from_pretrained(tmp_path / "dgmr") - assert(model.hparams == modelcopy.hparams) - assert_model_equal(model, modelcopy) + model_copy = DGMR.from_pretrained(tmp_path / "dgmr") + assert(model.hparams == model_copy.hparams) + assert_model_equal(model, model_copy) def test_discriminator_serialization(tmp_path): @@ -370,8 +370,8 @@ def test_discriminator_serialization(tmp_path): ) discriminator.save_pretrained(tmp_path / "discriminator") - discriminatorcopy = Discriminator.from_pretrained(tmp_path / "discriminator") - assert_model_equal(discriminator, discriminatorcopy) + discriminator_copy = Discriminator.from_pretrained(tmp_path / "discriminator") + assert_model_equal(discriminator, discriminator_copy) def test_sampler_serialization(tmp_path): @@ -383,8 +383,8 @@ def test_sampler_serialization(tmp_path): ) sampler.save_pretrained(tmp_path / "sampler") - samplercopy = Sampler.from_pretrained(tmp_path / "sampler") - assert_model_equal(sampler, samplercopy) + sampler_copy = Sampler.from_pretrained(tmp_path / "sampler") + assert_model_equal(sampler, sampler_copy) def test_context_conditioning_stack_serialization(tmp_path): @@ -396,8 +396,8 @@ def test_context_conditioning_stack_serialization(tmp_path): ) ctz.save_pretrained(tmp_path / "context-conditioning-stack") - ctzcopy = ContextConditioningStack.from_pretrained(tmp_path / "context-conditioning-stack") - assert_model_equal(ctz, ctzcopy) + ctz_copy = ContextConditioningStack.from_pretrained(tmp_path / "context-conditioning-stack") + assert_model_equal(ctz, ctz_copy) def test_latent_conditioning_stack_serialization(tmp_path): @@ -408,6 +408,6 @@ def test_latent_conditioning_stack_serialization(tmp_path): ) lat.save_pretrained(tmp_path / "latent-conditioning-stack") - latcopy = LatentConditioningStack.from_pretrained(tmp_path / "latent-conditioning-stack") - assert_model_equal(lat, latcopy) + lat_copy = LatentConditioningStack.from_pretrained(tmp_path / "latent-conditioning-stack") + assert_model_equal(lat, lat_copy) From 094c6fa5e983b7cffb21704d65926357441d2169 Mon Sep 17 00:00:00 2001 From: Arjan Gijsberts Date: Wed, 31 Jul 2024 14:36:50 +0200 Subject: [PATCH 3/6] Migrated from custom NowcastingModelHubMixin to huggingface_hub.PyTorchModelHubMixin. --- dgmr/dgmr.py | 10 +++- dgmr/hub.py | 160 --------------------------------------------------- 2 files changed, 8 insertions(+), 162 deletions(-) delete mode 100644 dgmr/hub.py diff --git a/dgmr/dgmr.py b/dgmr/dgmr.py index c3dab8b..1ffd4c8 100644 --- a/dgmr/dgmr.py +++ b/dgmr/dgmr.py @@ -1,11 +1,11 @@ import pytorch_lightning as pl import torch import torchvision +from huggingface_hub import PyTorchModelHubMixin from dgmr.common import ContextConditioningStack, LatentConditioningStack from dgmr.discriminators import Discriminator from dgmr.generators import Generator, Sampler -from dgmr.hub import NowcastingModelHubMixin from dgmr.losses import ( GridCellLoss, @@ -16,7 +16,13 @@ ) -class DGMR(pl.LightningModule, NowcastingModelHubMixin): +class DGMR( + pl.LightningModule, + PyTorchModelHubMixin, + library_name="DGMR", + tags=["nowcasting", "forecasting", "timeseries", "remote-sensing", "gan"], + repo_url="https://github.com/openclimatefix/skillful_nowcasting" +): """Deep Generative Model of Radar""" def __init__( diff --git a/dgmr/hub.py b/dgmr/hub.py deleted file mode 100644 index 871b31e..0000000 --- a/dgmr/hub.py +++ /dev/null @@ -1,160 +0,0 @@ -""" -Originally Taken from https://github.com/rwightman/ - -https://github.com/rwightman/pytorch-image-models/ -blob/acd6c687fd1c0507128f0ce091829b233c8560b9/timm/models/hub.py -""" - -import json -import logging -import os -from functools import partial - -import torch - -try: - from huggingface_hub import cached_download, hf_hub_url - - cached_download = partial(cached_download, library_name="dgmr") -except ImportError: - hf_hub_url = None - cached_download = None - -from huggingface_hub import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, ModelHubMixin, hf_hub_download - -MODEL_CARD_MARKDOWN = """--- -license: mit -tags: -- nowcasting -- forecasting -- timeseries -- remote-sensing -- gan ---- - -# {model_name} - -## Model description - -[More information needed] - -## Intended uses & limitations - -[More information needed] - -## How to use - -[More information needed] - -## Limitations and bias - -[More information needed] - -## Training data - -[More information needed] - -## Training procedure - -[More information needed] - -## Evaluation results - -[More information needed] - -""" - -_logger = logging.getLogger(__name__) - - -class NowcastingModelHubMixin(ModelHubMixin): - """ - HuggingFace ModelHubMixin containing specific adaptions for Nowcasting models - """ - - def __init__(self, *args, **kwargs): - """ - Mixin for pl.LightningModule and Hugging Face - - Mix this class with your pl.LightningModule class to easily push / download - the model via the Hugging Face Hub - - Example:: - - >>> from dgmr.hub import NowcastingModelHubMixin - - >>> class MyModel(nn.Module, NowcastingModelHubMixin): - ... def __init__(self, **kwargs): - ... super().__init__() - ... self.layer = ... - ... def forward(self, ...) - ... return ... - - >>> model = MyModel() - >>> model.push_to_hub("mymodel") # Pushing model-weights to hf-hub - - >>> # Downloading weights from hf-hub & model will be initialized from those weights - >>> model = MyModel.from_pretrained("username/mymodel") - """ - - def _create_model_card(self, path): - model_card = MODEL_CARD_MARKDOWN.format(model_name=type(self).__name__) - with open(os.path.join(path, "README.md"), "w") as f: - f.write(model_card) - - def _save_config(self, module, save_directory): - config = dict(module.hparams) - path = os.path.join(save_directory, CONFIG_NAME) - with open(path, "w") as f: - json.dump(config, f) - - def _save_pretrained(self, save_directory: str, save_config: bool = True): - # Save model weights - path = os.path.join(save_directory, PYTORCH_WEIGHTS_NAME) - model_to_save = self.module if hasattr(self, "module") else self - torch.save(model_to_save.state_dict(), path) - # Save model config - if save_config and model_to_save.hparams: - self._save_config(model_to_save, save_directory) - # Save model card - self._create_model_card(save_directory) - - @classmethod - def _from_pretrained( - cls, - model_id, - revision, - cache_dir, - force_download, - proxies, - resume_download, - local_files_only, - token=False, - map_location="cpu", - strict=False, - **model_kwargs, - ): - map_location = torch.device(map_location) - - if os.path.isdir(model_id): - print("Loading weights from local directory") - model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) - else: - model_file = hf_hub_download( - repo_id=model_id, - filename=PYTORCH_WEIGHTS_NAME, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - token=token, - local_files_only=local_files_only, - ) - model = cls(**model_kwargs) - - state_dict = torch.load(model_file, map_location=map_location) - model.load_state_dict(state_dict, strict=strict) - model.eval() - - return model From c5715403c894adc78cbc8370c12d9790cc6cbeb4 Mon Sep 17 00:00:00 2001 From: Arjan Gijsberts Date: Wed, 31 Jul 2024 14:43:15 +0200 Subject: [PATCH 4/6] Removed spurious empty line in imports. --- dgmr/dgmr.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dgmr/dgmr.py b/dgmr/dgmr.py index 1ffd4c8..74dc373 100644 --- a/dgmr/dgmr.py +++ b/dgmr/dgmr.py @@ -6,7 +6,6 @@ from dgmr.common import ContextConditioningStack, LatentConditioningStack from dgmr.discriminators import Discriminator from dgmr.generators import Generator, Sampler - from dgmr.losses import ( GridCellLoss, NowcastingLoss, From 6dffc2675080ba1da30e90bb31108e889ef5c683 Mon Sep 17 00:00:00 2001 From: Arjan Gijsberts Date: Wed, 31 Jul 2024 16:06:17 +0200 Subject: [PATCH 5/6] Minor fixes to coding style. --- dgmr/common.py | 4 ++-- dgmr/dgmr.py | 4 ++-- dgmr/discriminators.py | 6 +++--- dgmr/generators.py | 2 +- tests/test_model.py | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/dgmr/common.py b/dgmr/common.py index e2b1d23..13a66f9 100644 --- a/dgmr/common.py +++ b/dgmr/common.py @@ -290,7 +290,7 @@ def __init__( input_channels: int = 1, output_channels: int = 768, num_context_steps: int = 4, - conv_type: str = "standard" + conv_type: str = "standard", ): """ Conditioning Stack using the context images from Skillful Nowcasting, , see https://arxiv.org/pdf/2104.00954.pdf @@ -407,7 +407,7 @@ def __init__( self, shape: (int, int, int) = (8, 8, 8), output_channels: int = 768, - use_attention: bool = True + use_attention: bool = True, ): """ Latent conditioning stack from Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf diff --git a/dgmr/dgmr.py b/dgmr/dgmr.py index 74dc373..def45e3 100644 --- a/dgmr/dgmr.py +++ b/dgmr/dgmr.py @@ -20,7 +20,7 @@ class DGMR( PyTorchModelHubMixin, library_name="DGMR", tags=["nowcasting", "forecasting", "timeseries", "remote-sensing", "gan"], - repo_url="https://github.com/openclimatefix/skillful_nowcasting" + repo_url="https://github.com/openclimatefix/skillful_nowcasting", ): """Deep Generative Model of Radar""" @@ -39,7 +39,7 @@ def __init__( beta2: float = 0.999, latent_channels: int = 768, context_channels: int = 384, - generation_steps: int = 6 + generation_steps: int = 6, ): """ Nowcasting GAN is an attempt to recreate DeepMind's Skillful Nowcasting GAN from https://arxiv.org/abs/2104.00954 diff --git a/dgmr/discriminators.py b/dgmr/discriminators.py index 20444c6..cb18b3e 100644 --- a/dgmr/discriminators.py +++ b/dgmr/discriminators.py @@ -12,7 +12,7 @@ def __init__( self, input_channels: int = 12, num_spatial_frames: int = 8, - conv_type: str = "standard" + conv_type: str = "standard", ): super().__init__() @@ -35,7 +35,7 @@ def __init__( self, input_channels: int = 12, num_layers: int = 3, - conv_type: str = "standard" + conv_type: str = "standard", ): """ Temporal Discriminator from the Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf @@ -126,7 +126,7 @@ def __init__( input_channels: int = 12, num_timesteps: int = 8, num_layers: int = 4, - conv_type: str = "standard" + conv_type: str = "standard", ): """ Spatial discriminator from Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf diff --git a/dgmr/generators.py b/dgmr/generators.py index d78746c..cc92a12 100644 --- a/dgmr/generators.py +++ b/dgmr/generators.py @@ -21,7 +21,7 @@ def __init__( forecast_steps: int = 18, latent_channels: int = 768, context_channels: int = 384, - output_channels: int = 1 + output_channels: int = 1, ): """ Sampler from the Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf diff --git a/tests/test_model.py b/tests/test_model.py index 18198a1..f1d7e2e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -20,7 +20,7 @@ def assert_model_equal(actual, expected): - assert(actual.state_dict().keys() == expected.state_dict().keys()) + assert actual.state_dict().keys() == expected.state_dict().keys() for x, y in zip(actual.state_dict().values(), expected.state_dict().values()): assert_close(x, y) @@ -358,7 +358,7 @@ def test_model_serialization(tmp_path): model.save_pretrained(tmp_path / "dgmr") model_copy = DGMR.from_pretrained(tmp_path / "dgmr") - assert(model.hparams == model_copy.hparams) + assert model.hparams == model_copy.hparams assert_model_equal(model, model_copy) From 85af19535b56ef0d2a74475f8152614d1c43f577 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 31 Jul 2024 14:27:53 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_model.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index f1d7e2e..b7c956e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -353,7 +353,7 @@ def test_model_serialization(tmp_path): beta2=0.995, latent_channels=512, context_channels=256, - generation_steps=1 + generation_steps=1, ) model.save_pretrained(tmp_path / "dgmr") @@ -363,11 +363,7 @@ def test_model_serialization(tmp_path): def test_discriminator_serialization(tmp_path): - discriminator = Discriminator( - input_channels=1, - num_spatial_frames=1, - conv_type="standard" - ) + discriminator = Discriminator(input_channels=1, num_spatial_frames=1, conv_type="standard") discriminator.save_pretrained(tmp_path / "discriminator") discriminator_copy = Discriminator.from_pretrained(tmp_path / "discriminator") @@ -376,10 +372,7 @@ def test_discriminator_serialization(tmp_path): def test_sampler_serialization(tmp_path): sampler = Sampler( - forecast_steps=1, - latent_channels=256, - context_channels=256, - output_channels=1 + forecast_steps=1, latent_channels=256, context_channels=256, output_channels=1 ) sampler.save_pretrained(tmp_path / "sampler") @@ -389,10 +382,7 @@ def test_sampler_serialization(tmp_path): def test_context_conditioning_stack_serialization(tmp_path): ctz = ContextConditioningStack( - input_channels=2, - output_channels=256, - num_context_steps=1, - conv_type="standard" + input_channels=2, output_channels=256, num_context_steps=1, conv_type="standard" ) ctz.save_pretrained(tmp_path / "context-conditioning-stack") @@ -401,13 +391,8 @@ def test_context_conditioning_stack_serialization(tmp_path): def test_latent_conditioning_stack_serialization(tmp_path): - lat = LatentConditioningStack( - shape=(4, 4, 4), - output_channels=256, - use_attention=True - ) + lat = LatentConditioningStack(shape=(4, 4, 4), output_channels=256, use_attention=True) lat.save_pretrained(tmp_path / "latent-conditioning-stack") lat_copy = LatentConditioningStack.from_pretrained(tmp_path / "latent-conditioning-stack") assert_model_equal(lat, lat_copy) -