Skip to content

Commit

Permalink
feat: update to anndata 0.11 and memory efficient reads + writes (#1152)
Browse files Browse the repository at this point in the history
  • Loading branch information
nayib-jose-gloria authored Jan 10, 2025
1 parent e8c97c0 commit 6f7f496
Show file tree
Hide file tree
Showing 9 changed files with 214 additions and 186 deletions.
5 changes: 3 additions & 2 deletions cellxgene_schema_cli/cellxgene_schema/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def schema_cli(verbose):
type=click.Path(exists=False, dir_okay=False, writable=True),
)
@click.option("-i", "--ignore-labels", help="Ignore ontology labels when validating", is_flag=True)
def schema_validate(h5ad_file, add_labels_file, ignore_labels):
@click.option("-n", "--num-workers", help="Number of workers to use for parallel processing", default=1, type=int)
def schema_validate(h5ad_file, add_labels_file, ignore_labels, num_workers):
# Imports are very slow so we defer loading until Click arg validation has passed
logger.info("Loading dependencies")
try:
Expand All @@ -47,7 +48,7 @@ def schema_validate(h5ad_file, add_labels_file, ignore_labels):
logger.info("Loading validator modules")
from .validate import validate

is_valid, _, _ = validate(h5ad_file, add_labels_file, ignore_labels=ignore_labels)
is_valid, _, _ = validate(h5ad_file, add_labels_file, ignore_labels=ignore_labels, n_workers=num_workers)
if is_valid:
sys.exit(0)
else:
Expand Down
57 changes: 44 additions & 13 deletions cellxgene_schema_cli/cellxgene_schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from typing import Dict, List, Union

import anndata as ad
import h5py
import numpy as np
from anndata.compat import DaskArray
from anndata.experimental import read_dispatched, read_elem_as_dask
from cellxgene_ontology_guide.ontology_parser import OntologyParser
from scipy import sparse
from xxhash import xxh3_64_intdigest
Expand Down Expand Up @@ -68,7 +71,7 @@ def remap_deprecated_features(*, adata: ad.AnnData, remapped_features: Dict[str,
return adata


def get_matrix_format(adata: ad.AnnData, matrix: Union[np.ndarray, sparse.spmatrix]) -> str:
def get_matrix_format(matrix: DaskArray) -> str:
"""
Given a matrix, returns the format as one of: csc, csr, coo, dense
or unknown.
Expand All @@ -84,15 +87,11 @@ def get_matrix_format(adata: ad.AnnData, matrix: Union[np.ndarray, sparse.spmatr
# >>> return getattr(matrix, "format_str", "dense)
#
matrix_format = "unknown"
if adata.n_obs == 0 or adata.n_vars == 0:
matrix_slice = matrix[0:1, 0:1].compute()
if isinstance(matrix_slice, sparse.spmatrix):
matrix_format = matrix_slice.format
elif isinstance(matrix_slice, np.ndarray):
matrix_format = "dense"
else:
matrix_slice = matrix[0:1, 0:1]
if isinstance(matrix_slice, sparse.spmatrix):
matrix_format = matrix_slice.format
elif isinstance(matrix_slice, np.ndarray):
matrix_format = "dense"

assert matrix_format in ["unknown", "csr", "csc", "coo", "dense"]
return matrix_format

Expand All @@ -116,21 +115,53 @@ def getattr_anndata(adata: ad.AnnData, attr: str = None):
return getattr(adata, attr)


def read_h5ad(h5ad_path: Union[str, bytes, os.PathLike]) -> ad.AnnData:
def read_backed(f: h5py.File, chunk_size: int) -> ad.AnnData:
"""
Read an AnnData object from a h5py.File object, reading in matrices (dense or sparse) as dask arrays. Does not
read full matrices into memory.
:param f: h5py.File object
:param chunk_size: size of chunks to read matrices in
:return: ad.AnnData object
"""

def callback(func, elem_name: str, elem, iospec):
if "/layers" in elem_name or elem_name == "/X" or elem_name == "/raw/X":
if iospec.encoding_type in (
"csr_matrix",
"csc_matrix",
):
n_vars = elem.attrs.get("shape")[1]
return read_elem_as_dask(elem, chunks=(chunk_size, n_vars))
elif iospec.encoding_type == "array" and len(elem.shape) == 2:
n_vars = elem.shape[1]
return read_elem_as_dask(elem, chunks=(chunk_size, n_vars))
else:
return func(elem)
else:
return func(elem)

adata = read_dispatched(f, callback=callback)

return adata


def read_h5ad(h5ad_path: Union[str, bytes, os.PathLike], chunk_size: int = 5000) -> ad.AnnData:
"""
Reads h5ad into adata
:params Union[str, bytes, os.PathLike] h5ad_path: path to h5ad to read
:rtype None
"""
try:
adata = ad.read_h5ad(h5ad_path, backed="r")
f = h5py.File(h5ad_path)
adata = read_backed(f, chunk_size)

# This code, and AnnData in general, is optimized for row access.
# Running backed, with CSC, is prohibitively slow. Read the entire
# AnnData into memory if it is CSC.
if (get_matrix_format(adata, adata.X) == "csc") or (
(adata.raw is not None) and (get_matrix_format(adata, adata.raw.X) == "csc")
if (get_matrix_format(adata.X) == "csc") or (
(adata.raw is not None) and (get_matrix_format(adata.raw.X) == "csc")
):
logger.warning("Matrices are in CSC format; loading entire dataset into memory.")
adata = adata.to_memory()
Expand Down
Loading

0 comments on commit 6f7f496

Please sign in to comment.