Skip to content

Commit

Permalink
Merge branch 'dev_1.14.1' into development_issue_2116
Browse files Browse the repository at this point in the history
  • Loading branch information
beat-buesser authored Apr 20, 2023
2 parents 731c1d3 + 63f3501 commit cb01479
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions art/defences/trainer/ibp_certified_trainer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def __init__(
classifier: "IBP_CERTIFIER_TYPE",
nb_epochs: Optional[int] = 20,
bound: float = 0.1,
loss_weighting: float = 0.1,
batch_size: int = 32,
loss_weighting: Optional[int] = None,
use_certification_schedule: bool = True,
certification_schedule: Optional[Any] = None,
use_loss_weighting_schedule: bool = True,
Expand All @@ -133,9 +133,9 @@ def __init__(
* *max_iter*: The maximum number of iterations.
* *batch_size*: Size of the batch on which adversarial samples are generated.
* *num_random_init*: Number of random initialisations within the epsilon ball.
:param loss_weighting: Weighting factor for the certified loss.
:param bound: The perturbation range for the interval. If the default certification schedule is used
will be the upper limit.
:param loss_weighting: Weighting factor for the certified loss.
:param nb_epochs: Number of training epochs.
:param use_certification_schedule: If to use a training schedule for the certification radius.
:param certification_schedule: Schedule for gradually increasing the certification radius. Empirical studies
Expand All @@ -152,6 +152,14 @@ def __init__(
"art.estimators.certification.interval.pytorch.PyTorchIBPClassifier"
)

if not use_loss_weighting_schedule and loss_weighting is None:
raise ValueError(
"If a loss weighting schedule is not used then a value for loss_weighting should be supplied."
)

if use_loss_weighting_schedule and loss_weighting is not None:
raise ValueError("Using a loss weighting schedule is incompatible with a fixed loss_weighting.")

super().__init__(classifier=classifier)
self.classifier: "IBP_CERTIFIER_TYPE"
self.pgd_params: "PGDParamDict"
Expand Down Expand Up @@ -300,8 +308,10 @@ def fit( # pylint: disable=W0221
self.loss_weighting_schedule = self.initialise_default_scheduler(
initial_val=0.0, final_val=0.5, epochs=epochs
)
elif self.loss_weighting is not None:
loss_weighting_k = self.loss_weighting
else:
loss_weighting_k = 0.1
raise ValueError("Unable to determine loss weighting.")

for _ in tqdm(range(epochs)):
if self.use_certification_schedule and self.certification_schedule is not None:
Expand Down

0 comments on commit cb01479

Please sign in to comment.