Skip to content

Commit

Permalink
feat(external): adding batch_key and labels_key (#3045)
Browse files Browse the repository at this point in the history
Adds batch_key and labels_key per documentation. Also corrected the
reference to scvi.external.

Co-authored-by: Ori Kronfeld <[email protected]>
  • Loading branch information
mys721tx and ori-kron-wis authored Nov 19, 2024
1 parent c8f242c commit 03acd60
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ to [Semantic Versioning]. Full commit history is available in the
`datasplitter_kwargs={"drop_last": False}` and `train_size = None` by moving them into
validation set, if available.
{pr}`3036`.
- Add `batch_key` and `labels_key` to `scvi.external.SCAR.setup_anndata`.

#### Fixed

- Breaking Change: Fix `get_outlier_cell_sample_pairs` function in {class}`scvi.external.MRVI`
to correctly compute the maxmimum log-density across in-sample cells rather than the
aggregated posterior log-density {pr}`3007`.
- Fix references to `scvi.external` in `scvi.external.SCAR.setup_anndata`.

#### Changed

Expand Down
18 changes: 11 additions & 7 deletions src/scvi/external/scar/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class SCAR(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
Parameters
----------
adata
AnnData object that has been registered via :meth:`~scvi_external.SCAR.setup_anndata`.
AnnData object that has been registered via :meth:`~scvi.external.SCAR.setup_anndata`.
ambient_profile
The probability of occurrence of each ambient transcript.\
If None, averaging cells to estimate the ambient profile, by default None.
Expand Down Expand Up @@ -70,15 +70,15 @@ class SCAR(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
the sparsity should be low; on the other hand, it should be set high
in the case of unflitered genes.
**model_kwargs
Keyword args for :class:`~scvi_external.SCAR`
Keyword args for :class:`~scvi.external.SCAR`
Examples
--------
>>> adata = anndata.read_h5ad(path_to_anndata)
>>> raw_adata = anndata.read_h5ad(path_to_raw_anndata)
>>> scvi_external.SCAR.setup_anndata(adata, batch_key="batch")
>>> scvi_external.SCAR.get_ambient_profile(adata=adata, raw_adata=raw_adata, prob=0.995)
>>> vae = scvi_external.SCAR(adata)
>>> scvi.external.SCAR.setup_anndata(adata, batch_key="batch")
>>> scvi.external.SCAR.get_ambient_profile(adata=adata, raw_adata=raw_adata, prob=0.995)
>>> vae = scvi.external.SCAR(adata)
>>> vae.train()
>>> adata.obsm["X_scAR"] = vae.get_latent_representation()
>>> adata.layers['denoised'] = vae.get_denoised_counts()
Expand Down Expand Up @@ -152,6 +152,8 @@ def __init__(
def setup_anndata(
cls,
adata: AnnData,
batch_key: str | None = None,
labels_key: str | None = None,
layer: str | None = None,
size_factor_key: str | None = None,
**kwargs,
Expand All @@ -161,14 +163,16 @@ def setup_anndata(
Parameters
----------
%(param_adata)s
%(param_batch_key)s
%(param_labels_key)s
%(param_layer)s
%(param_size_factor_key)s
"""
setup_method_args = cls._get_setup_method_args(**locals())
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, None),
CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None),
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False),
]
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
Expand Down
5 changes: 4 additions & 1 deletion tests/external/scar/test_scar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
def test_scar():
n_latent = 5
adata = synthetic_iid()

adata.obs["batch"] = adata.obs["batch"].cat.rename_categories(["batch_2", "batch_3"])

adata.X = scipy.sparse.csr_matrix(adata.X)
SCAR.setup_anndata(adata)
SCAR.setup_anndata(adata, batch_key="batch", labels_key="labels")

_ = SCAR.get_ambient_profile(adata, adata, prob=0.0, iterations=1, sample=100)
model = SCAR(adata, ambient_profile=None, n_latent=n_latent)
Expand Down

0 comments on commit 03acd60

Please sign in to comment.