From 6a9f31119b0a633f59193760346af886a07d9767 Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Mon, 19 Aug 2024 15:55:18 +0300 Subject: [PATCH] Added PoissonVI Region Factors (#2940) This will close Issue [2529](https://github.com/scverse/scvi-tools/issues/2529) --- CHANGELOG.md | 1 + src/scvi/external/poissonvi/_model.py | 8 ++++++++ tests/external/poissonvi/test_poissonvi.py | 1 + 3 files changed, 10 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f528945a47..c2cee00877 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ to [Semantic Versioning]. Full commit history is available in the #### Added +- Add `scvi.external.poissonvi.get_region_factors` {pr}`2940`. - {attr}`scvi.settings.dl_persistent_workers` allows using persistent workers in {class}`scvi.dataloaders.AnnDataLoader` {pr}`2924`. - Add option for using external indexes in data splitting classes that are under `scvi.dataloaders` diff --git a/src/scvi/external/poissonvi/_model.py b/src/scvi/external/poissonvi/_model.py index 4a8729fae6..5d18feb03c 100644 --- a/src/scvi/external/poissonvi/_model.py +++ b/src/scvi/external/poissonvi/_model.py @@ -229,6 +229,14 @@ def get_accessibility_estimates( self.module.decoder.px_scale_decoder[-2].bias = torch.nn.Parameter(region_factors) return accs + @torch.inference_mode() + def get_region_factors(self): + """Return region-specific factors.""" + region_factors = self.module.decoder.px_scale_decoder[-2].bias.numpy() + if region_factors is None: + raise RuntimeError("region factors were not included in this model") + return region_factors + def get_normalized_expression( self, ): diff --git a/tests/external/poissonvi/test_poissonvi.py b/tests/external/poissonvi/test_poissonvi.py index 53de237668..27862b345f 100644 --- a/tests/external/poissonvi/test_poissonvi.py +++ b/tests/external/poissonvi/test_poissonvi.py @@ -11,6 +11,7 @@ def test_poissonvi(): model.train(max_epochs=1) model.get_latent_representation() model.get_accessibility_estimates() + model.get_region_factors() def test_poissonvi_default_params():