diff --git a/mlpp_lib/probabilistic_layers.py b/mlpp_lib/probabilistic_layers.py index 6a005d1..3928b47 100644 --- a/mlpp_lib/probabilistic_layers.py +++ b/mlpp_lib/probabilistic_layers.py @@ -361,6 +361,36 @@ def _sample_n(self, n, seed=None): return chosen_samples + def _mean(self): + """ + Original: X ~ N(mu, sigma) + Censored: Y = X if 0 <= X <= 1 else 0 if X < 0 else 1 + + Law of total expectations: E[Y] = E[Y | X > 1] * P(X > 1) + E[Y | X < 0] * P(X < 0) + E[Y | 0 <= X <= 1] * P(0 <= X <= 1) + = P(X > 1) * 1 + P(X < 0) * 0 + E[X | 0 <= X <= 1] * P(0 <= X <= 1) + = 1 - Phi((1 - mu) / sigma) + E[Z ~ TruncNormal(mu, sigma, 0, 1)] * (Phi((1 - mu) / sigma) - Phi(-mu / sigma)) + Ref for TruncatedNormal mean: https://en.wikipedia.org/wiki/Truncated_normal_distribution + """ + original_mean = self.normal.mean() + low_bound_standard = (0 - original_mean) / self.normal.stddev() + high_bound_standard = (1 - original_mean) / self.normal.stddev() + + self.low_bound_cdf = self.normal.cdf(low_bound_standard) + self.high_bound_cdf = self.normal.cdf(high_bound_standard) + + self.low_bound_pdf = self.normal.prob(low_bound_standard) + self.high_bound_pdf = self.normal.prob(high_bound_standard) + + return original_mean + self.normal.stddev() * ( + self.low_bound_pdf - self.high_bound_pdf + ) / (self.high_bound_cdf - self.low_bound_cdf + 1e-3) + + def _log_prob(self, value): + original_log_prob = self.normal.log_prob(value) + + return original_log_prob - tf.math.log( + self.high_bound_cdf - self.low_bound_cdf + 1e-3 + ) return independent_lib.Independent( CustomCensored(normal_dist),