Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor neighbors-based metrics to use NeighborsResults #129

Merged
merged 24 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.4.1
current_version = 0.5.0
tag = True
commit = True

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python: ["3.9", "3.10"]
python: ["3.9", "3.10", "3.11"]
os: [ubuntu-latest]

env:
Expand Down
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,7 @@ repos:
mdformat-myst,
]
args: [--nbqa-md]
- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
hooks:
- id: nbstripout
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## 0.5.0 (2024-MM-DD)

- Refactor all relevant metrics to use `NeighborsResults` as input instead of sparse distance/connectivity matrices.

## 0.4.1 (2023-10-08)

- Fix KMeans. All previous versions had a bug with KMeans and ARI/NMI metrics are not reliable with this clustering. ([#115][])
Expand Down
2 changes: 1 addition & 1 deletion docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ scib_metrics.ilisi_knn(...)

nearest_neighbors.pynndescent
nearest_neighbors.jax_approx_min_k
nearest_neighbors.NeighborsOutput
nearest_neighbors.NeighborsResults
```

## Settings
Expand Down
179 changes: 105 additions & 74 deletions docs/notebooks/large_scale.ipynb

Large diffs are not rendered by default.

518 changes: 210 additions & 308 deletions docs/notebooks/lung_example.ipynb

Large diffs are not rendered by default.

60 changes: 26 additions & 34 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,13 @@ requires = ["hatchling"]

[project]
name = "scib-metrics"
version = "0.4.1"
version = "0.5.0"
description = "Accelerated and Python-only scIB metrics"
readme = "README.md"
requires-python = ">=3.9"
license = {file = "LICENSE"}
authors = [
{name = "Adam Gayoso"},
]
maintainers = [
{name = "Adam Gayoso", email = "[email protected]"},
]
license = { file = "LICENSE" }
authors = [{ name = "Adam Gayoso" }]
maintainers = [{ name = "Adam Gayoso", email = "[email protected]" }]
urls.Documentation = "https://scib-metrics.readthedocs.io/"
urls.Source = "https://github.com/yoseflab/scib-metrics"
urls.Home-page = "https://github.com/yoseflab/scib-metrics"
Expand All @@ -35,13 +31,14 @@ dependencies = [
"matplotlib",
"plottable",
"tqdm",
"umap-learn>=0.5.0",
]

[project.optional-dependencies]
dev = [
# CLI for bumping the version number
"bump2version",
"pre-commit"
"pre-commit",
]
doc = [
"sphinx>=4",
Expand All @@ -66,30 +63,26 @@ test = [
"black",
"numba>=0.57.1",
]
parallel = [
"joblib"
]
parallel = ["joblib"]
tutorial = [
"rich",
"scanorama",
"harmony-pytorch",
"scvi-tools",
"pyliger",
"numexpr", # missing liger dependency
"plotnine", # missing liger dependency
"mygene", # missing liger dependency
"goatools", # missing liger dependency
"adjustText", # missing liger dependency
"numexpr", # missing liger dependency
"plotnine", # missing liger dependency
"mygene", # missing liger dependency
"goatools", # missing liger dependency
"adjustText", # missing liger dependency
]

[tool.hatch.build.targets.wheel]
packages = ['src/scib_metrics']

[tool.coverage.run]
source = ["scib_metrics"]
omit = [
"**/test_*.py",
]
omit = ["**/test_*.py"]

[tool.pytest.ini_options]
testpaths = ["tests"]
Expand All @@ -99,17 +92,17 @@ xfail_strict = true
src = ["src"]
line-length = 120
select = [
"F", # Errors detected by Pyflakes
"E", # Error detected by Pycodestyle
"W", # Warning detected by Pycodestyle
"I", # isort
"D", # pydocstyle
"B", # flake8-bugbear
"TID", # flake8-tidy-imports
"C4", # flake8-comprehensions
"BLE", # flake8-blind-except
"UP", # pyupgrade
"RUF100", # Report unused noqa directives
"F", # Errors detected by Pyflakes
"E", # Error detected by Pycodestyle
"W", # Warning detected by Pycodestyle
"I", # isort
"D", # pydocstyle
"B", # flake8-bugbear
"TID", # flake8-tidy-imports
"C4", # flake8-comprehensions
"BLE", # flake8-blind-except
"UP", # pyupgrade
"RUF100", # Report unused noqa directives
]
ignore = [
# line too long -> we accept long comment lines; black gets rid of long code lines
Expand All @@ -128,8 +121,7 @@ ignore = [
"B008",
# __magic__ methods are are often self-explanatory, allow missing docstrings
"D105",
# first line should end with a period [Bug: doesn't work with single-line docstrings]
"D400",
# first line should end with a period [Bug: doesn't work with single-line docstrings] "D400",
# First line should be in imperative mood; try rephrasing
"D401",
## Disable one in each pair of mutually incompatible rules
Expand Down Expand Up @@ -185,5 +177,5 @@ skip = [
"docs/changelog.md",
"docs/references.bib",
"docs/references.md",
"docs/notebooks/example.ipynb"
"docs/notebooks/example.ipynb",
]
9 changes: 6 additions & 3 deletions src/scib_metrics/_graph_connectivity.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components

from scib_metrics.nearest_neighbors import NeighborsResults

def graph_connectivity(X: csr_matrix, labels: np.ndarray) -> float:

def graph_connectivity(X: NeighborsResults, labels: np.ndarray) -> float:
"""Quantify the connectivity of the subgraph per cell type label.

Parameters
Expand All @@ -19,9 +20,11 @@ def graph_connectivity(X: csr_matrix, labels: np.ndarray) -> float:
# TODO(adamgayoso): Utils for validating inputs
clust_res = []

graph = X.knn_graph_distances

for label in np.unique(labels):
mask = labels == label
graph_sub = X[mask]
graph_sub = graph[mask]
graph_sub = graph_sub[:, mask]
_, comps = connected_components(graph_sub, connection="strong")
tab = pd.value_counts(comps)
Expand Down
50 changes: 22 additions & 28 deletions src/scib_metrics/_kbet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import numpy as np
import pandas as pd
import scipy
from scipy.sparse import csr_matrix

from scib_metrics.utils import convert_knn_graph_to_idx, diffusion_nn, get_ndarray
from scib_metrics.nearest_neighbors import NeighborsResults
from scib_metrics.utils import diffusion_nn, get_ndarray

from ._types import NdArray

Expand Down Expand Up @@ -40,7 +40,7 @@ def _kbet(neigh_batch_ids: jnp.ndarray, batches: jnp.ndarray, n_batches: int) ->
return test_statistics, p_values


def kbet(X: csr_matrix, batches: np.ndarray, alpha: float = 0.05) -> float:
def kbet(X: NeighborsResults, batches: np.ndarray, alpha: float = 0.05) -> float:
"""Compute kbet :cite:p:`buttner2018`.

This implementation is inspired by the implementation in Pegasus:
Expand All @@ -57,8 +57,7 @@ def kbet(X: csr_matrix, batches: np.ndarray, alpha: float = 0.05) -> float:
Parameters
----------
X
Array of shape (n_cells, n_cells) with non-zero values
representing distances to exactly each cell's k nearest neighbors.
A :class:`~scib_metrics.utils.nearest_neighbors.NeighborsResults` object.
batches
Array of shape (n_cells,) representing batch values
for each cell.
Expand All @@ -73,16 +72,10 @@ def kbet(X: csr_matrix, batches: np.ndarray, alpha: float = 0.05) -> float:
Mean Kbet chi-square statistic over all cells.
pvalue_mean
Mean Kbet p-value over all cells.

Notes
-----
This function requires X to be cell-cell distances, not connectivies.
"""
if len(batches) != X.shape[0]:
if len(batches) != len(X.indices):
raise ValueError("Length of batches does not match number of cells.")
_, knn_idx = convert_knn_graph_to_idx(X)
# Make sure self is included
knn_idx = np.concatenate([np.arange(knn_idx.shape[0])[:, None], knn_idx], axis=1)
knn_idx = X.indices
batches = np.asarray(pd.Categorical(batches).codes)
neigh_batch_ids = batches[knn_idx]
chex.assert_equal_shape([neigh_batch_ids, knn_idx])
Expand All @@ -96,7 +89,7 @@ def kbet(X: csr_matrix, batches: np.ndarray, alpha: float = 0.05) -> float:


def kbet_per_label(
X: csr_matrix,
X: NeighborsResults,
batches: np.ndarray,
labels: np.ndarray,
alpha: float = 0.05,
Expand All @@ -113,8 +106,7 @@ def kbet_per_label(
Parameters
----------
X
Array of shape (n_cells, n_cells) with non-zero values
representing connectivies to exactly each cell's k nearest neighbors.
A :class:`~scib_metrics.utils.nearest_neighbors.NeighborsResults` object.
batches
Array of shape (n_cells,) representing batch values
for each cell.
Expand All @@ -136,23 +128,25 @@ def kbet_per_label(
-----
This function requires X to be cell-cell connectivities, not distances.
"""
if len(batches) != X.shape[0]:
if len(batches) != len(X.indices):
raise ValueError("Length of batches does not match number of cells.")
if len(labels) != X.shape[0]:
if len(labels) != len(X.indices):
raise ValueError("Length of labels does not match number of cells.")
# set upper bound for k0
size_max = 2**31 - 1
batches = np.asarray(pd.Categorical(batches).codes)
labels = np.asarray(labels)

conn_graph = X.knn_graph_connectivities

# prepare call of kBET per cluster
kbet_scores = {"cluster": [], "kBET": []}
for clus in np.unique(labels):
# subset by label
mask = labels == clus
X_sub = X[mask, :][:, mask]
X_sub.sort_indices()
n_obs = X_sub.shape[0]
conn_graph_sub = conn_graph[mask, :][:, mask]
conn_graph_sub.sort_indices()
n_obs = conn_graph_sub.shape[0]
batches_sub = batches[mask]

# check if neighborhood size too small or only one batch in subset
Expand All @@ -166,12 +160,12 @@ def kbet_per_label(
if k0 * n_obs >= size_max:
k0 = np.floor(size_max / n_obs).astype("int")

n_comp, labs = scipy.sparse.csgraph.connected_components(X_sub, connection="strong")
n_comp, labs = scipy.sparse.csgraph.connected_components(conn_graph_sub, connection="strong")

if n_comp == 1: # a single component to compute kBET on
try:
diffusion_n_comps = np.min([diffusion_n_comps, n_obs - 1])
nn_graph_sub = diffusion_nn(X_sub, k=k0, n_comps=diffusion_n_comps).astype("float")
nn_graph_sub = diffusion_nn(conn_graph_sub, k=k0, n_comps=diffusion_n_comps)
# call kBET
score, _, _ = kbet(
nn_graph_sub,
Expand All @@ -192,15 +186,15 @@ def kbet_per_label(
# check if 75% of all cells can be used for kBET run
if len(idx_nonan) / len(labs) >= 0.75:
# create another subset of components, assume they are not visited in a diffusion process
X_sub_sub = X_sub[idx_nonan, :][:, idx_nonan]
X_sub_sub.sort_indices()
conn_graph_sub_sub = conn_graph_sub[idx_nonan, :][:, idx_nonan]
conn_graph_sub_sub.sort_indices()

try:
diffusion_n_comps = np.min([diffusion_n_comps, X_sub_sub.shape[0] - 1])
nn_graph_sub_sub = diffusion_nn(X_sub_sub, k=k0, n_comps=diffusion_n_comps).astype("float")
diffusion_n_comps = np.min([diffusion_n_comps, conn_graph_sub_sub.shape[0] - 1])
nn_results_sub_sub = diffusion_nn(conn_graph_sub_sub, k=k0, n_comps=diffusion_n_comps)
# call kBET
score, _, _ = kbet(
nn_graph_sub_sub,
nn_results_sub_sub,
batches=batches_sub[idx_nonan],
alpha=alpha,
)
Expand Down
Loading