Skip to content

Commit

Permalink
Update randomized_smoothing.py
Browse files Browse the repository at this point in the history
Signed-off-by: Prem Kiran Laknaboina <[email protected]>
  • Loading branch information
Ashuradhipathi authored and beat-buesser committed Dec 16, 2024
1 parent 6be8850 commit 63b89ef
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ def predict(self, x: np.ndarray, batch_size: int = 128, verbose: bool = False, *
# get class counts
counts_pred = self._prediction_counts(x_i, batch_size=batch_size)
top = counts_pred.argsort()[::-1]
# Conersion to int
# conversion to int
count1 = int(np.max(counts_pred))
count2 = int(counts_pred[top[1]])

# predict or abstain
smooth_prediction = np.zeros(counts_pred.shape)
#Get p value from BinomTestResult object
# get p value from BinomTestResult object
p_value = binomtest(count1, count1 + count2, p=0.5).pvalue
if (not is_abstain) or (p_value <= self.alpha):
smooth_prediction[np.argmax(counts_pred)] = 1
Expand Down

0 comments on commit 63b89ef

Please sign in to comment.