Skip to content

Commit

Permalink
Update workflows for TensorFlow/Keras 2.14
Browse files Browse the repository at this point in the history
Signed-off-by: Beat Buesser <[email protected]>
  • Loading branch information
beat-buesser committed Nov 16, 2023
1 parent f45b268 commit 709adc3
Showing 1 changed file with 7 additions and 15 deletions.
22 changes: 7 additions & 15 deletions tests/estimators/classification/test_scikitlearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,38 +300,32 @@ def test_class_gradient_none_2(self):

def test_class_gradient_int_1(self):
grad_predicted = self.classifier.class_gradient(self.x_test_iris[0:1], label=1)
grad_expected = [[[-0.56322294, -0.70493763, -0.98874801, -0.67053026]]]
grad_expected = [[[-0.56317311, -0.70493763, -0.98908609, -0.67106276]]]

for i_shape in range(4):
print(grad_predicted[0, 0, i_shape])
self.assertAlmostEqual(grad_predicted[0, 0, i_shape], grad_expected[0][0][i_shape], 3)

def test_class_gradient_int_2(self):
grad_predicted = self.classifier.class_gradient(self.x_test_iris[0:2], label=1)
grad_expected = [
[[-0.56322294, -0.70427608, -0.98874801, -0.67053026]],
[[-0.50528532, -0.71700042, -0.82467848, -0.59614766]],
[[-0.56317306, -0.70493776, -0.98908573, -0.67106259]],
[[-0.50522697, -0.71762568, -0.82497531, -0.5966416]],
]
print("grad_predicted")
print(grad_predicted)
np.testing.assert_array_almost_equal(grad_predicted, grad_expected, decimal=4)

def test_class_gradient_list_1(self):
grad_predicted = self.classifier.class_gradient(self.x_test_iris[0:1], label=[1])
grad_expected = [[[-0.56322294, -0.70427608, -0.98874801, -0.67053026]]]
grad_expected = [[[-0.56317311, -0.70493763, -0.98874801, -0.67053026]]]

for i_shape in range(4):
print(grad_predicted[0, 0, i_shape])
self.assertAlmostEqual(grad_predicted[0, 0, i_shape], grad_expected[0][0][i_shape], 3)

def test_class_gradient_list_2(self):
grad_predicted = self.classifier.class_gradient(self.x_test_iris[0:2], label=[1, 2])
grad_expected = [
[[-0.56322294, -0.70427608, -0.98874801, -0.67053026]],
[[0.70875132, 0.25104877, 1.70929277, 0.88410652]],
[[-0.56317306, -0.70493776, -0.98908573, -0.67106259]],
[[0.70866591, 0.25158876, 1.70947325, 0.88450021]],
]
print("grad_predicted")
print(grad_predicted)
np.testing.assert_array_almost_equal(grad_predicted, grad_expected, decimal=4)

def test_class_gradient_label_wrong_type(self):
Expand All @@ -345,9 +339,7 @@ def test_class_gradient_label_wrong_type(self):

def test_loss_gradient(self):
grad_predicted = self.classifier.loss_gradient(self.x_test_iris[0:1], self.y_test_iris[0:1])
grad_expected = np.asarray([[-0.21693791, -0.08792436, -0.51507443, -0.26990796]])
print("grad_predicted")
print(grad_predicted)
grad_expected = np.asarray([[-0.21690657, -0.08809226, -0.51512082, -0.27002635]])
np.testing.assert_array_almost_equal(grad_predicted, grad_expected, decimal=4)

def test_save(self):
Expand Down

0 comments on commit 709adc3

Please sign in to comment.