Skip to content

Commit

Permalink
Merge pull request #2505 from abigailgold/sklearn_nbclasses
Browse files Browse the repository at this point in the history
Support sklearn models with multiple outputs
  • Loading branch information
beat-buesser authored Dec 17, 2024
2 parents 0b4bb68 + bd58b1a commit cf11263
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
2 changes: 1 addition & 1 deletion art/estimators/classification/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def nb_classes(self, nb_classes: int):
"""
Set the number of output classes.
"""
if nb_classes is None or nb_classes < 2:
if nb_classes is None or (isinstance(nb_classes, (int, np.integer)) and nb_classes < 2):
raise ValueError("nb_classes must be greater than or equal to 2.")

self._nb_classes = nb_classes
Expand Down
7 changes: 5 additions & 2 deletions art/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,15 +799,18 @@ def check_and_transform_label_format(
labels: np.ndarray, nb_classes: int | None, return_one_hot: bool = True
) -> np.ndarray:
"""
Check label format and transform to one-hot-encoded labels if necessary
Check label format and transform to one-hot-encoded labels if necessary. Only supports single-output classification.
:param labels: An array of integer labels of shape `(nb_samples,)`, `(nb_samples, 1)` or `(nb_samples, nb_classes)`.
:param nb_classes: The number of classes. If None the number of classes is determined automatically.
:param nb_classes: The number of classes, as an integer. If None the number of classes is determined automatically.
:param return_one_hot: True if returning one-hot encoded labels, False if returning index labels.
:return: Labels with shape `(nb_samples, nb_classes)` (one-hot) or `(nb_samples,)` (index).
"""
labels_return = labels

if nb_classes is not None and not isinstance(nb_classes, (int, np.integer)):
raise TypeError("nb_classes that is not an integer is not supported")

if len(labels.shape) == 2 and labels.shape[1] > 1: # multi-class, one-hot encoded
if not return_one_hot:
labels_return = np.argmax(labels, axis=1)
Expand Down
23 changes: 23 additions & 0 deletions tests/estimators/classification/test_scikitlearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
ScikitlearnSVC,
)
from art.estimators.classification.scikitlearn import SklearnClassifier
from art.utils import check_and_transform_label_format

from tests.utils import TestBase, master_seed

Expand Down Expand Up @@ -80,6 +81,28 @@ def test_save(self):
def test_clone_for_refitting(self):
_ = self.classifier.clone_for_refitting()

def test_multi_label(self):
x_train = self.x_train_iris
y_train = self.y_train_iris
x_test = self.x_test_iris
y_test = self.y_test_iris

# make multi-label binary
y_train = np.column_stack((y_train, y_train, y_train))
y_train[y_train > 1] = 1
y_test = np.column_stack((y_test, y_test, y_test))
y_test[y_test > 1] = 1

underlying_model = DecisionTreeClassifier()
underlying_model.fit(x_train, y_train)
model = ScikitlearnDecisionTreeClassifier(model=underlying_model)

pred = model.predict(x_test)
assert pred[0].shape[0] == x_test.shape[0]
assert isinstance(model.nb_classes, np.ndarray)
with self.assertRaises(TypeError):
check_and_transform_label_format(y_train, nb_classes=model.nb_classes)


class TestScikitlearnExtraTreeClassifier(TestBase):
@classmethod
Expand Down

0 comments on commit cf11263

Please sign in to comment.