Skip to content

Commit

Permalink
Backport PR #3132 on branch 1.2.x (feat: Improve speed for transfer f…
Browse files Browse the repository at this point in the history
…ields if using many categories.) (#3140)

Backport PR #3132: feat: Improve speed for transfer fields if using many
categories.

Co-authored-by: Can Ergen <[email protected]>
  • Loading branch information
meeseeksmachine and canergen authored Jan 12, 2025
1 parent 935b56b commit 2e7933c
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions src/scvi/data/fields/_dataframe_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Literal

import numpy as np
import pandas as pd
import rich
from anndata import AnnData
from pandas.api.types import CategoricalDtype
Expand Down Expand Up @@ -211,15 +212,17 @@ def transfer_field(
mapping = state_registry[self.CATEGORICAL_MAPPING_KEY].copy()

# extend mapping for new categories
for c in np.unique(self._get_original_column(adata_target)):
if c not in mapping:
if extend_categories:
mapping = np.concatenate([mapping, [c]])
else:
raise ValueError(
f"Category {c} not found in source registry. "
f"Cannot transfer setup without `extend_categories = True`."
)
missing_categories = (
pd.Index(np.unique(self._get_original_column(adata_target)))
.difference(pd.Index(mapping))
.to_numpy()
)
if missing_categories.any() and not extend_categories:
raise ValueError(
f"Category {missing_categories[0]} not found in source registry. "
f"Cannot transfer setup without `extend_categories = True`."
)
mapping = np.concatenate([mapping, missing_categories])
cat_dtype = CategoricalDtype(categories=mapping, ordered=True)
new_mapping = _make_column_categorical(
getattr(adata_target, self.attr_name),
Expand Down

0 comments on commit 2e7933c

Please sign in to comment.