Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: continuous, embedded covariates re-injected during training #3032

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/scvi/external/contrastivevi/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,12 @@ def _generic_inference(
library = torch.log(x.sum(1)).unsqueeze(1)
x_ = torch.log(1 + x_)

qz_m, qz_v, z = self.z_encoder(x_, batch_index)
qs_m, qs_v, s = self.s_encoder(x_, batch_index)
qz_m, qz_v, z = self.z_encoder(x_, None, batch_index)
qs_m, qs_v, s = self.s_encoder(x_, None, batch_index)

ql_m, ql_v = None, None
if not self.use_observed_lib_size:
ql_m, ql_v, library_encoded = self.l_encoder(x_, batch_index)
ql_m, ql_v, library_encoded = self.l_encoder(x_, None, batch_index)
library = library_encoded

if n_samples > 1:
Expand Down Expand Up @@ -333,6 +333,7 @@ def _generic_generative(
self.dispersion,
latent,
library,
None,
batch_index,
)
px_r = torch.exp(self.px_r)
Expand Down
6 changes: 5 additions & 1 deletion src/scvi/external/methylvi/_base_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def forward(
self,
dispersion: str,
z: torch.Tensor,
cont_covs: torch.Tensor | None = None,
*cat_list: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""The forward computation for a single sample.
Expand All @@ -99,14 +100,17 @@ def forward(
library size
cat_list
list of category membership(s) for this sample
cont_covs
continuous covariates for this sample,
tensor of values with shape ``(n_cont,)``

Returns
-------
2-tuple of :py:class:`torch.Tensor`
parameters for the Beta distribution of mean methylation values

"""
px = self.px_decoder(z, *cat_list)
px = self.px_decoder(z, cont_covs, *cat_list)
px_mu = self.px_mu_decoder(px)
px_gamma = self.px_gamma_decoder(px) if dispersion == "region-cell" else None

Expand Down
4 changes: 2 additions & 2 deletions src/scvi/external/methylvi/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def inference(self, mc, cov, batch_index, cat_covs=None, n_samples=1):
else:
categorical_input = ()

qz, z = self.z_encoder(methylation_input, batch_index, *categorical_input)
qz, z = self.z_encoder(methylation_input, None, batch_index, *categorical_input)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the None as an argument in line 148. It makes it cleaner where it comes from.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is throughout this PR.

if n_samples > 1:
z = qz.sample((n_samples,))

Expand All @@ -182,7 +182,7 @@ def generative(self, z, batch_index, cat_covs=None):

for context in self.contexts:
px_mu[context], px_gamma[context] = self.decoders[context](
self.dispersion, z, batch_index, *categorical_input
self.dispersion, z, None, batch_index, *categorical_input
)

pz = Normal(torch.zeros_like(z), torch.ones_like(z))
Expand Down
4 changes: 2 additions & 2 deletions src/scvi/module/_mrdeconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def generative(self, x, ind_x):
enum_label = (
torch.arange(0, self.n_labels).repeat(m).view((-1, 1))
) # minibatch_size * n_labels, 1
h = self.decoder(gamma_reshape, enum_label.to(x.device))
h = self.decoder(gamma_reshape, None, enum_label.to(x.device))
px_rate = self.px_decoder(h).reshape(
(m, self.n_labels, -1)
) # (minibatch, n_labels, n_genes)
Expand Down Expand Up @@ -394,7 +394,7 @@ def get_ct_specific_expression(
gamma_select = gamma_ind[
:, y_torch, torch.arange(ind_x.shape[0])
].T # minibatch_size, n_latent
h = self.decoder(gamma_select, y_torch.unsqueeze(1))
h = self.decoder(gamma_select, None, y_torch.unsqueeze(1))
px_scale = self.px_decoder(h) # (minibatch, n_genes)
px_ct = torch.exp(self.px_o).unsqueeze(0) * beta.unsqueeze(0) * px_scale
return px_ct # shape (minibatch, genes)
62 changes: 44 additions & 18 deletions src/scvi/module/_multivae.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
self,
n_input: int,
n_cat_list: Iterable[int] = None,
n_cont: int = 0,
n_layers: int = 2,
n_hidden: int = 128,
use_batch_norm: bool = False,
Expand All @@ -40,6 +41,7 @@ def __init__(
n_in=n_input,
n_out=n_hidden,
n_cat_list=n_cat_list,
n_cont=n_cont,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=0,
Expand All @@ -51,9 +53,9 @@ def __init__(
)
self.output = torch.nn.Sequential(torch.nn.Linear(n_hidden, 1), torch.nn.LeakyReLU())

def forward(self, x: torch.Tensor, *cat_list: int):
def forward(self, x: torch.Tensor, cont_covs: torch.Tensor | None = None, *cat_list: int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it's correct.

"""Forward pass."""
return self.output(self.px_decoder(x, *cat_list))
return self.output(self.px_decoder(x, cont_covs, *cat_list))


class DecoderADT(torch.nn.Module):
Expand All @@ -64,6 +66,7 @@ def __init__(
n_input: int,
n_output_proteins: int,
n_cat_list: Iterable[int] = None,
n_cont: int = 0,
n_layers: int = 2,
n_hidden: int = 128,
dropout_rate: float = 0.1,
Expand All @@ -86,6 +89,7 @@ def __init__(
n_in=n_input,
n_out=n_hidden,
n_cat_list=n_cat_list,
n_cont=n_cont,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
Expand All @@ -96,6 +100,7 @@ def __init__(
n_in=n_hidden + n_input,
n_out=n_output_proteins,
n_cat_list=n_cat_list,
n_cont=n_cont,
n_layers=1,
use_activation=True,
use_batch_norm=False,
Expand All @@ -108,6 +113,7 @@ def __init__(
n_in=n_hidden + n_input,
n_out=n_output_proteins,
n_cat_list=n_cat_list,
n_cont=n_cont,
**linear_args,
)

Expand All @@ -116,6 +122,7 @@ def __init__(
n_in=n_input,
n_out=n_hidden,
n_cat_list=n_cat_list,
n_cont=n_cont,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
Expand All @@ -128,12 +135,15 @@ def __init__(
n_in=n_hidden + n_input,
n_out=n_output_proteins,
n_cat_list=n_cat_list,
n_cont=n_cont,
inject_covariates=deep_inject_covariates,
**linear_args,
)
self.py_back_mean_log_beta = FCLayers(
n_in=n_hidden + n_input,
n_out=n_output_proteins,
n_cat_list=n_cat_list,
n_cont=n_cont,
**linear_args,
)

Expand All @@ -142,34 +152,39 @@ def __init__(
n_in=n_input,
n_out=n_hidden,
n_cat_list=n_cat_list,
n_cont=n_cont,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
use_batch_norm=use_batch_norm,
use_layer_norm=use_layer_norm,
)

def forward(self, z: torch.Tensor, *cat_list: int):
def forward(self, z: torch.Tensor, cont_covs: torch.Tensor | None = None, *cat_list: int):
"""Forward pass."""
# z is the latent repr
py_ = {}

py_back = self.py_back_decoder(z, *cat_list)
py_back = self.py_back_decoder(z, cont_covs, *cat_list)
py_back_cat_z = torch.cat([py_back, z], dim=-1)

py_["back_alpha"] = self.py_back_mean_log_alpha(py_back_cat_z, *cat_list)
py_["back_beta"] = torch.exp(self.py_back_mean_log_beta(py_back_cat_z, *cat_list))
py_["back_alpha"] = self.py_back_mean_log_alpha(py_back_cat_z, cont_covs, *cat_list)
py_["back_beta"] = torch.exp(
self.py_back_mean_log_beta(py_back_cat_z, cont_covs, *cat_list)
)
log_pro_back_mean = Normal(py_["back_alpha"], py_["back_beta"]).rsample()
py_["rate_back"] = torch.exp(log_pro_back_mean)

py_fore = self.py_fore_decoder(z, *cat_list)
py_fore = self.py_fore_decoder(z, cont_covs, *cat_list)
py_fore_cat_z = torch.cat([py_fore, z], dim=-1)
py_["fore_scale"] = self.py_fore_scale_decoder(py_fore_cat_z, *cat_list) + 1 + 1e-8
py_["fore_scale"] = (
self.py_fore_scale_decoder(py_fore_cat_z, cont_covs, *cat_list) + 1 + 1e-8
)
py_["rate_fore"] = py_["rate_back"] * py_["fore_scale"]

p_mixing = self.sigmoid_decoder(z, *cat_list)
p_mixing = self.sigmoid_decoder(z, cont_covs, *cat_list)
p_mixing_cat_z = torch.cat([p_mixing, z], dim=-1)
py_["mixing"] = self.py_background_decoder(p_mixing_cat_z, *cat_list)
py_["mixing"] = self.py_background_decoder(p_mixing_cat_z, cont_covs, *cat_list)

protein_mixing = 1 / (1 + torch.exp(-py_["mixing"]))
py_["scale"] = torch.nn.functional.normalize(
Expand Down Expand Up @@ -356,6 +371,7 @@ def __init__(
n_input=n_input_encoder_exp,
n_output=self.n_latent,
n_cat_list=encoder_cat_list,
n_cont=n_continuous_cov,
n_layers=self.n_layers_encoder,
n_hidden=self.n_hidden,
dropout_rate=self.dropout_rate,
Expand All @@ -372,6 +388,7 @@ def __init__(
self.l_encoder_expression = LibrarySizeEncoder(
n_input_encoder_exp,
n_cat_list=encoder_cat_list,
n_cont=n_continuous_cov,
n_layers=self.n_layers_encoder,
n_hidden=self.n_hidden,
use_batch_norm=self.use_batch_norm_encoder,
Expand All @@ -385,6 +402,7 @@ def __init__(
n_input_decoder,
n_input_genes,
n_cat_list=cat_list,
n_cont=n_continuous_cov,
n_layers=n_layers_decoder,
n_hidden=self.n_hidden,
inject_covariates=self.deeply_inject_covariates,
Expand All @@ -406,6 +424,7 @@ def __init__(
n_output=self.n_latent,
n_hidden=self.n_hidden,
n_cat_list=encoder_cat_list,
n_cont=n_continuous_cov,
dropout_rate=self.dropout_rate,
activation_fn=torch.nn.LeakyReLU,
distribution=self.latent_distribution,
Expand All @@ -426,6 +445,7 @@ def __init__(
n_output=n_input_regions,
n_hidden=self.n_hidden,
n_cat_list=cat_list,
n_cont=n_continuous_cov,
n_layers=self.n_layers_decoder,
use_batch_norm=self.use_batch_norm_decoder,
use_layer_norm=self.use_layer_norm_decoder,
Expand All @@ -438,6 +458,7 @@ def __init__(
n_output=1,
n_hidden=self.n_hidden,
n_cat_list=encoder_cat_list,
n_cont=n_continuous_cov,
n_layers=self.n_layers_encoder,
use_batch_norm=self.use_batch_norm_encoder,
use_layer_norm=self.use_layer_norm_encoder,
Expand Down Expand Up @@ -486,6 +507,7 @@ def __init__(
n_output=self.n_latent,
n_hidden=self.n_hidden,
n_cat_list=encoder_cat_list,
n_cont=n_continuous_cov,
dropout_rate=self.dropout_rate,
activation_fn=torch.nn.LeakyReLU,
distribution=self.latent_distribution,
Expand All @@ -501,6 +523,7 @@ def __init__(
n_output_proteins=n_input_proteins,
n_hidden=self.n_hidden,
n_cat_list=cat_list,
n_cont=n_continuous_cov,
n_layers=self.n_layers_decoder,
use_batch_norm=self.use_batch_norm_decoder,
use_layer_norm=self.use_layer_norm_decoder,
Expand Down Expand Up @@ -581,7 +604,7 @@ def inference(
mask_acc = x_chr.sum(dim=1) > 0
mask_pro = y.sum(dim=1) > 0

if cont_covs is not None and self.encode_covariates:
if cont_covs is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't remove the self.encode covariates check.

encoder_input_expression = torch.cat((x_rna, cont_covs), dim=-1)
encoder_input_accessibility = torch.cat((x_chr, cont_covs), dim=-1)
encoder_input_protein = torch.cat((y, cont_covs), dim=-1)
Expand All @@ -597,21 +620,21 @@ def inference(

# Z Encoders
qzm_acc, qzv_acc, z_acc = self.z_encoder_accessibility(
encoder_input_accessibility, batch_index, *categorical_input
encoder_input_accessibility, None, batch_index, *categorical_input
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this hard-coded?

)
qzm_expr, qzv_expr, z_expr = self.z_encoder_expression(
encoder_input_expression, batch_index, *categorical_input
encoder_input_expression, None, batch_index, *categorical_input
)
qzm_pro, qzv_pro, z_pro = self.z_encoder_protein(
encoder_input_protein, batch_index, *categorical_input
encoder_input_protein, None, batch_index, *categorical_input
)

# L encoders
libsize_expr = self.l_encoder_expression(
encoder_input_expression, batch_index, *categorical_input
encoder_input_expression, None, batch_index, *categorical_input
)
libsize_acc = self.l_encoder_accessibility(
encoder_input_accessibility, batch_index, *categorical_input
encoder_input_accessibility, None, batch_index, *categorical_input
)

# mix representations
Expand Down Expand Up @@ -733,7 +756,7 @@ def generative(
decoder_input = torch.cat([latent, cont_covs], dim=-1)

# Accessibility Decoder
p = self.z_decoder_accessibility(decoder_input, batch_index, *categorical_input)
p = self.z_decoder_accessibility(decoder_input, None, batch_index, *categorical_input)

# Expression Decoder
if not self.use_size_factor_key:
Expand All @@ -742,6 +765,7 @@ def generative(
self.gene_dispersion,
decoder_input,
size_factor,
None,
batch_index,
*categorical_input,
label,
Expand All @@ -758,7 +782,9 @@ def generative(
px_r = torch.exp(px_r)

# Protein Decoder
py_, log_pro_back_mean = self.z_decoder_pro(decoder_input, batch_index, *categorical_input)
py_, log_pro_back_mean = self.z_decoder_pro(
decoder_input, None, batch_index, *categorical_input
)
# Protein Dispersion
if self.protein_dispersion == "protein-label":
# py_r gets transposed - last dimension is n_proteins
Expand Down
Loading