Skip to content

Commit

Permalink
Fixes error in get losses functions (#2362)
Browse files Browse the repository at this point in the history
-   [X] Tests added and passed if fixing a bug or adding a new feature
-   [X] All code checks passed
-   [x] Added type annotations to new arguments/methods/functions
- [x] Added an entry in the latest `docs/release_notes/index.md` file if
fixing a bug or adding a new feature
- [x] If the changes are patches for a version, I have added the
`on-merge: backport to 1.2.x` label

---------

Co-authored-by: cane11 <[email protected]>
Co-authored-by: Martin Kim <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Jul 30, 2024
1 parent 9137c05 commit 6c458b3
Show file tree
Hide file tree
Showing 10 changed files with 116 additions and 38 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ to [Semantic Versioning]. Full commit history is available in the
data {pr}`2756`.
- Add support for reference mapping with {class}`mudata.MuData` models to
{class}`scvi.model.base.ArchesMixin` {pr}`2578`.
- Add argument `return_mean` to {meth}`scvi.model.base.VAEMixin.get_reconstruction_error`
and {meth}`scvi.model.base.VAEMixin.get_elbo` to allow computation
without averaging across cells {pr}`2362`.
- Add support for setting `weights="importance"` in
{meth}`scvi.model.SCANVI.differential_expression` {pr}`2362`.

#### Changed

Expand Down Expand Up @@ -94,6 +99,7 @@ to [Semantic Versioning]. Full commit history is available in the
Previously this raised a None error {pr}`2914`.
- {meth}`~scvi.model.SCVI.get_normalized_expression` fixed for Poisson distribution and
Negative Binomial with latent_library_size {pr}`2915`.
- Fix {meth}`scvi.module.VAE.marginal_ll` when `n_mc_samples_per_pass=1` {pr}`2362`.

#### Removed

Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/scar/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def loss(
generative_outputs["pl"],
).sum(dim=1)
else:
kl_divergence_l = 0.0
kl_divergence_l = torch.zeros_like(kl_divergence_z)

# need to add the ambient rate and scale to the distribution for the loss
px = generative_outputs["px"]
Expand Down
57 changes: 43 additions & 14 deletions src/scvi/model/base/_log_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Iterator
from typing import Any, Callable

import torch
from torch import Tensor

from scvi.module.base import LossOutput
Expand All @@ -11,6 +12,7 @@
def compute_elbo(
module: Callable[[dict[str, Tensor | None], dict], tuple[Any, Any, LossOutput]],
dataloader: Iterator[dict[str, Tensor | None]],
return_mean: bool = True,
**kwargs,
) -> float:
"""Compute the evidence lower bound (ELBO) on the data.
Expand All @@ -33,22 +35,37 @@ def compute_elbo(
the ``forward`` method of ``module``.
**kwargs
Additional keyword arguments to pass into ``module``.
return_mean
If ``True``, return the mean ELBO across the dataset. If ``False``, return the ELBO for
each cell individually.
Returns
-------
The evidence lower bound (ELBO) of the data.
"""
elbo = 0.0
elbo = []
for tensors in dataloader:
_, _, loss_output = module(tensors, **kwargs)
elbo += (loss_output.reconstruction_loss_sum + loss_output.kl_local_sum).item()

return (elbo + loss_output.kl_global_sum) / len(dataloader.dataset)
_, _, losses = module(tensors, **kwargs)
if isinstance(losses.reconstruction_loss, dict):
reconstruction_loss = torch.stack(list(losses.reconstruction_loss.values())).sum(dim=0)
else:
reconstruction_loss = losses.reconstruction_loss
if isinstance(losses.kl_local, dict):
kl_local = torch.stack(list(losses.kl_local.values())).sum(dim=0)
else:
kl_local = losses.kl_local
elbo.append(reconstruction_loss + kl_local)

elbo = torch.cat(elbo, dim=0)
if return_mean:
elbo = elbo.mean()
return elbo


def compute_reconstruction_error(
module: Callable[[dict[str, Tensor | None], dict], tuple[Any, Any, LossOutput]],
dataloader: Iterator[dict[str, Tensor | None]],
return_mean: bool = True,
**kwargs,
) -> dict[str, float]:
"""Compute the reconstruction error on the data.
Expand All @@ -67,21 +84,33 @@ def compute_reconstruction_error(
An iterator over minibatches of data on which to compute the metric. The minibatches
should be formatted as a dictionary of :class:`~torch.Tensor` with keys as expected by
the ``forward`` method of ``module``.
return_mean
If ``True``, return the mean reconstruction error across the dataset. If ``False``,
return the reconstruction error for each cell individually.
**kwargs
Additional keyword arguments to pass into ``module``.
Returns
-------
A dictionary of the reconstruction error of the data.
"""
log_likelihoods = {}
# Iterate once over the data and computes the reconstruction error
log_lkl = {}
for tensors in dataloader:
_, _, loss_output = module(tensors, loss_kwargs={"kl_weight": 1}, **kwargs)
rec_losses: dict[str, Tensor] | Tensor = loss_output.reconstruction_loss
if not isinstance(rec_losses, dict):
rec_losses = {"reconstruction_loss": rec_losses}

for key, value in rec_losses.items():
log_likelihoods[key] = log_likelihoods.get(key, 0.0) + value.sum().item()

return {key: -(value / len(dataloader.dataset)) for key, value in log_likelihoods.items()}
if not isinstance(loss_output.reconstruction_loss, dict):
rec_loss_dict = {"reconstruction_loss": loss_output.reconstruction_loss}
else:
rec_loss_dict = loss_output.reconstruction_loss
for key, value in rec_loss_dict.items():
if key in log_lkl:
log_lkl[key].append(value)
else:
log_lkl[key] = [value]

for key, _ in log_lkl.items():
log_lkl[key] = torch.cat(log_lkl[key], dim=0)
if return_mean:
log_lkl[key] = torch.mean(log_lkl[key])

return log_lkl
14 changes: 12 additions & 2 deletions src/scvi/model/base/_vaemixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def get_elbo(
indices: Sequence[int] | None = None,
batch_size: int | None = None,
dataloader: Iterator[dict[str, Tensor | None]] = None,
return_mean: bool = True,
**kwargs,
) -> float:
"""Compute the evidence lower bound (ELBO) on the data.
Expand All @@ -49,6 +51,8 @@ def get_elbo(
An iterator over minibatches of data on which to compute the metric. The minibatches
should be formatted as a dictionary of :class:`~torch.Tensor` with keys as expected by
the model. If ``None``, a dataloader is created from ``adata``.
return_mean
Whether to return the mean of the ELBO or the ELBO for each observation.
**kwargs
Additional keyword arguments to pass into the forward method of the module.
Expand All @@ -70,7 +74,7 @@ def get_elbo(
adata=adata, indices=indices, batch_size=batch_size
)

return -compute_elbo(self.module, dataloader)
return -compute_elbo(self.module, dataloader, return_mean=return_mean, **kwargs)

@torch.inference_mode()
@unsupported_if_adata_minified
Expand Down Expand Up @@ -158,6 +162,7 @@ def get_reconstruction_error(
indices: Sequence[int] | None = None,
batch_size: int | None = None,
dataloader: Iterator[dict[str, Tensor | None]] = None,
return_mean: bool = True,
**kwargs,
) -> dict[str, float]:
r"""Compute the reconstruction error on the data.
Expand All @@ -183,6 +188,9 @@ def get_reconstruction_error(
An iterator over minibatches of data on which to compute the metric. The minibatches
should be formatted as a dictionary of :class:`~torch.Tensor` with keys as expected by
the model. If ``None``, a dataloader is created from ``adata``.
return_mean
Whether to return the mean reconstruction loss or the reconstruction loss
for each observation.
**kwargs
Additional keyword arguments to pass into the forward method of the module.
Expand All @@ -205,7 +213,9 @@ def get_reconstruction_error(
adata=adata, indices=indices, batch_size=batch_size
)

return compute_reconstruction_error(self.module, dataloader, **kwargs)
return compute_reconstruction_error(
self.module, dataloader, return_mean=return_mean, **kwargs
)

@torch.inference_mode()
def get_latent_representation(
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/module/_autozivae.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def loss(
Normal(local_library_log_means, torch.sqrt(local_library_log_vars)),
).sum(dim=1)
else:
kl_divergence_l = 0.0
kl_divergence_l = torch.zeros_like(kl_divergence_z)

# KL divergence wrt Bernoulli parameters
kl_divergence_bernoulli = self.compute_global_kl_divergence()
Expand Down
27 changes: 14 additions & 13 deletions src/scvi/module/_scanvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,19 +305,6 @@ def loss(
kl_divergence_z2 = kl(qz2, Normal(mean, scale)).sum(dim=-1)
loss_z1_unweight = -Normal(pz1_m, torch.sqrt(pz1_v)).log_prob(z1s).sum(dim=-1)
loss_z1_weight = qz1.log_prob(z1).sum(dim=-1)
if not self.use_observed_lib_size:
ql = inference_outputs["ql"]
(
local_library_log_means,
local_library_log_vars,
) = self._compute_local_library_params(batch_index)

kl_divergence_l = kl(
ql,
Normal(local_library_log_means, torch.sqrt(local_library_log_vars)),
).sum(dim=1)
else:
kl_divergence_l = 0.0

probs = self.classifier(z1)
if self.classifier.logits:
Expand All @@ -344,6 +331,20 @@ def loss(
),
)

if not self.use_observed_lib_size:
ql = inference_outputs["ql"]
(
local_library_log_means,
local_library_log_vars,
) = self._compute_local_library_params(batch_index)

kl_divergence_l = kl(
ql,
Normal(local_library_log_means, torch.sqrt(local_library_log_vars)),
).sum(dim=1)
else:
kl_divergence_l = torch.zeros_like(kl_divergence)

kl_divergence += kl_divergence_l

loss = torch.mean(reconst_loss + kl_divergence * kl_weight)
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/module/_totalvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def loss(
Normal(local_library_log_means, torch.sqrt(local_library_log_vars)),
).sum(dim=1)
else:
kl_div_l_gene = 0.0
kl_div_l_gene = torch.zeros_like(kl_div_z)

kl_div_back_pro_full = kl(
Normal(py_["back_alpha"], py_["back_beta"]), self.back_mean_prior
Expand Down
4 changes: 3 additions & 1 deletion src/scvi/module/_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def loss(
inference_outputs[MODULE_KEYS.QL_KEY], generative_outputs[MODULE_KEYS.PL_KEY]
).sum(dim=1)
else:
kl_divergence_l = torch.tensor(0.0, device=x.device)
kl_divergence_l = torch.zeros_like(kl_divergence_z)

reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)

Expand Down Expand Up @@ -697,6 +697,8 @@ def marginal_ll(
q_l_x = ql.log_prob(library).sum(dim=-1)

log_prob_sum += p_l - q_l_x
if n_mc_samples_per_pass == 1:
log_prob_sum = log_prob_sum.unsqueeze(0)

to_sum.append(log_prob_sum)
to_sum = torch.cat(to_sum, dim=0)
Expand Down
19 changes: 19 additions & 0 deletions tests/model/test_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,25 @@ def test_multiple_covariates_scanvi():
m.get_latent_representation()
m.get_elbo()
m.get_marginal_ll(n_mc_samples=3)
m.get_marginal_ll(adata, return_mean=True, n_mc_samples=6, n_mc_samples_per_pass=1)
m.get_marginal_ll(adata, return_mean=True, n_mc_samples=6, n_mc_samples_per_pass=6)
m.differential_expression(
idx1=np.arange(50), idx2=51 + np.arange(50), mode="vanilla", weights="uniform"
)
m.differential_expression(
idx1=np.arange(50),
idx2=51 + np.arange(50),
mode="vanilla",
weights="importance",
importance_weighting_kwargs={"n_mc_samples": 10, "n_mc_samples_per_pass": 1},
)
m.differential_expression(
idx1=np.arange(50),
idx2=51 + np.arange(50),
mode="vanilla",
weights="importance",
importance_weighting_kwargs={"n_mc_samples": 10, "n_mc_samples_per_pass": 10},
)
m.get_reconstruction_error()
m.get_normalized_expression(n_samples=1)
m.get_normalized_expression(n_samples=2)
Expand Down
21 changes: 16 additions & 5 deletions tests/model/test_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,19 @@ def test_scvi(gene_likelihood: str, n_latent: int = 5):
)
model = SCVI(adata, n_latent=n_latent)
model.train(1, check_val_every_n_epoch=1, train_size=0.5)
model.get_elbo()
model.get_marginal_ll(n_mc_samples=3)
model.get_reconstruction_error()
model.get_normalized_expression(transform_batch="batch_1")
model.get_normalized_expression(n_samples=2)
assert model.get_elbo().ndim == 0
assert model.get_elbo(return_mean=False).shape == (adata.n_obs,)
assert model.get_marginal_ll(n_mc_samples=3).ndim == 0
assert model.get_marginal_ll(n_mc_samples=3, return_mean=False).shape == (adata.n_obs,)
assert model.get_reconstruction_error()["reconstruction_loss"].ndim == 0
assert model.get_reconstruction_error(return_mean=False)["reconstruction_loss"].shape == (
adata.n_obs,
)
assert model.get_normalized_expression(transform_batch="batch_1").shape == (
adata.n_obs,
adata.n_vars,
)
assert model.get_normalized_expression(n_samples=2).shape == (adata.n_obs, adata.n_vars)

# Test without observed lib size.
model = SCVI(adata, n_latent=n_latent, var_activation=Softplus(), use_observed_lib_size=False)
Expand All @@ -213,8 +221,11 @@ def test_scvi(gene_likelihood: str, n_latent: int = 5):
assert z.shape == (adata.shape[0], n_latent)
assert len(model.history["elbo_train"]) == 2
model.get_elbo()
model.get_elbo(return_mean=False)
model.get_marginal_ll(n_mc_samples=3)
model.get_marginal_ll(n_mc_samples=3, return_mean=False)
model.get_reconstruction_error()
model.get_reconstruction_error(return_mean=False)
model.get_normalized_expression(transform_batch="batch_1")
model.get_normalized_expression(n_samples=2)

Expand Down

0 comments on commit 6c458b3

Please sign in to comment.