-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speedup Benchmarker.prepare
(compute_connectivities_umap
)
#128
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
import numpy as np | ||
from chex import ArrayDevice | ||
from jax import nn | ||
from scipy.sparse import csr_matrix | ||
from scipy.sparse import coo_matrix, csr_matrix | ||
from sklearn.neighbors import NearestNeighbors | ||
from sklearn.utils import check_array | ||
|
||
|
@@ -63,3 +63,52 @@ def convert_knn_graph_to_idx(X: csr_matrix) -> tuple[np.ndarray, np.ndarray]: | |
nn_obj = NearestNeighbors(n_neighbors=n_neighbors, metric="precomputed").fit(X) | ||
kneighbors = nn_obj.kneighbors(X) | ||
return kneighbors | ||
|
||
|
||
def compute_connectivities_umap( | ||
knn_indices, | ||
knn_dists, | ||
n_obs, | ||
n_neighbors, | ||
set_op_mix_ratio=1.0, | ||
local_connectivity=1.0, | ||
): | ||
"""Sped up version of sc.neighbors._compute_connectivities_umap.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you put a more general docstring? Overview of the method and that it matches how connectivies are computed in scanpy? |
||
with warnings.catch_warnings(): | ||
# umap 0.5.0 | ||
warnings.filterwarnings("ignore", message=r"Tensorflow not installed") | ||
from umap.umap_ import fuzzy_simplicial_set | ||
|
||
X = coo_matrix(([], ([], [])), shape=(n_obs, 1)) | ||
connectivities = fuzzy_simplicial_set( | ||
X, | ||
n_neighbors, | ||
None, | ||
None, | ||
Comment on lines
+84
to
+87
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I prefer using keywords everywhere |
||
knn_indices=knn_indices, | ||
knn_dists=knn_dists, | ||
set_op_mix_ratio=set_op_mix_ratio, | ||
local_connectivity=local_connectivity, | ||
) | ||
|
||
if isinstance(connectivities, tuple): | ||
# In umap-learn 0.4, this returns (result, sigmas, rhos) | ||
connectivities = connectivities[0] | ||
Comment on lines
+94
to
+96
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this bit still necessary? What's the lower bound for scanpy? If we merge this PR we will need to make umap a direct dependency, so please also add that and potentially remove this block. |
||
|
||
n_samples = knn_indices.shape[0] | ||
distances = knn_dists.ravel() | ||
indices = knn_indices.ravel() | ||
|
||
# Check for self-connections | ||
self_connections = not np.all(knn_indices != np.arange(n_samples)[:, None]) | ||
|
||
# Efficient creation of row pointer | ||
rowptr = np.arange(0, n_samples * n_neighbors + 1, n_neighbors) | ||
|
||
# Create CSR matrix | ||
dist_sparse_csr = csr_matrix((distances, indices, rowptr), shape=(n_samples, n_samples)) | ||
|
||
# Set diagonal to zero if self-connections exist | ||
if self_connections: | ||
dist_sparse_csr.setdiag(0.0) | ||
return dist_sparse_csr, connectivities.tocsr() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import time | ||
|
||
import numpy as np | ||
import pytest | ||
import scanpy as sc | ||
from sklearn.neighbors import NearestNeighbors | ||
|
||
from scib_metrics.utils._utils import compute_connectivities_umap | ||
from tests.utils.data import dummy_benchmarker_adata | ||
|
||
|
||
@pytest.mark.parametrize("n", [5, 10, 20, 21]) | ||
def test_compute_connectivities_umap(n): | ||
adata, embedding_keys, *_ = dummy_benchmarker_adata() | ||
neigh = NearestNeighbors(n_neighbors=25).fit(adata.obsm[embedding_keys[0]]) | ||
dist, ind = neigh.kneighbors() | ||
new_dist, new_connect = compute_connectivities_umap(ind[:, :n], dist[:, :n], adata.n_obs, n_neighbors=n) | ||
sc_dist, sc_connect = sc.neighbors._compute_connectivities_umap(ind[:, :n], dist[:, :n], adata.n_obs, n_neighbors=n) | ||
assert (new_dist == sc_dist).todense().all() | ||
assert (new_connect == sc_connect).todense().all() | ||
Comment on lines
+19
to
+20
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: use |
||
|
||
|
||
def test_timing_compute_connectivities_umap(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this test is necessary if you add a reproducible script to this PR description |
||
n_obs = 10_000 | ||
X = np.random.normal(size=(n_obs, 10)) | ||
neigh = NearestNeighbors(n_neighbors=90).fit(X) | ||
dist, ind = neigh.kneighbors() | ||
|
||
new_start = time.perf_counter() | ||
compute_connectivities_umap(ind, dist, n_obs, n_neighbors=90) | ||
new_end = time.perf_counter() | ||
|
||
sc_start = time.perf_counter() | ||
sc.neighbors._compute_connectivities_umap(ind, dist, n_obs, n_neighbors=90) | ||
sc_end = time.perf_counter() | ||
|
||
assert new_end - new_start < sc_end - sc_start | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add typing