Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 9, 2025
1 parent 7e1a133 commit 34c1cb5
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
3 changes: 2 additions & 1 deletion src/scvi/model/base/_de_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def ravel_idx(my_idx, obs_df):
raise ValueError("One of idx1 or idx2 has size zero.")
return obs_col, group1, group2


def _subset_group(
subset_idx: list[bool] | np.ndarray | str,
groupby: list[bool] | np.ndarray | str,
Expand Down Expand Up @@ -98,7 +99,7 @@ def ravel_idx(my_idx, obs_df):
obs_col = obs_df[groupby].astype(str)
mask = np.ones_like(obs_col, dtype=bool)
mask[subset_idx] = False
obs_col[mask] = 'other'
obs_col[mask] = "other"
return obs_col


Expand Down
6 changes: 3 additions & 3 deletions src/scvi/model/base/_differential.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,14 +364,14 @@ def m1_domain_fn(samples):
proba_m2 = np.mean(is_de_minus, 0)
if test_mode == "two":
proba_de = proba_m1 + proba_m2
sign = 1.
sign = 1.0
else:
proba_de = np.maximum(proba_m1, proba_m2)
sign = np.sign(proba_m1 - proba_m2)
change_distribution_props = describe_continuous_distrib(
samples=change_fn(scales_1, scales_2, 1e-3*pseudocounts),
samples=change_fn(scales_1, scales_2, 1e-3 * pseudocounts),
credible_intervals_levels=cred_interval_lvls,
) # reduced pseudocounts to correctly estimate lfc
) # reduced pseudocounts to correctly estimate lfc
change_distribution_props = {
"lfc_" + key: val for (key, val) in change_distribution_props.items()
}
Expand Down
4 changes: 2 additions & 2 deletions src/scvi/model/base/_rnamixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _get_importance_weights(
return_mean=False,
n_mc_samples=n_mc_samples,
n_mc_samples_per_pass=n_mc_samples_per_pass,
) # n_anchors
) # n_anchors
mask = torch.tensor(anchor_cells)
qz_anchor = subset_distribution(qz, mask, 0) # n_anchors, n_latent
log_qz = qz_anchor.log_prob(zs.unsqueeze(-2)).sum(dim=-1) # n_samples, n_cells, n_anchors
Expand All @@ -129,7 +129,7 @@ def _get_importance_weights(
log_px_z.append(
distributions_px.log_prob(x_anchor).sum(dim=-1)[..., None].cpu()
) # n_samples, n_cells, 1
log_px_z = torch.cat(log_px_z, dim=-1) # n_samples, n_cells, n_anchors
log_px_z = torch.cat(log_px_z, dim=-1) # n_samples, n_cells, n_anchors

log_pz = log_pz.reshape(-1, 1)
log_px_z = log_px_z.reshape(-1, len(anchor_cells))
Expand Down

0 comments on commit 34c1cb5

Please sign in to comment.