From 68f8863c6f2d7131debeee4dea53358a5e451280 Mon Sep 17 00:00:00 2001 From: Martin Kim <46072231+martinkim0@users.noreply.github.com> Date: Mon, 9 Oct 2023 15:29:54 -0700 Subject: [PATCH] Backport PR #2271: Add support for AnnData 0.10.0 (#2278) --- docs/release_notes/index.md | 6 ++++++ scvi/data/_anntorchdataset.py | 10 +++++++++- scvi/data/_built_in_data/_synthetic.py | 6 +++--- scvi/data/_utils.py | 9 ++++++++- scvi/model/base/_archesmixin.py | 2 +- 5 files changed, 27 insertions(+), 6 deletions(-) diff --git a/docs/release_notes/index.md b/docs/release_notes/index.md index 168074a6bf..c3598586ac 100644 --- a/docs/release_notes/index.md +++ b/docs/release_notes/index.md @@ -17,6 +17,12 @@ is available in the [commit logs](https://github.com/YosefLab/scvi-tools/commits ## Version 1.0 +### 1.0.4 (2023-10-xx) + +### Added + +- Add support for AnnData 0.10.0 {pr}`2271`. + ### 1.0.3 (2023-08-13) ### Changed diff --git a/scvi/data/_anntorchdataset.py b/scvi/data/_anntorchdataset.py index 7d8843afce..87aa9dd5a0 100644 --- a/scvi/data/_anntorchdataset.py +++ b/scvi/data/_anntorchdataset.py @@ -4,7 +4,15 @@ import h5py import numpy as np import pandas as pd -from anndata._core.sparse_dataset import SparseDataset + +try: + from anndata._core.sparse_dataset import SparseDataset +except ImportError: + # anndata >= 0.10.0 + from anndata._core.sparse_dataset import ( + BaseCompressedSparseDataset as SparseDataset, + ) + from scipy.sparse import issparse from torch.utils.data import Dataset diff --git a/scvi/data/_built_in_data/_synthetic.py b/scvi/data/_built_in_data/_synthetic.py index 81a1619694..a39f749101 100644 --- a/scvi/data/_built_in_data/_synthetic.py +++ b/scvi/data/_built_in_data/_synthetic.py @@ -54,16 +54,16 @@ def _generate_synthetic( labels = np.random.randint(0, n_labels, size=(n_obs,)) labels = np.array([f"label_{i}" for i in labels]) - adata = AnnData(rna, dtype=np.float32) + adata = AnnData(rna) if return_mudata: mod_dict = {rna_key: adata} if n_proteins > 0: - protein_adata = AnnData(protein, dtype=np.float32) + protein_adata = AnnData(protein) protein_adata.var_names = protein_names mod_dict[protein_expression_key] = protein_adata if n_regions > 0: - mod_dict[accessibility_key] = AnnData(accessibility, dtype=np.float32) + mod_dict[accessibility_key] = AnnData(accessibility) adata = MuData(mod_dict) else: diff --git a/scvi/data/_utils.py b/scvi/data/_utils.py index beb4a27299..35edd34d78 100644 --- a/scvi/data/_utils.py +++ b/scvi/data/_utils.py @@ -10,7 +10,14 @@ import pandas as pd import scipy.sparse as sp_sparse from anndata import AnnData -from anndata._core.sparse_dataset import SparseDataset + +try: + from anndata._core.sparse_dataset import SparseDataset +except ImportError: + # anndata >= 0.10.0 + from anndata._core.sparse_dataset import ( + BaseCompressedSparseDataset as SparseDataset, + ) # TODO use the experimental api once we lower bound to anndata 0.8 try: diff --git a/scvi/model/base/_archesmixin.py b/scvi/model/base/_archesmixin.py index 3163ea699f..25b834f79e 100644 --- a/scvi/model/base/_archesmixin.py +++ b/scvi/model/base/_archesmixin.py @@ -245,7 +245,7 @@ def prepare_query_anndata( if inplace: if adata_out is not adata: - adata._init_as_actual(adata_out, dtype=adata._X.dtype) + adata._init_as_actual(adata_out) else: return adata_out