diff --git a/src/scvi/model/base/_de_core.py b/src/scvi/model/base/_de_core.py index 2a96f5e25c..9d52b02dff 100644 --- a/src/scvi/model/base/_de_core.py +++ b/src/scvi/model/base/_de_core.py @@ -61,6 +61,7 @@ def ravel_idx(my_idx, obs_df): raise ValueError("One of idx1 or idx2 has size zero.") return obs_col, group1, group2 + def _subset_group( subset_idx: list[bool] | np.ndarray | str, groupby: list[bool] | np.ndarray | str, @@ -98,7 +99,7 @@ def ravel_idx(my_idx, obs_df): obs_col = obs_df[groupby].astype(str) mask = np.ones_like(obs_col, dtype=bool) mask[subset_idx] = False - obs_col[mask] = 'other' + obs_col[mask] = "other" return obs_col diff --git a/src/scvi/model/base/_differential.py b/src/scvi/model/base/_differential.py index c4fd47ffdf..8ecdab9c8c 100644 --- a/src/scvi/model/base/_differential.py +++ b/src/scvi/model/base/_differential.py @@ -364,14 +364,14 @@ def m1_domain_fn(samples): proba_m2 = np.mean(is_de_minus, 0) if test_mode == "two": proba_de = proba_m1 + proba_m2 - sign = 1. + sign = 1.0 else: proba_de = np.maximum(proba_m1, proba_m2) sign = np.sign(proba_m1 - proba_m2) change_distribution_props = describe_continuous_distrib( - samples=change_fn(scales_1, scales_2, 1e-3*pseudocounts), + samples=change_fn(scales_1, scales_2, 1e-3 * pseudocounts), credible_intervals_levels=cred_interval_lvls, - ) # reduced pseudocounts to correctly estimate lfc + ) # reduced pseudocounts to correctly estimate lfc change_distribution_props = { "lfc_" + key: val for (key, val) in change_distribution_props.items() } diff --git a/src/scvi/model/base/_rnamixin.py b/src/scvi/model/base/_rnamixin.py index 8c5a0fb65a..8cbbc8fe91 100644 --- a/src/scvi/model/base/_rnamixin.py +++ b/src/scvi/model/base/_rnamixin.py @@ -112,7 +112,7 @@ def _get_importance_weights( return_mean=False, n_mc_samples=n_mc_samples, n_mc_samples_per_pass=n_mc_samples_per_pass, - ) # n_anchors + ) # n_anchors mask = torch.tensor(anchor_cells) qz_anchor = subset_distribution(qz, mask, 0) # n_anchors, n_latent log_qz = qz_anchor.log_prob(zs.unsqueeze(-2)).sum(dim=-1) # n_samples, n_cells, n_anchors @@ -129,7 +129,7 @@ def _get_importance_weights( log_px_z.append( distributions_px.log_prob(x_anchor).sum(dim=-1)[..., None].cpu() ) # n_samples, n_cells, 1 - log_px_z = torch.cat(log_px_z, dim=-1) # n_samples, n_cells, n_anchors + log_px_z = torch.cat(log_px_z, dim=-1) # n_samples, n_cells, n_anchors log_pz = log_pz.reshape(-1, 1) log_px_z = log_px_z.reshape(-1, len(anchor_cells))