From 2a3290a97369258c80ce7dfe49159e0b878ac821 Mon Sep 17 00:00:00 2001 From: GiulioZizzo Date: Fri, 15 Dec 2023 09:12:24 +0000 Subject: [PATCH] mypy fixes Signed-off-by: GiulioZizzo --- .../certification/derandomized_smoothing/pytorch.py | 4 ++-- art/estimators/classification/tensorflow.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/art/estimators/certification/derandomized_smoothing/pytorch.py b/art/estimators/certification/derandomized_smoothing/pytorch.py index 059078d27f..792e45974b 100644 --- a/art/estimators/certification/derandomized_smoothing/pytorch.py +++ b/art/estimators/certification/derandomized_smoothing/pytorch.py @@ -438,10 +438,10 @@ def fit( # pylint: disable=W0221 training_mode: bool = True, drop_last: bool = False, scheduler: Optional[Any] = None, + verbose: Optional[Union[bool, int]] = None, update_batchnorm: bool = True, batchnorm_update_epochs: int = 1, transform: Optional["torchvision.transforms.transforms.Compose"] = None, - verbose: Optional[Union[bool, int]] = None, **kwargs, ) -> None: """ @@ -457,13 +457,13 @@ def fit( # pylint: disable=W0221 the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: ``False``) :param scheduler: Learning rate scheduler to run at the start of every epoch. + :param verbose: if to display training progress bars :param update_batchnorm: ViT specific argument. If to run the training data through the model to update any batch norm statistics prior to training. Useful on small datasets when using pre-trained ViTs. :param batchnorm_update_epochs: ViT specific argument. How many times to forward pass over the training data to pre-adjust the batchnorm statistics. :param transform: ViT specific argument. Torchvision compose of relevant augmentation transformations to apply. - :param verbose: if to display training progress bars :param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch and providing it takes no effect. """ diff --git a/art/estimators/classification/tensorflow.py b/art/estimators/classification/tensorflow.py index 6bede5d971..b26b0e3ec9 100644 --- a/art/estimators/classification/tensorflow.py +++ b/art/estimators/classification/tensorflow.py @@ -376,7 +376,8 @@ def fit_generator( ) ): for _ in tqdm(range(nb_epochs), disable=not display_pb, desc="Epochs"): - for _ in tqdm(range(int(generator.size / generator.batch_size)), disable=not display_pb, desc="Batches"): # type: ignore + num_bathces = int(generator.size / generator.batch_size) + for _ in tqdm(range(num_bathces), disable=not display_pb, desc="Batches"): # type: ignore i_batch, o_batch = generator.get_batch() if self._reduce_labels: