Skip to content

Commit

Permalink
Revert "Remove mean and logprob (inconsistent values due to numerical…
Browse files Browse the repository at this point in the history
… approximations)"

This reverts commit f75333b.
  • Loading branch information
louisPoulain committed Sep 6, 2024
1 parent f75333b commit cb25eb4
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions mlpp_lib/probabilistic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,36 @@ def _sample_n(self, n, seed=None):

return chosen_samples

def _mean(self):
"""
Original: X ~ N(mu, sigma)
Censored: Y = X if 0 <= X <= 1 else 0 if X < 0 else 1
Law of total expectations: E[Y] = E[Y | X > 1] * P(X > 1) + E[Y | X < 0] * P(X < 0) + E[Y | 0 <= X <= 1] * P(0 <= X <= 1)
= P(X > 1) * 1 + P(X < 0) * 0 + E[X | 0 <= X <= 1] * P(0 <= X <= 1)
= 1 - Phi((1 - mu) / sigma) + E[Z ~ TruncNormal(mu, sigma, 0, 1)] * (Phi((1 - mu) / sigma) - Phi(-mu / sigma))
Ref for TruncatedNormal mean: https://en.wikipedia.org/wiki/Truncated_normal_distribution
"""
original_mean = self.normal.mean()
low_bound_standard = (0 - original_mean) / self.normal.stddev()
high_bound_standard = (1 - original_mean) / self.normal.stddev()

self.low_bound_cdf = self.normal.cdf(low_bound_standard)
self.high_bound_cdf = self.normal.cdf(high_bound_standard)

self.low_bound_pdf = self.normal.prob(low_bound_standard)
self.high_bound_pdf = self.normal.prob(high_bound_standard)

return original_mean + self.normal.stddev() * (
self.low_bound_pdf - self.high_bound_pdf
) / (self.high_bound_cdf - self.low_bound_cdf + 1e-3)

def _log_prob(self, value):
original_log_prob = self.normal.log_prob(value)

return original_log_prob - tf.math.log(
self.high_bound_cdf - self.low_bound_cdf + 1e-3
)

return independent_lib.Independent(
CustomCensored(normal_dist),
Expand Down

0 comments on commit cb25eb4

Please sign in to comment.