Skip to content

Commit

Permalink
Added PoissonVI Region Factors (#2940)
Browse files Browse the repository at this point in the history
This will close Issue
[2529](#2529)
  • Loading branch information
ori-kron-wis authored Aug 19, 2024
1 parent b2f9f3e commit 6a9f311
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ to [Semantic Versioning]. Full commit history is available in the

#### Added

- Add `scvi.external.poissonvi.get_region_factors` {pr}`2940`.
- {attr}`scvi.settings.dl_persistent_workers` allows using persistent workers in
{class}`scvi.dataloaders.AnnDataLoader` {pr}`2924`.
- Add option for using external indexes in data splitting classes that are under `scvi.dataloaders`
Expand Down
8 changes: 8 additions & 0 deletions src/scvi/external/poissonvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,14 @@ def get_accessibility_estimates(
self.module.decoder.px_scale_decoder[-2].bias = torch.nn.Parameter(region_factors)
return accs

@torch.inference_mode()
def get_region_factors(self):
"""Return region-specific factors."""
region_factors = self.module.decoder.px_scale_decoder[-2].bias.numpy()
if region_factors is None:
raise RuntimeError("region factors were not included in this model")
return region_factors

def get_normalized_expression(
self,
):
Expand Down
1 change: 1 addition & 0 deletions tests/external/poissonvi/test_poissonvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def test_poissonvi():
model.train(max_epochs=1)
model.get_latent_representation()
model.get_accessibility_estimates()
model.get_region_factors()


def test_poissonvi_default_params():
Expand Down

0 comments on commit 6a9f311

Please sign in to comment.