Skip to content

Commit

Permalink
add progress bar arg toBenchmarker (#152)
Browse files Browse the repository at this point in the history
* add progress bar arg to benchmarker

* add release note
  • Loading branch information
martinkim0 authored Mar 1, 2024
1 parent 4538478 commit 45bfcfd
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning][].

## 0.6.0 (unreleased)

### Added

- Add `progress_bar` argument to {class}`scib_metrics.benchmark.Benchmarker` {pr}`152`.

## 0.5.1 (2024-02-23)

### Changed
Expand Down
25 changes: 20 additions & 5 deletions src/scib_metrics/benchmark/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ class Benchmarker:
in the prepare step. See the notes below for more information.
n_jobs
Number of jobs to use for parallelization of neighbor search.
progress_bar
Whether to show a progress bar for :meth:`~scib_metrics.benchmark.Benchmarker.prepare` and
:meth:`~scib_metrics.benchmark.Benchmarker.benchmark`.
Notes
-----
Expand All @@ -136,6 +139,7 @@ def __init__(
batch_correction_metrics: Optional[BatchCorrection] = None,
pre_integrated_embedding_obsm_key: Optional[str] = None,
n_jobs: int = 1,
progress_bar: bool = True,
):
self._adata = adata
self._embedding_obsm_keys = embedding_obsm_keys
Expand All @@ -150,6 +154,7 @@ def __init__(
self._batch_key = batch_key
self._label_key = label_key
self._n_jobs = n_jobs
self._progress_bar = progress_bar

self._metric_collection_dict = {
"Bio conservation": self._bio_conservation_metrics,
Expand Down Expand Up @@ -181,7 +186,11 @@ def prepare(self, neighbor_computer: Optional[Callable[[np.ndarray, int], Neighb
self._emb_adatas[emb_key].obsm[_X_PRE] = self._adata.obsm[self._pre_integrated_embedding_obsm_key]

# Compute neighbors
for ad in tqdm(self._emb_adatas.values(), desc="Computing neighbors"):
progress = self._emb_adatas.values()
if self._progress_bar:
progress = tqdm(progress, desc="Computing neighbors")

for ad in progress:
if neighbor_computer is not None:
neigh_result = neighbor_computer(ad.X, max(self._neighbor_values))
else:
Expand All @@ -208,12 +217,18 @@ def benchmark(self) -> None:
[sum([v is not False for v in asdict(met_col)]) for met_col in self._metric_collection_dict.values()]
)

for emb_key, ad in tqdm(self._emb_adatas.items(), desc="Embeddings", position=0, colour="green"):
pbar = tqdm(total=num_metrics, desc="Metrics", position=1, leave=False, colour="blue")
progress_embs = self._emb_adatas.items()
if self._progress_bar:
progress_embs = tqdm(self._emb_adatas.items(), desc="Embeddings", position=0, colour="green")

for emb_key, ad in progress_embs:
pbar = None
if self._progress_bar:
pbar = tqdm(total=num_metrics, desc="Metrics", position=1, leave=False, colour="blue")
for metric_type, metric_collection in self._metric_collection_dict.items():
for metric_name, use_metric_or_kwargs in asdict(metric_collection).items():
if use_metric_or_kwargs:
pbar.set_postfix_str(f"{metric_type}: {metric_name}")
pbar.set_postfix_str(f"{metric_type}: {metric_name}") if pbar is not None else None
metric_fn = getattr(scib_metrics, metric_name)
if isinstance(use_metric_or_kwargs, dict):
# Kwargs in this case
Expand All @@ -227,7 +242,7 @@ def benchmark(self) -> None:
else:
self._results.loc[metric_name, emb_key] = metric_value
self._results.loc[metric_name, _METRIC_TYPE] = metric_type
pbar.update(1)
pbar.update(1) if pbar is not None else None

self._benchmarked = True

Expand Down

0 comments on commit 45bfcfd

Please sign in to comment.