Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add seed to leiden clustering #173

Merged
merged 10 commits into from
Aug 22, 2024
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ and this project adheres to [Semantic Versioning][].

- Add `progress_bar` argument to {class}`scib_metrics.benchmark.Benchmarker` {pr}`152`.

### Changed

- Leiden clustering now has a seed argument for reproducibility {pr}`173`.

### Fixed

- Fix neighbors connectivities in test to use new scanpy fn {pr}`170`.
Expand Down
19 changes: 15 additions & 4 deletions src/scib_metrics/metrics/_nmi_ari.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import random
import warnings

import igraph
Expand All @@ -19,7 +20,9 @@ def _compute_clustering_kmeans(X: np.ndarray, n_clusters: int) -> np.ndarray:
return kmeans.labels_


def _compute_clustering_leiden(connectivity_graph: spmatrix, resolution: float) -> np.ndarray:
def _compute_clustering_leiden(connectivity_graph: spmatrix, resolution: float, seed: int) -> np.ndarray:
rng = random.Random(seed)
igraph.set_random_number_generator(rng)
# The connectivity graph with the umap method is symmetric, but we need to first make it directed
# to have both sets of edges as is done in scanpy. See test for more details.
g = igraph.Graph.Weighted_Adjacency(connectivity_graph, mode="directed")
Expand All @@ -33,8 +36,9 @@ def _compute_nmi_ari_cluster_labels(
X: spmatrix,
labels: np.ndarray,
resolution: float = 1.0,
seed: int = 42,
) -> tuple[float, float]:
labels_pred = _compute_clustering_leiden(X, resolution)
labels_pred = _compute_clustering_leiden(X, resolution, seed)
nmi = normalized_mutual_info_score(labels, labels_pred, average_method="arithmetic")
ari = adjusted_rand_score(labels, labels_pred)
return nmi, ari
Expand Down Expand Up @@ -71,7 +75,12 @@ def nmi_ari_cluster_labels_kmeans(X: np.ndarray, labels: np.ndarray) -> dict[str


def nmi_ari_cluster_labels_leiden(
X: NeighborsResults, labels: np.ndarray, optimize_resolution: bool = True, resolution: float = 1.0, n_jobs: int = 1
X: NeighborsResults,
labels: np.ndarray,
optimize_resolution: bool = True,
resolution: float = 1.0,
n_jobs: int = 1,
seed: int = 42,
) -> dict[str, float]:
"""Compute nmi and ari between leiden clusters and labels.

Expand All @@ -93,6 +102,8 @@ def nmi_ari_cluster_labels_leiden(
n_jobs
Number of jobs for parallelizing resolution optimization via joblib. If -1, all CPUs
are used.
seed
Seed used for reproducibility of clustering.

Returns
-------
Expand All @@ -113,7 +124,7 @@ def nmi_ari_cluster_labels_leiden(
)
except ImportError:
warnings.warn("Using for loop over clustering resolutions. `pip install joblib` for parallelization.")
out = [_compute_nmi_ari_cluster_labels(conn_graph, labels, r) for r in resolutions]
out = [_compute_nmi_ari_cluster_labels(conn_graph, labels, r, seed=seed) for r in resolutions]
nmi_ari = np.array(out)
nmi_ind = np.argmax(nmi_ari[:, 0])
nmi, ari = nmi_ari[nmi_ind, :]
Expand Down
10 changes: 10 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,16 @@ def test_nmi_ari_cluster_labels_leiden_single_resolution():
assert isinstance(ari, float)


def test_nmi_ari_cluster_labels_leiden_reproducibility():
X, labels = dummy_x_labels(symmetric_positive=True, x_is_neighbors_results=True)
out1 = scib_metrics.nmi_ari_cluster_labels_leiden(X, labels, optimize_resolution=False, resolution=3.0)
out2 = scib_metrics.nmi_ari_cluster_labels_leiden(X, labels, optimize_resolution=False, resolution=3.0)
nmi1, ari1 = out1["nmi"], out1["ari"]
nmi2, ari2 = out2["nmi"], out2["ari"]
assert nmi1 == nmi2
assert ari1 == ari2


def test_leiden_graph_construction():
X, _ = dummy_x_labels(symmetric_positive=True, x_is_neighbors_results=True)
conn_graph = X.knn_graph_connectivities
Expand Down