Skip to content

Commit

Permalink
mypy fixes
Browse files Browse the repository at this point in the history
Signed-off-by: GiulioZizzo <[email protected]>
  • Loading branch information
GiulioZizzo committed Dec 15, 2023
1 parent 5fb24d5 commit 2a3290a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -438,10 +438,10 @@ def fit( # pylint: disable=W0221
training_mode: bool = True,
drop_last: bool = False,
scheduler: Optional[Any] = None,
verbose: Optional[Union[bool, int]] = None,
update_batchnorm: bool = True,
batchnorm_update_epochs: int = 1,
transform: Optional["torchvision.transforms.transforms.Compose"] = None,
verbose: Optional[Union[bool, int]] = None,
**kwargs,
) -> None:
"""
Expand All @@ -457,13 +457,13 @@ def fit( # pylint: disable=W0221
the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then
the last batch will be smaller. (default: ``False``)
:param scheduler: Learning rate scheduler to run at the start of every epoch.
:param verbose: if to display training progress bars
:param update_batchnorm: ViT specific argument.
If to run the training data through the model to update any batch norm statistics prior
to training. Useful on small datasets when using pre-trained ViTs.
:param batchnorm_update_epochs: ViT specific argument. How many times to forward pass over the training data
to pre-adjust the batchnorm statistics.
:param transform: ViT specific argument. Torchvision compose of relevant augmentation transformations to apply.
:param verbose: if to display training progress bars
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
and providing it takes no effect.
"""
Expand Down
3 changes: 2 additions & 1 deletion art/estimators/classification/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ def fit_generator(
)
):
for _ in tqdm(range(nb_epochs), disable=not display_pb, desc="Epochs"):
for _ in tqdm(range(int(generator.size / generator.batch_size)), disable=not display_pb, desc="Batches"): # type: ignore
num_bathces = int(generator.size / generator.batch_size)
for _ in tqdm(range(num_bathces), disable=not display_pb, desc="Batches"): # type: ignore
i_batch, o_batch = generator.get_batch()

if self._reduce_labels:
Expand Down

0 comments on commit 2a3290a

Please sign in to comment.