Skip to content

Commit

Permalink
add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyuejohn committed Mar 5, 2024
1 parent 18e23d9 commit 852a3af
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 25 deletions.
50 changes: 36 additions & 14 deletions ehrdata/pl/_omop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import ehrapy as ep
import matplotlib.pyplot as plt
import pandas as pd
import scanpy as sc
import seaborn as sns
from anndata import AnnData
Expand All @@ -16,7 +17,7 @@


def feature_counts(
adata,
adata: AnnData,
source: Literal[
"observation",
"measurement",
Expand All @@ -26,18 +27,21 @@ def feature_counts(
"drug_exposure",
"condition_occurrence",
],
number=20,
key=None,
use_dask=None,
):
# if source == 'measurement':
# columns = ["value_as_number", "time", "visit_occurrence_id", "measurement_concept_id"]
# elif source == 'observation':
# columns = ["value_as_number", "value_as_string", "measurement_datetime"]
# elif source == 'condition_occurrence':
# columns = None
# else:
# raise KeyError(f"Extracting data from {source} is not supported yet")
number: int = 20,
use_dask: bool = None,
) -> pd.DataFrame:
"""Plot feature counts for a given source table and return a dataframe with feature names and counts.
Args:
adata (AnnData): Anndata object
source (Literal[ "observation", "measurement", "procedure_occurrence", "specimen", "device_exposure", "drug_exposure", "condition_occurrence", ]): source table name. Defaults to None.
number (int, optional): Number of top features to plot. Defaults to 20.
use_dask (bool, optional): If True, dask will be used to read the tables. For large tables, it is highly recommended to use dask. If None, it will be set to adata.uns["use_dask"]. Defaults to None.
Returns
-------
pd.DataFrame: Dataframe with feature names and counts
"""
path = adata.uns["filepath_dict"][source]
if isinstance(path, list):
if not use_dask or use_dask is None:
Expand Down Expand Up @@ -80,7 +84,21 @@ def plot_timeseries(
value_key: str = "value_as_number",
time_key: str = "measurement_datetime",
x_label: str = None,
show: Optional[bool] = None,
):
"""Plot timeseries data using data from adata.obsm.
Args:
adata (AnnData): Anndata object
visit_occurrence_id (int): visit_occurrence_id to plot
key (Union[str, list[str]]): feature key or list of keys in adata.obsm to plot
slot (Union[str, None], optional): Slot to use. Defaults to "obsm".
value_key (str, optional): key in awkward array in adata.obsm to be used as value. Defaults to "value_as_number".
time_key (str, optional): key in awkward array in adata.obsm to be used as time. Defaults to "measurement_datetime".
x_label (str, optional): x labe name. Defaults to None.
show (Optional[bool], optional): Show the plot, do not return axis.
"""
if isinstance(key, str):
key_list = [key]
else:
Expand Down Expand Up @@ -122,7 +140,10 @@ def plot_timeseries(
plt.xlabel(x_label if x_label else "Hours since ICU admission")

plt.tight_layout()
plt.show()
if not show:
return ax
else:
plt.show()


def violin(
Expand Down Expand Up @@ -153,6 +174,7 @@ def violin(
Args:
adata: :class:`~anndata.AnnData` object object containing all observations.
obsm_key: feature key or list of keys in adata.obsm to plot
keys: Keys for accessing variables of `.var_names` or fields of `.obs`.
groupby: The key of the observation grouping to consider.
log: Plot on logarithmic axis.
Expand Down
2 changes: 1 addition & 1 deletion ehrdata/pp/_omop.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_feature_statistics(
Args:
adata (AnnData): Anndata object
source (Literal[ "observation", "measurement", "procedure_occurrence", "specimen", "device_exposure", "drug_exposure", "condition_occurrence", ]): source table name. Defaults to None.
features (Union[str, int, list[Union[str, int]]], optional): feature_id or feature_name, or list of feature_id or feature_name. Defaults to None.
features (Union[str, int, list[Union[str, int]]], optional): concept_id or concept_name, or list of concept_id or concept_name. Defaults to None.
level (Literal["stay_level", "patient_level"], optional): For stay level, statistics are calculated for each stay. For patient level, statistics are calculated for each patient. It should be aligned with the setting of the adata object. Defaults to "stay_level".
value_col (str, optional): column name in source table to extract value from. Defaults to None.
aggregation_methods (Union[ Literal["min", "max", "mean", "std", "count"], list[Literal["min", "max", "mean", "std", "count"]] ], optional): aggregation methods to calculate statistics. Defaults to ["min", "max", "mean", "std", "count"].
Expand Down
45 changes: 35 additions & 10 deletions ehrdata/tl/_omop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,22 @@
from ehrdata.utils._omop_utils import df_to_dict, get_column_types, read_table


def get_concept_name(adata: Union[AnnData, dict], concept_id: Union[str, list], raise_error=False, verbose=True):
def get_concept_name(
adata: Union[AnnData, dict],
concept_id: Union[str, list],
raise_error: bool = False,
) -> Union[str, list[str]]:
"""Get concept name from concept_id using concept table
Args:
adata (Union[AnnData, dict]): Anndata object or adata.uns
concept_id (Union[str, list]): concept_id or list of concept_id
raise_error (bool, optional): If True, raise error if concept_id not found. Defaults to False.
Returns
-------
Union[str, list[str]]: concept_name or list of concept_name
"""
if isinstance(concept_id, numbers.Integral):
concept_id = [concept_id]

Expand All @@ -38,8 +53,7 @@ def get_concept_name(adata: Union[AnnData, dict], concept_id: Union[str, list],
concept_name_not_found.append(id)
if len(concept_name_not_found) > 0:
# warnings.warn(f"Couldn't find concept {id} in concept table!")
if verbose:
rprint(f"Couldn't find concept {concept_name_not_found} in concept table!")
rprint(f"Couldn't find concept {concept_name_not_found} in concept table!")
if raise_error:
raise KeyError
if len(concept_name) == 1:
Expand All @@ -50,15 +64,31 @@ def get_concept_name(adata: Union[AnnData, dict], concept_id: Union[str, list],

# downsampling
def aggregate_timeseries_in_bins(
adata,
adata: AnnData,
features: Union[str, list[str]],
slot: Union[str, None] = "obsm",
value_key: str = "value_as_number",
time_key: str = "measurement_datetime",
time_binning_method: Literal["floor", "ceil", "round"] = "floor",
bin_size: Union[str, Offset] = "h",
aggregation_method: Literal["median", "mean", "min", "max"] = "median",
):
) -> AnnData:
"""Aggregate timeseries data in bins
Args:
adata (AnnData): Anndata object
features (Union[str, list[str]]): concept_id or concept_name, or list of concept_id or concept_name. Defaults to None.
slot (Union[str, None], optional): Slot to read the data. Defaults to "obsm".
value_key (str, optional): key in awkward array in adata.obsm to be used as value. Defaults to "value_as_number".
time_key (str, optional): key in awkward array in adata.obsm to be used as time. Defaults to "measurement_datetime".
time_binning_method (Literal["floor", "ceil", "round"], optional): Time binning method. Defaults to "floor".
bin_size (Union[str, Offset], optional): Time bin size. Defaults to "h".
aggregation_method (Literal["median", "mean", "min", "max"], optional): Aggregation method. Defaults to "median".
Returns
-------
AnnData: Anndata object
"""
if isinstance(features, str):
features_list = [features]
else:
Expand Down Expand Up @@ -103,11 +133,6 @@ def aggregate_timeseries_in_bins(
return adata


# TODO
def get_concept_id():
pass


# TODO
def note_nlp_map(
self,
Expand Down

0 comments on commit 852a3af

Please sign in to comment.