From 0e27b3d93eb487b752092d323b48076aa27aaed5 Mon Sep 17 00:00:00 2001 From: Beat Buesser Date: Thu, 10 Aug 2023 14:30:25 +0200 Subject: [PATCH] Remove circular dependency in art.estimators.certification Signed-off-by: Beat Buesser --- art/estimators/certification/__init__.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/art/estimators/certification/__init__.py b/art/estimators/certification/__init__.py index 8a0a07f4ef..33a97ad7ad 100644 --- a/art/estimators/certification/__init__.py +++ b/art/estimators/certification/__init__.py @@ -2,12 +2,26 @@ This module contains certified classifiers. """ import importlib -from art.estimators.certification import randomized_smoothing -from art.estimators.certification import derandomized_smoothing +from art.estimators.certification.randomized_smoothing.randomized_smoothing import RandomizedSmoothingMixin +from art.estimators.certification.randomized_smoothing.numpy import NumpyRandomizedSmoothing +from art.estimators.certification.randomized_smoothing.tensorflow import TensorFlowV2RandomizedSmoothing +from art.estimators.certification.randomized_smoothing.pytorch import PyTorchRandomizedSmoothing +from art.estimators.certification.derandomized_smoothing.derandomized_smoothing import DeRandomizedSmoothingMixin +from art.estimators.certification.derandomized_smoothing.pytorch import PyTorchDeRandomizedSmoothing +from art.estimators.certification.derandomized_smoothing.tensorflow import TensorFlowV2DeRandomizedSmoothing if importlib.util.find_spec("torch") is not None: - from art.estimators.certification import deep_z - from art.estimators.certification import interval + from art.estimators.certification.deep_z.deep_z import ZonoDenseLayer + from art.estimators.certification.deep_z.deep_z import ZonoBounds + from art.estimators.certification.deep_z.deep_z import ZonoConv + from art.estimators.certification.deep_z.deep_z import ZonoReLU + from art.estimators.certification.deep_z.pytorch import PytorchDeepZ + from art.estimators.certification.interval.interval import PyTorchIntervalDense + from art.estimators.certification.interval.interval import PyTorchIntervalConv2D + from art.estimators.certification.interval.interval import PyTorchIntervalReLU + from art.estimators.certification.interval.interval import PyTorchIntervalFlatten + from art.estimators.certification.interval.interval import PyTorchIntervalBounds + from art.estimators.certification.interval.pytorch import PyTorchIBPClassifier else: import warnings