Skip to content

Commit

Permalink
Update workflows for TensorFlow/Keras 2.14
Browse files Browse the repository at this point in the history
Signed-off-by: Beat Buesser <[email protected]>
  • Loading branch information
beat-buesser committed Nov 16, 2023
1 parent 709adc3 commit abf4727
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/estimators/classification/test_scikitlearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def test_type(self):
def test_predict(self):
y_predicted = self.classifier.predict(self.x_test_iris[0:1])
y_expected = np.asarray([[0.07997696, 0.36272544, 0.5572976]])
np.testing.assert_array_almost_equal(y_predicted, y_expected, decimal=4)
np.testing.assert_array_almost_equal(y_predicted, y_expected, decimal=3)

def test_class_gradient_none_1(self):
grad_predicted = self.classifier.class_gradient(self.x_test_iris[0:1], label=None)
Expand All @@ -280,7 +280,7 @@ def test_class_gradient_none_1(self):
[0.6508137, 0.26377308, 1.54522324, 0.80972391],
]
]
np.testing.assert_array_almost_equal(grad_predicted, grad_expected, decimal=4)
np.testing.assert_array_almost_equal(grad_predicted, grad_expected, decimal=3)

def test_class_gradient_none_2(self):
grad_predicted = self.classifier.class_gradient(self.x_test_iris[0:2], label=None)
Expand All @@ -296,7 +296,7 @@ def test_class_gradient_none_2(self):
[0.70875132, 0.25104877, 1.70929277, 0.88410652],
],
]
np.testing.assert_array_almost_equal(grad_predicted, grad_expected, decimal=4)
np.testing.assert_array_almost_equal(grad_predicted, grad_expected, decimal=3)

def test_class_gradient_int_1(self):
grad_predicted = self.classifier.class_gradient(self.x_test_iris[0:1], label=1)
Expand Down Expand Up @@ -326,7 +326,7 @@ def test_class_gradient_list_2(self):
[[-0.56317306, -0.70493776, -0.98908573, -0.67106259]],
[[0.70866591, 0.25158876, 1.70947325, 0.88450021]],
]
np.testing.assert_array_almost_equal(grad_predicted, grad_expected, decimal=4)
np.testing.assert_array_almost_equal(grad_predicted, grad_expected, decimal=3)

def test_class_gradient_label_wrong_type(self):

Expand Down

0 comments on commit abf4727

Please sign in to comment.