Skip to content

Commit

Permalink
[CI, testing] add SVM C and nu parameter check prevents seg fault skl…
Browse files Browse the repository at this point in the history
…earn < 1.2 (#1930) (#1947)

* Update svc.py

* Update svr.py

* Update nusvc.py

* Update nusvc.py

* Update nusvr.py

* Update svc.py

* match sklearn error messages

* add comments

* formatting

* Update sklearnex/svm/svc.py

Co-authored-by: ethanglaser <[email protected]>

* Update sklearnex/svm/nusvr.py

Co-authored-by: ethanglaser <[email protected]>

---------

Co-authored-by: ethanglaser <[email protected]>
(cherry picked from commit 04dbc5e)

Co-authored-by: Ian Faust <[email protected]>
  • Loading branch information
mergify[bot] and icfaust authored Jul 18, 2024
1 parent 3a89a48 commit 52f27a5
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 0 deletions.
11 changes: 11 additions & 0 deletions sklearnex/svm/nusvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@ def __init__(
def fit(self, X, y, sample_weight=None):
if sklearn_check_version("1.2"):
self._validate_params()
elif self.nu <= 0 or self.nu > 1:
# else if added to correct issues with
# sklearn tests:
# svm/tests/test_sparse.py::test_error
# svm/tests/test_svm.py::test_bad_input
# for sklearn versions < 1.2 (i.e. without
# validate_params parameter checking)
# Without this, a segmentation fault with
# Windows fatal exception: access violation
# occurs
raise ValueError("nu <= 0 or nu > 1")
if sklearn_check_version("1.0"):
self._check_feature_names(X, reset=True)
dispatch(
Expand Down
11 changes: 11 additions & 0 deletions sklearnex/svm/nusvr.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ def __init__(
def fit(self, X, y, sample_weight=None):
if sklearn_check_version("1.2"):
self._validate_params()
elif self.nu <= 0 or self.nu > 1:
# else if added to correct issues with
# sklearn tests:
# svm/tests/test_sparse.py::test_error
# svm/tests/test_svm.py::test_bad_input
# for sklearn versions < 1.2 (i.e. without
# validate_params parameter checking)
# Without this, a segmentation fault with
# Windows fatal exception: access violation
# occurs
raise ValueError("nu <= 0 or nu > 1")
if sklearn_check_version("1.0"):
self._check_feature_names(X, reset=True)
dispatch(
Expand Down
11 changes: 11 additions & 0 deletions sklearnex/svm/svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ def __init__(
def fit(self, X, y, sample_weight=None):
if sklearn_check_version("1.2"):
self._validate_params()
elif self.C <= 0:
# else if added to correct issues with
# sklearn tests:
# svm/tests/test_sparse.py::test_error
# svm/tests/test_svm.py::test_bad_input
# for sklearn versions < 1.2 (i.e. without
# validate_params parameter checking)
# Without this, a segmentation fault with
# Windows fatal exception: access violation
# occurs
raise ValueError("C <= 0")
if sklearn_check_version("1.0"):
self._check_feature_names(X, reset=True)
dispatch(
Expand Down
11 changes: 11 additions & 0 deletions sklearnex/svm/svr.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ def __init__(
def fit(self, X, y, sample_weight=None):
if sklearn_check_version("1.2"):
self._validate_params()
elif self.C <= 0:
# else if added to correct issues with
# sklearn tests:
# svm/tests/test_sparse.py::test_error
# svm/tests/test_svm.py::test_bad_input
# for sklearn versions < 1.2 (i.e. without
# validate_params parameter checking)
# Without this, a segmentation fault with
# Windows fatal exception: access violation
# occurs
raise ValueError("C <= 0")
if sklearn_check_version("1.0"):
self._check_feature_names(X, reset=True)
dispatch(
Expand Down

0 comments on commit 52f27a5

Please sign in to comment.