Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup Benchmarker.prepare (compute_connectivities_umap) #128

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/scib_metrics/benchmark/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import scib_metrics
from scib_metrics.nearest_neighbors import NeighborsOutput, pynndescent
from scib_metrics.utils._utils import compute_connectivities_umap

Kwargs = dict[str, Any]
MetricType = Union[bool, Kwargs]
Expand Down Expand Up @@ -190,7 +191,7 @@ def prepare(self, neighbor_computer: Optional[Callable[[np.ndarray, int], Neighb
)
indices, distances = neigh_output.indices, neigh_output.distances
for n in self._neighbor_values:
sp_distances, sp_conns = sc.neighbors._compute_connectivities_umap(
sp_distances, sp_conns = compute_connectivities_umap(
indices[:, :n], distances[:, :n], ad.n_obs, n_neighbors=n
)
ad.obsp[f"{n}_connectivities"] = sp_conns
Expand Down
51 changes: 50 additions & 1 deletion src/scib_metrics/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from chex import ArrayDevice
from jax import nn
from scipy.sparse import csr_matrix
from scipy.sparse import coo_matrix, csr_matrix
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import check_array

Expand Down Expand Up @@ -63,3 +63,52 @@ def convert_knn_graph_to_idx(X: csr_matrix) -> tuple[np.ndarray, np.ndarray]:
nn_obj = NearestNeighbors(n_neighbors=n_neighbors, metric="precomputed").fit(X)
kneighbors = nn_obj.kneighbors(X)
return kneighbors


def compute_connectivities_umap(
knn_indices,
knn_dists,
n_obs,
n_neighbors,
set_op_mix_ratio=1.0,
local_connectivity=1.0,
):
Comment on lines +69 to +75
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add typing

"""Sped up version of sc.neighbors._compute_connectivities_umap."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put a more general docstring? Overview of the method and that it matches how connectivies are computed in scanpy?

with warnings.catch_warnings():
# umap 0.5.0
warnings.filterwarnings("ignore", message=r"Tensorflow not installed")
from umap.umap_ import fuzzy_simplicial_set

X = coo_matrix(([], ([], [])), shape=(n_obs, 1))
connectivities = fuzzy_simplicial_set(
X,
n_neighbors,
None,
None,
Comment on lines +84 to +87
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I prefer using keywords everywhere

knn_indices=knn_indices,
knn_dists=knn_dists,
set_op_mix_ratio=set_op_mix_ratio,
local_connectivity=local_connectivity,
)

if isinstance(connectivities, tuple):
# In umap-learn 0.4, this returns (result, sigmas, rhos)
connectivities = connectivities[0]
Comment on lines +94 to +96
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this bit still necessary? What's the lower bound for scanpy? If we merge this PR we will need to make umap a direct dependency, so please also add that and potentially remove this block.


n_samples = knn_indices.shape[0]
distances = knn_dists.ravel()
indices = knn_indices.ravel()

# Check for self-connections
self_connections = not np.all(knn_indices != np.arange(n_samples)[:, None])

# Efficient creation of row pointer
rowptr = np.arange(0, n_samples * n_neighbors + 1, n_neighbors)

# Create CSR matrix
dist_sparse_csr = csr_matrix((distances, indices, rowptr), shape=(n_samples, n_samples))

# Set diagonal to zero if self-connections exist
if self_connections:
dist_sparse_csr.setdiag(0.0)
return dist_sparse_csr, connectivities.tocsr()
41 changes: 41 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import time

import numpy as np
import pytest
import scanpy as sc
from sklearn.neighbors import NearestNeighbors

from scib_metrics.utils._utils import compute_connectivities_umap
from tests.utils.data import dummy_benchmarker_adata


@pytest.mark.parametrize("n", [5, 10, 20, 21])
def test_compute_connectivities_umap(n):
adata, embedding_keys, *_ = dummy_benchmarker_adata()
neigh = NearestNeighbors(n_neighbors=25).fit(adata.obsm[embedding_keys[0]])
dist, ind = neigh.kneighbors()
new_dist, new_connect = compute_connectivities_umap(ind[:, :n], dist[:, :n], adata.n_obs, n_neighbors=n)
sc_dist, sc_connect = sc.neighbors._compute_connectivities_umap(ind[:, :n], dist[:, :n], adata.n_obs, n_neighbors=n)
assert (new_dist == sc_dist).todense().all()
assert (new_connect == sc_connect).todense().all()
Comment on lines +19 to +20
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use np.testing



def test_timing_compute_connectivities_umap():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this test is necessary if you add a reproducible script to this PR description

n_obs = 10_000
X = np.random.normal(size=(n_obs, 10))
neigh = NearestNeighbors(n_neighbors=90).fit(X)
dist, ind = neigh.kneighbors()

new_start = time.perf_counter()
compute_connectivities_umap(ind, dist, n_obs, n_neighbors=90)
new_end = time.perf_counter()

sc_start = time.perf_counter()
sc.neighbors._compute_connectivities_umap(ind, dist, n_obs, n_neighbors=90)
sc_end = time.perf_counter()

assert new_end - new_start < sc_end - sc_start


if __name__ == "__main__":
pytest.main([__file__])