From a4b96b2d401da805f6cd39773bb058d3b5790255 Mon Sep 17 00:00:00 2001 From: Xinyue Zhang Date: Wed, 21 Feb 2024 12:55:11 +0100 Subject: [PATCH] greatly improve the efficiency of the bottleneck - adding awkward array to adata.obsm using dataframe --- ehrdata/io/_omop.py | 43 ++++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/ehrdata/io/_omop.py b/ehrdata/io/_omop.py index e965519..60e3aa1 100644 --- a/ehrdata/io/_omop.py +++ b/ehrdata/io/_omop.py @@ -4,6 +4,7 @@ import awkward as ak import ehrapy as ep import pandas as pd +import pyarrow as pa from rich import print as rprint from ehrdata.utils._omop_utils import ( @@ -407,29 +408,29 @@ def extract_note( def from_dataframe(adata, feature: str, df): - grouped = df.groupby("visit_occurrence_id") - unique_visit_occurrence_ids = set(adata.obs.index) + # Add new rows for those visit_occurrence_id that don't have any data + new_row_dict = {col: [] for col in df.columns} + for key in new_row_dict.keys(): + if key == "visit_occurrence_id": + new_row_dict[key] = list(set(adata.obs.index) - set(df.visit_occurrence_id)) + else: + new_row_dict[key] = [None] * len(new_row_dict["visit_occurrence_id"]) + new_rows = pd.DataFrame(new_row_dict) + df = pd.concat([df, new_rows], ignore_index=True) - # Use set difference and intersection more efficiently - feature_ids = unique_visit_occurrence_ids.intersection(grouped.groups.keys()) - empty_entry = { - source_table_column: [] - for source_table_column in set(df.columns) - if source_table_column not in ["visit_occurrence_id"] - } + ak_array = ak.from_arrow(pa.Table.from_pandas(df), highlevel=True) + ak_array = ak.unflatten(ak_array, df["visit_occurrence_id"].value_counts(sort=False).values) + + # Need to sort the visit_occurrence_id in awkward array accoring to the sequence in the indices in the adata + id_in_df = list(df["visit_occurrence_id"].unique()) + id_in_adata = list(adata.obs.index) + index_dict = {value: index for index, value in enumerate(id_in_df)} + index = [index_dict[x] for x in id_in_adata] + + # Sort the ak_array to align with the adata + ak_array = ak_array[index] columns_in_ak_array = list(set(df.columns) - {"visit_occurrence_id"}) - # Creating the array more efficiently - ak_array = ak.Array( - [ - ( - grouped.get_group(visit_occurrence_id)[columns_in_ak_array].to_dict(orient="list") - if visit_occurrence_id in feature_ids - else empty_entry - ) - for visit_occurrence_id in unique_visit_occurrence_ids - ] - ) - adata.obsm[feature] = ak_array + adata.obsm[feature] = ak_array[columns_in_ak_array] return adata