From 7efcd37b9ae992d38e2378c00e68d9cc643268c5 Mon Sep 17 00:00:00 2001 From: GiulioZizzo Date: Tue, 18 Apr 2023 19:05:53 +0100 Subject: [PATCH 1/3] fixing value of loss weighting when schedule is not used Signed-off-by: GiulioZizzo --- art/defences/trainer/ibp_certified_trainer_pytorch.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/art/defences/trainer/ibp_certified_trainer_pytorch.py b/art/defences/trainer/ibp_certified_trainer_pytorch.py index e807b4b377..c42cfd89ef 100644 --- a/art/defences/trainer/ibp_certified_trainer_pytorch.py +++ b/art/defences/trainer/ibp_certified_trainer_pytorch.py @@ -109,8 +109,8 @@ def __init__( classifier: "IBP_CERTIFIER_TYPE", nb_epochs: Optional[int] = 20, bound: float = 0.1, - loss_weighting: float = 0.1, batch_size: int = 32, + loss_weighting: Optional[int] = None, use_certification_schedule: bool = True, certification_schedule: Optional[Any] = None, use_loss_weighting_schedule: bool = True, @@ -133,9 +133,9 @@ def __init__( * *max_iter*: The maximum number of iterations. * *batch_size*: Size of the batch on which adversarial samples are generated. * *num_random_init*: Number of random initialisations within the epsilon ball. + :param loss_weighting: Weighting factor for the certified loss. :param bound: The perturbation range for the interval. If the default certification schedule is used will be the upper limit. - :param loss_weighting: Weighting factor for the certified loss. :param nb_epochs: Number of training epochs. :param use_certification_schedule: If to use a training schedule for the certification radius. :param certification_schedule: Schedule for gradually increasing the certification radius. Empirical studies @@ -152,6 +152,11 @@ def __init__( "art.estimators.certification.interval.pytorch.PyTorchIBPClassifier" ) + if not use_loss_weighting_schedule and loss_weighting is None: + raise ValueError( + "If a loss weighting schedule is not used then a value for loss_weighting should be supplied" + ) + super().__init__(classifier=classifier) self.classifier: "IBP_CERTIFIER_TYPE" self.pgd_params: "PGDParamDict" @@ -301,7 +306,7 @@ def fit( # pylint: disable=W0221 initial_val=0.0, final_val=0.5, epochs=epochs ) else: - loss_weighting_k = 0.1 + loss_weighting_k = self.loss_weighting for _ in tqdm(range(epochs)): if self.use_certification_schedule and self.certification_schedule is not None: From 75a9d984db427604356610b39f06649fcdd32570 Mon Sep 17 00:00:00 2001 From: GiulioZizzo Date: Tue, 18 Apr 2023 18:29:44 +0000 Subject: [PATCH 2/3] additional input checking Signed-off-by: GiulioZizzo --- art/defences/trainer/ibp_certified_trainer_pytorch.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/art/defences/trainer/ibp_certified_trainer_pytorch.py b/art/defences/trainer/ibp_certified_trainer_pytorch.py index c42cfd89ef..ac5364ddd3 100644 --- a/art/defences/trainer/ibp_certified_trainer_pytorch.py +++ b/art/defences/trainer/ibp_certified_trainer_pytorch.py @@ -154,7 +154,12 @@ def __init__( if not use_loss_weighting_schedule and loss_weighting is None: raise ValueError( - "If a loss weighting schedule is not used then a value for loss_weighting should be supplied" + "If a loss weighting schedule is not used then a value for loss_weighting should be supplied." + ) + + if use_loss_weighting_schedule and loss_weighting is not None: + raise ValueError( + "Using a loss weighting schedule is incompatible with a fixed loss_weighting." ) super().__init__(classifier=classifier) From e87ed916e50d76d434e382270f45385e3fbdc892 Mon Sep 17 00:00:00 2001 From: GiulioZizzo Date: Wed, 19 Apr 2023 08:46:46 +0000 Subject: [PATCH 3/3] formatting fixes and adding additional input checking Signed-off-by: GiulioZizzo --- art/defences/trainer/ibp_certified_trainer_pytorch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/art/defences/trainer/ibp_certified_trainer_pytorch.py b/art/defences/trainer/ibp_certified_trainer_pytorch.py index ac5364ddd3..2d5e6abb0e 100644 --- a/art/defences/trainer/ibp_certified_trainer_pytorch.py +++ b/art/defences/trainer/ibp_certified_trainer_pytorch.py @@ -158,9 +158,7 @@ def __init__( ) if use_loss_weighting_schedule and loss_weighting is not None: - raise ValueError( - "Using a loss weighting schedule is incompatible with a fixed loss_weighting." - ) + raise ValueError("Using a loss weighting schedule is incompatible with a fixed loss_weighting.") super().__init__(classifier=classifier) self.classifier: "IBP_CERTIFIER_TYPE" @@ -310,8 +308,10 @@ def fit( # pylint: disable=W0221 self.loss_weighting_schedule = self.initialise_default_scheduler( initial_val=0.0, final_val=0.5, epochs=epochs ) - else: + elif self.loss_weighting is not None: loss_weighting_k = self.loss_weighting + else: + raise ValueError("Unable to determine loss weighting.") for _ in tqdm(range(epochs)): if self.use_certification_schedule and self.certification_schedule is not None: