From d56a576189c510f52534720c13480c9a0b914d68 Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Sun, 12 Jan 2025 03:15:54 -0800 Subject: [PATCH] Backport PR #3132: feat: Improve speed for transfer fields if using many categories. --- src/scvi/data/fields/_dataframe_field.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/scvi/data/fields/_dataframe_field.py b/src/scvi/data/fields/_dataframe_field.py index 633f3e8673..ef6fdfb618 100644 --- a/src/scvi/data/fields/_dataframe_field.py +++ b/src/scvi/data/fields/_dataframe_field.py @@ -2,6 +2,7 @@ from typing import Literal import numpy as np +import pandas as pd import rich from anndata import AnnData from pandas.api.types import CategoricalDtype @@ -211,15 +212,17 @@ def transfer_field( mapping = state_registry[self.CATEGORICAL_MAPPING_KEY].copy() # extend mapping for new categories - for c in np.unique(self._get_original_column(adata_target)): - if c not in mapping: - if extend_categories: - mapping = np.concatenate([mapping, [c]]) - else: - raise ValueError( - f"Category {c} not found in source registry. " - f"Cannot transfer setup without `extend_categories = True`." - ) + missing_categories = ( + pd.Index(np.unique(self._get_original_column(adata_target))) + .difference(pd.Index(mapping)) + .to_numpy() + ) + if missing_categories.any() and not extend_categories: + raise ValueError( + f"Category {missing_categories[0]} not found in source registry. " + f"Cannot transfer setup without `extend_categories = True`." + ) + mapping = np.concatenate([mapping, missing_categories]) cat_dtype = CategoricalDtype(categories=mapping, ordered=True) new_mapping = _make_column_categorical( getattr(adata_target, self.attr_name),