From bab5aee97fa575e07076fddbbdf86d8a7c202249 Mon Sep 17 00:00:00 2001 From: abigailt Date: Tue, 16 Apr 2024 14:13:38 +0300 Subject: [PATCH 1/3] Support sklearn models with multiple outputs (i.e., nb_classes is an array instead of an integer). Signed-off-by: abigailt --- art/estimators/classification/classifier.py | 2 +- art/utils.py | 7 ++++-- .../classification/test_scikitlearn.py | 23 +++++++++++++++++++ 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/art/estimators/classification/classifier.py b/art/estimators/classification/classifier.py index 191f4a784a..f879568859 100644 --- a/art/estimators/classification/classifier.py +++ b/art/estimators/classification/classifier.py @@ -114,7 +114,7 @@ def nb_classes(self, nb_classes: int): """ Set the number of output classes. """ - if nb_classes is None or nb_classes < 2: + if nb_classes is None or (isinstance(nb_classes, int) and nb_classes < 2): raise ValueError("nb_classes must be greater than or equal to 2.") self._nb_classes = nb_classes diff --git a/art/utils.py b/art/utils.py index 7c4ff28348..f813d6cdc3 100644 --- a/art/utils.py +++ b/art/utils.py @@ -792,15 +792,18 @@ def check_and_transform_label_format( labels: np.ndarray, nb_classes: Optional[int], return_one_hot: bool = True ) -> np.ndarray: """ - Check label format and transform to one-hot-encoded labels if necessary + Check label format and transform to one-hot-encoded labels if necessary. Only supports single-output classification. :param labels: An array of integer labels of shape `(nb_samples,)`, `(nb_samples, 1)` or `(nb_samples, nb_classes)`. - :param nb_classes: The number of classes. If None the number of classes is determined automatically. + :param nb_classes: The number of classes, as an integer. If None the number of classes is determined automatically. :param return_one_hot: True if returning one-hot encoded labels, False if returning index labels. :return: Labels with shape `(nb_samples, nb_classes)` (one-hot) or `(nb_samples,)` (index). """ labels_return = labels + if nb_classes is not None and not isinstance(nb_classes, int): + raise TypeError("nb_classes that is not an integer is not supported") + if len(labels.shape) == 2 and labels.shape[1] > 1: # multi-class, one-hot encoded if not return_one_hot: labels_return = np.argmax(labels, axis=1) diff --git a/tests/estimators/classification/test_scikitlearn.py b/tests/estimators/classification/test_scikitlearn.py index 14fccaf1c5..2f966f3d7b 100644 --- a/tests/estimators/classification/test_scikitlearn.py +++ b/tests/estimators/classification/test_scikitlearn.py @@ -47,6 +47,7 @@ ScikitlearnSVC, ) from art.estimators.classification.scikitlearn import SklearnClassifier +from art.utils import check_and_transform_label_format from tests.utils import TestBase, master_seed @@ -80,6 +81,28 @@ def test_save(self): def test_clone_for_refitting(self): _ = self.classifier.clone_for_refitting() + def test_multi_label(self): + x_train = self.x_train_iris + y_train = self.y_train_iris + x_test = self.x_test_iris + y_test = self.y_test_iris + + # make multi-label binary + y_train = np.column_stack((y_train, y_train, y_train)) + y_train[y_train > 1] = 1 + y_test = np.column_stack((y_test, y_test, y_test)) + y_test[y_test > 1] = 1 + + underlying_model = DecisionTreeClassifier() + underlying_model.fit(x_train, y_train) + model = ScikitlearnDecisionTreeClassifier(model=underlying_model) + + pred = model.predict(x_test) + assert (pred[0].shape[0] == x_test.shape[0]) + assert (isinstance(model.nb_classes, np.ndarray)) + with self.assertRaises(TypeError): + check_and_transform_label_format(y_train, nb_classes=model.nb_classes) + class TestScikitlearnExtraTreeClassifier(TestBase): @classmethod From 96b865892f1a8dbe27de0823a91904f358c26d29 Mon Sep 17 00:00:00 2001 From: abigailt Date: Tue, 16 Apr 2024 15:10:22 +0300 Subject: [PATCH 2/3] Fix type check to include numpy integer types Signed-off-by: abigailt --- art/estimators/classification/classifier.py | 2 +- art/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/art/estimators/classification/classifier.py b/art/estimators/classification/classifier.py index f879568859..38439071f1 100644 --- a/art/estimators/classification/classifier.py +++ b/art/estimators/classification/classifier.py @@ -114,7 +114,7 @@ def nb_classes(self, nb_classes: int): """ Set the number of output classes. """ - if nb_classes is None or (isinstance(nb_classes, int) and nb_classes < 2): + if nb_classes is None or (isinstance(nb_classes, (int, np.integer)) and nb_classes < 2): raise ValueError("nb_classes must be greater than or equal to 2.") self._nb_classes = nb_classes diff --git a/art/utils.py b/art/utils.py index f813d6cdc3..adb9ddc7b7 100644 --- a/art/utils.py +++ b/art/utils.py @@ -801,7 +801,7 @@ def check_and_transform_label_format( """ labels_return = labels - if nb_classes is not None and not isinstance(nb_classes, int): + if nb_classes is not None and not isinstance(nb_classes, (int, np.integer)): raise TypeError("nb_classes that is not an integer is not supported") if len(labels.shape) == 2 and labels.shape[1] > 1: # multi-class, one-hot encoded From 8be48a9f90d7d6a338ad9f2fe47edc4549d6b968 Mon Sep 17 00:00:00 2001 From: abigailt Date: Tue, 8 Oct 2024 09:45:25 +0300 Subject: [PATCH 3/3] Formatting Signed-off-by: abigailt --- tests/estimators/classification/test_scikitlearn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/estimators/classification/test_scikitlearn.py b/tests/estimators/classification/test_scikitlearn.py index 89290ca7fa..7fbcfe87b4 100644 --- a/tests/estimators/classification/test_scikitlearn.py +++ b/tests/estimators/classification/test_scikitlearn.py @@ -98,8 +98,8 @@ def test_multi_label(self): model = ScikitlearnDecisionTreeClassifier(model=underlying_model) pred = model.predict(x_test) - assert (pred[0].shape[0] == x_test.shape[0]) - assert (isinstance(model.nb_classes, np.ndarray)) + assert pred[0].shape[0] == x_test.shape[0] + assert isinstance(model.nb_classes, np.ndarray) with self.assertRaises(TypeError): check_and_transform_label_format(y_train, nb_classes=model.nb_classes)