From ee96e71b0a34cb1236479a0406e82f0aad86dd71 Mon Sep 17 00:00:00 2001 From: Jayoung Ryu Date: Sat, 25 Nov 2023 17:51:15 -0500 Subject: [PATCH] Initial run with condition --- bean/framework/ReporterScreen.py | 6 ++- bean/model/model.py | 73 +++++++++++++++++++-------- bean/model/utils.py | 28 ++++++---- bean/preprocessing/data_class.py | 43 ++++++++++++++-- bean/preprocessing/utils.py | 8 +-- bin/bean-run | 20 ++++++-- notebooks/sample_quality_report.ipynb | 17 ++++--- 7 files changed, 144 insertions(+), 51 deletions(-) diff --git a/bean/framework/ReporterScreen.py b/bean/framework/ReporterScreen.py index 5d012c6..37631d5 100644 --- a/bean/framework/ReporterScreen.py +++ b/bean/framework/ReporterScreen.py @@ -99,6 +99,8 @@ def __init__( self.layers["X_bcmatch"] = X_bcmatch for k, df in self.uns.items(): if not isinstance(df, pd.DataFrame): + if k == "sample_covariates" and not isinstance(df, list): + self.uns[k] = df.tolist() continue if "guide" in df.columns and len(df) > 0: if ( @@ -325,13 +327,13 @@ def __getitem__(self, index): if k.startswith("repguide_mask"): if "sample_covariates" in adata.uns: adata.var["_rc"] = adata.var[ - ["rep"] + adata.uns["sample_covariates"] + ["rep"] + list(adata.uns["sample_covariates"]) ].values.tolist() adata.var["_rc"] = adata.var["_rc"].map( lambda slist: ".".join(slist) ) new_uns[k] = df.loc[guides_include, adata.var._rc.unique()] - adata.var.pop("_rc") + #adata.var.pop("_rc") else: new_uns[k] = df.loc[guides_include, adata.var.rep.unique()] if not isinstance(df, pd.DataFrame): diff --git a/bean/model/model.py b/bean/model/model.py index c6be6e9..91514f1 100644 --- a/bean/model/model.py +++ b/bean/model/model.py @@ -45,26 +45,51 @@ def NormalModel( sd = sd_alleles sd = torch.repeat_interleave(sd, data.target_lengths, dim=0) assert sd.shape == (data.n_guides, 1) - + if data.sample_covariates is not None: + with pyro.plate("cov_place", data.n_sample_covariates): + mu_cov = pyro.sample("mu_cov", dist.Normal(0, 1)) + assert mu_cov.shape == (data.n_sample_covariates,), mu_cov.shape with replicate_plate: with bin_plate as b: uq = data.upper_bounds[b] lq = data.lower_bounds[b] assert uq.shape == lq.shape == (data.n_condits,) - # with guide_plate, poutine.mask(mask=(data.allele_counts.sum(axis=-1) == 0)): with guide_plate: + mu = mu.unsqueeze(0).unsqueeze(0).expand( + (data.n_reps, data.n_condits, -1, -1) + ) + (data.rep_by_cov * mu_cov)[:, 0].unsqueeze(-1).unsqueeze( + -1 + ).unsqueeze( + -1 + ).expand( + (-1, data.n_condits, data.n_guides, 1) + ) + sd = torch.sqrt( + ( + sd.unsqueeze(0) + .unsqueeze(0) + .expand((data.n_reps, data.n_condits, -1, -1)) + ) + ) alleles_p_bin = get_std_normal_prob( - uq.unsqueeze(-1).unsqueeze(-1).expand((-1, data.n_guides, 1)), - lq.unsqueeze(-1).unsqueeze(-1).expand((-1, data.n_guides, 1)), - mu.unsqueeze(0).expand((data.n_condits, -1, -1)), - sd.unsqueeze(0).expand((data.n_condits, -1, -1)), + uq.unsqueeze(0) + .unsqueeze(-1) + .unsqueeze(-1) + .expand((data.n_reps, -1, data.n_guides, 1)), + lq.unsqueeze(0) + .unsqueeze(-1) + .unsqueeze(-1) + .expand((data.n_reps, -1, data.n_guides, 1)), + mu, + sd, ) - assert alleles_p_bin.shape == (data.n_condits, data.n_guides, 1) - - expected_allele_p = alleles_p_bin.unsqueeze(0).expand( - data.n_reps, -1, -1, -1 - ) - expected_guide_p = expected_allele_p.sum(axis=-1) + assert alleles_p_bin.shape == ( + data.n_reps, + data.n_condits, + data.n_guides, + 1, + ) + expected_guide_p = alleles_p_bin.sum(axis=-1) assert expected_guide_p.shape == ( data.n_reps, data.n_condits, @@ -158,14 +183,10 @@ def ControlNormalModel(data, mask_thres=10, use_bcmatch=True): with pyro.plate("guide_plate3", data.n_guides, dim=-1): a = get_alpha(expected_guide_p, data.size_factor, data.sample_mask, data.a0) - assert ( - data.X.shape - == data.X_bcmatch.shape - == ( - data.n_reps, - data.n_condits, - data.n_guides, - ) + assert data.X.shape == ( + data.n_reps, + data.n_condits, + data.n_guides, ) with poutine.mask( mask=torch.logical_and( @@ -490,6 +511,18 @@ def NormalGuide(data): constraint=constraints.positive, ) pyro.sample("sd_alleles", dist.LogNormal(sd_loc, sd_scale)) + if data.sample_covariates is not None: + with pyro.plate("cov_place", data.n_sample_covariates): + mu_cov_loc = pyro.param( + "mu_cov_loc", torch.zeros((data.n_sample_covariates,)) + ) + mu_cov_scale = pyro.param( + "mu_cov_scale", + torch.ones((data.n_sample_covariates,)), + constraint=constraints.positive, + ) + mu_cov = pyro.sample("mu_cov", dist.Normal(mu_cov_loc, mu_cov_scale)) + assert mu_cov.shape == (data.n_sample_covariates,), mu_cov.shape def MixtureNormalGuide( diff --git a/bean/model/utils.py b/bean/model/utils.py index 7767232..7dfba0a 100644 --- a/bean/model/utils.py +++ b/bean/model/utils.py @@ -8,17 +8,23 @@ def get_alpha( expected_guide_p, size_factor, sample_mask, a0, epsilon=1e-5, normalize_by_a0=True ): - p = ( - expected_guide_p.permute(0, 2, 1) * size_factor[:, None, :] - ) # (n_reps, n_guides, n_bins) - if normalize_by_a0: - a = ( - (p + epsilon / p.shape[-1]) - / (p.sum(axis=-1)[:, :, None] + epsilon) - * a0[None, :, None] - ) - a = (a * sample_mask[:, None, :]).clamp(min=epsilon) - return a + try: + p = ( + expected_guide_p.permute(0, 2, 1) * size_factor[:, None, :] + ) # (n_reps, n_guides, n_bins) + + if normalize_by_a0: + a = ( + (p + epsilon / p.shape[-1]) + / (p.sum(axis=-1)[:, :, None] + epsilon) + * a0[None, :, None] + ) + a = (a * sample_mask[:, None, :]).clamp(min=epsilon) + return a + except: + print(size_factor.shape) + print(expected_guide_p.shape) + print(a0.shape) a = (p * sample_mask[:, None, :]).clamp(min=epsilon) return a diff --git a/bean/preprocessing/data_class.py b/bean/preprocessing/data_class.py index 0a16039..9f01947 100644 --- a/bean/preprocessing/data_class.py +++ b/bean/preprocessing/data_class.py @@ -60,17 +60,34 @@ def __init__( self.device = device screen.samples["size_factor"] = self.get_size_factor(screen.X) if not ( - "rep" in screen.samples.columns + replicate_column in screen.samples.columns and condition_column in screen.samples.columns ): - screen.samples["rep"], screen.samples[condition_column] = zip( + screen.samples[replicate_column], screen.samples[condition_column] = zip( *screen.samples.index.map(lambda s: s.rsplit("_", 1)) ) if condition_column not in screen.samples.columns: screen.samples[condition_column] = screen.samples["index"].map( lambda s: s.split("_")[-1] ) - + if "sample_covariates" in screen.uns: + self.sample_covariates = screen.uns["sample_covariates"] + self.n_sample_covariates = len(self.sample_covariates) + screen.samples["_rc"] = screen.samples[ + [replicate_column] + self.sample_covariates + ].values.tolist() + screen.samples["_rc"] = screen.samples["_rc"].map( + lambda slist: ".".join(slist) + ) + self.rep_by_cov = torch.as_tensor( + ( + screen.samples[["_rc"] + self.sample_covariates] + .drop_duplicates() + .set_index("_rc") + .values.astype(int) + ) + ) + replicate_column = "_rc" self.screen = screen if not control_can_be_selected: self.screen_selected = screen[ @@ -146,7 +163,7 @@ def _post_init( ).all() assert ( self.screen_selected.uns[self.repguide_mask].columns - == self.screen_selected.samples.rep.unique() + == self.screen_selected.samples[self.replicate_column].unique() ).all() self.repguide_mask = ( torch.as_tensor(self.screen_selected.uns[self.repguide_mask].values.T) @@ -182,6 +199,7 @@ def __getitem__(self, guide_idx): ndata.X_masked = ndata.X_masked[:, :, guide_idx] ndata.X_control = ndata.X_control[:, :, guide_idx] ndata.repguide_mask = ndata.repguide_mask[:, guide_idx] + ndata.a0 = ndata.a0[guide_idx] return ndata def transform_data(self, X, n_bins=None): @@ -905,9 +923,20 @@ def _pre_init( self.screen.samples.loc[ self.screen_selected.samples.index, f"{self.condition_column}_id" ] = self.screen_selected.samples[f"{self.condition_column}_id"] + print(self.screen.samples.columns) self.screen = _assign_rep_ids_and_sort( self.screen, self.replicate_column, self.condition_column ) + print(self.screen.samples.columns) + if self.sample_covariates is not None: + self.rep_by_cov = torch.as_tensor( + ( + self.screen.samples[["_rc"] + self.sample_covariates] + .drop_duplicates() + .set_index("_rc") + .values.astype(int) + ) + ) self.screen_selected = _assign_rep_ids_and_sort( self.screen_selected, self.replicate_column, self.condition_column ) @@ -986,8 +1015,12 @@ def _post_init( self.screen = _assign_rep_ids_and_sort( self.screen, self.replicate_column, self.time_column ) + if self.sample_covariates is not None: + self.rep_by_cov = self.screen.samples.groupby(self.replicate_column)[ + self.sample_covariates + ].values self.screen_selected = _assign_rep_ids_and_sort( - self.screen_selected, self.replicate_column, self.time_column + self.screen_selected, self.replicate_column, self.condition_column ) self.screen_control = _assign_rep_ids_and_sort( self.screen_control, diff --git a/bean/preprocessing/utils.py b/bean/preprocessing/utils.py index 8596782..44c6431 100644 --- a/bean/preprocessing/utils.py +++ b/bean/preprocessing/utils.py @@ -219,10 +219,10 @@ def _assign_rep_ids_and_sort( sort_key = f"{rep_col}_id" else: sort_key = [f"{rep_col}_id", f"{condition_column}_id"] - screen = screen[ - :, - screen.samples.sort_values(sort_key).index, - ] + screen = screen[ + :, + screen.samples.sort_values(sort_key).index, + ] return screen diff --git a/bin/bean-run b/bin/bean-run index 669da0b..04d364e 100644 --- a/bin/bean-run +++ b/bin/bean-run @@ -2,6 +2,8 @@ import os import sys import logging +import warnings +from functools import partial from copy import deepcopy import numpy as np import pandas as pd @@ -43,6 +45,11 @@ warn = logging.warning debug = logging.debug info = logging.info pyro.set_rng_seed(101) +warnings.filterwarnings( + "ignore", + category=FutureWarning, + message=r".*is_categorical_dtype is deprecated and will be removed in a future version.*", +) def main(args, bdata): @@ -127,8 +134,15 @@ def main(args, bdata): run_inference(model, guide, ndata, num_steps=args.n_iter) ) if args.fit_negctrl: - negctrl_model = m.ControlNormalModel - negctrl_guide = m.ControlNormalGuide + negctrl_model = partial( + m.ControlNormalModel, + use_bcmatch=(not args.ignore_bcmatch and "X_bcmatch" in bdata.layers), + ) + print((not args.ignore_bcmatch and "X_bcmatch" in bdata.layers)) + negctrl_guide = partial( + m.ControlNormalGuide, + use_bcmatch=(not args.ignore_bcmatch and "X_bcmatch" in bdata.layers), + ) negctrl_idx = np.where( guide_info_df[args.negctrl_col].map(lambda s: s.lower()) == args.negctrl_col_value.lower() @@ -137,7 +151,7 @@ def main(args, bdata): print(negctrl_idx.shape) ndata_negctrl = ndata[negctrl_idx] param_history_dict["negctrl"] = run_inference( - negctrl_model, negctrl_guide, ndata_negctrl + negctrl_model, negctrl_guide, ndata_negctrl, num_steps=args.n_iter ) outfile_path = ( diff --git a/notebooks/sample_quality_report.ipynb b/notebooks/sample_quality_report.ipynb index 205d3a3..bf632b3 100644 --- a/notebooks/sample_quality_report.ipynb +++ b/notebooks/sample_quality_report.ipynb @@ -76,9 +76,10 @@ "outputs": [], "source": [ "if tiling is not None:\n", - " bdata.uns['tiling'] = tiling\n", + " bdata.uns[\"tiling\"] = tiling\n", "if not isinstance(replicate_label, str):\n", - " bdata.uns['sample_covariates'] = replicate_label[1:]" + " bdata.uns[\"sample_covariates\"] = replicate_label[1:]\n", + "bdata.samples[replicate_label] = bdata.samples[replicate_label].astype(str)" ] }, { @@ -352,11 +353,15 @@ "metadata": {}, "outputs": [], "source": [ - "bdata.samples['mask'] = 1\n", - "bdata.samples.loc[bdata.samples.median_corr_X < corr_X_thres, 'mask'] = 0\n", + "bdata.samples[\"mask\"] = 1\n", + "bdata.samples.loc[\n", + " bdata.samples.median_corr_X.isnull() | (bdata.samples.median_corr_X < corr_X_thres), \"mask\"\n", + "] = 0\n", "if \"median_editing_rate\" in bdata.samples.columns.tolist():\n", - " bdata.samples.loc[bdata.samples.median_editing_rate < edit_rate_thres, 'mask'] = 0\n", - "bdata_filtered = bdata[:, bdata.samples[f\"median_lfc_corr.{comp_cond1}_{comp_cond2}\"] > lfc_thres]" + " bdata.samples.loc[bdata.samples.median_editing_rate < edit_rate_thres, \"mask\"] = 0\n", + "bdata_filtered = bdata[\n", + " :, bdata.samples[f\"median_lfc_corr.{comp_cond1}_{comp_cond2}\"] > lfc_thres\n", + "]" ] }, {