From c2e3714750f181c1fbc5f9a784d4b8ebb8cd6686 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ramon=20Vi=C3=B1as?= Date: Wed, 20 Nov 2024 10:36:48 +0100 Subject: [PATCH] feat: Add variance to ZINB model (#3044) Implemented variance of ZINB distribution --------- Co-authored-by: Ori Kronfeld --- CHANGELOG.md | 4 ++-- src/scvi/distributions/_negative_binomial.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 315aae5ce7..bec14bc8a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,9 +25,9 @@ to [Semantic Versioning]. Full commit history is available in the - Added adaptive handling for last training minibatch of 1-2 cells in case of `datasplitter_kwargs={"drop_last": False}` and `train_size = None` by moving them into - validation set, if available. - {pr}`3036`. + validation set, if available. {pr}`3036`. - Add `batch_key` and `labels_key` to `scvi.external.SCAR.setup_anndata`. +- Implemented variance of ZINB distribution. {pr}`3044`. #### Fixed diff --git a/src/scvi/distributions/_negative_binomial.py b/src/scvi/distributions/_negative_binomial.py index 27f8f90524..c7115c56b2 100644 --- a/src/scvi/distributions/_negative_binomial.py +++ b/src/scvi/distributions/_negative_binomial.py @@ -502,7 +502,8 @@ def mean(self) -> torch.Tensor: @property def variance(self) -> None: - raise NotImplementedError + pi = self.zi_probs + return (1 - pi) * self.mu * (self.mu + self.theta + pi * self.mu * self.theta) / self.theta @lazy_property def zi_logits(self) -> torch.Tensor: