From c9c13795ede2ffffad78b7373cbd1197f0277b32 Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Thu, 9 Jan 2025 12:45:49 -0800 Subject: [PATCH 1/2] Improve speed in checkout fields. --- src/scvi/data/fields/_dataframe_field.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/scvi/data/fields/_dataframe_field.py b/src/scvi/data/fields/_dataframe_field.py index 633f3e8673..1c5c93acab 100644 --- a/src/scvi/data/fields/_dataframe_field.py +++ b/src/scvi/data/fields/_dataframe_field.py @@ -2,6 +2,7 @@ from typing import Literal import numpy as np +import pandas as pd import rich from anndata import AnnData from pandas.api.types import CategoricalDtype @@ -211,15 +212,15 @@ def transfer_field( mapping = state_registry[self.CATEGORICAL_MAPPING_KEY].copy() # extend mapping for new categories - for c in np.unique(self._get_original_column(adata_target)): - if c not in mapping: - if extend_categories: - mapping = np.concatenate([mapping, [c]]) - else: - raise ValueError( - f"Category {c} not found in source registry. " - f"Cannot transfer setup without `extend_categories = True`." - ) + missing_categories = pd.Index( + np.unique(self._get_original_column(adata_target)) + ).difference(pd.Index(mapping)).to_numpy() + if missing_categories.any() and not extend_categories: + raise ValueError( + f"Category {missing_categories[0]} not found in source registry. " + f"Cannot transfer setup without `extend_categories = True`." + ) + mapping = np.concatenate([mapping, missing_categories]) cat_dtype = CategoricalDtype(categories=mapping, ordered=True) new_mapping = _make_column_categorical( getattr(adata_target, self.attr_name), From 25045697de3fb9544d278b426d876c08734f4e05 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Jan 2025 20:50:38 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/data/fields/_dataframe_field.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/scvi/data/fields/_dataframe_field.py b/src/scvi/data/fields/_dataframe_field.py index 1c5c93acab..ef6fdfb618 100644 --- a/src/scvi/data/fields/_dataframe_field.py +++ b/src/scvi/data/fields/_dataframe_field.py @@ -212,9 +212,11 @@ def transfer_field( mapping = state_registry[self.CATEGORICAL_MAPPING_KEY].copy() # extend mapping for new categories - missing_categories = pd.Index( - np.unique(self._get_original_column(adata_target)) - ).difference(pd.Index(mapping)).to_numpy() + missing_categories = ( + pd.Index(np.unique(self._get_original_column(adata_target))) + .difference(pd.Index(mapping)) + .to_numpy() + ) if missing_categories.any() and not extend_categories: raise ValueError( f"Category {missing_categories[0]} not found in source registry. "