Skip to content

Commit

Permalink
Fix mypy warning on typing
Browse files Browse the repository at this point in the history
Signed-off-by: Beat Buesser <[email protected]>
  • Loading branch information
beat-buesser committed Jan 16, 2025
1 parent 0aa91b5 commit 911884b
Showing 1 changed file with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,9 @@ def generate( # type: ignore
logger.info("Setting labels to estimator classification predictions.")
y = to_categorical(np.argmax(self.estimator.predict(x=x), axis=1), nb_classes=self.estimator.nb_classes)

y = check_and_transform_label_format(labels=y, nb_classes=self.estimator.nb_classes)
y_array: np.ndarray = y

y = check_and_transform_label_format(labels=y_array, nb_classes=self.estimator.nb_classes)

# check if logits or probabilities
y_pred = self.estimator.predict(x=x[[0]])
Expand Down

0 comments on commit 911884b

Please sign in to comment.