diff --git a/pyproject.toml b/pyproject.toml index 75196ca..89b0a57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,7 @@ xfail_strict = true [tool.ruff] src = ["src"] line-length = 120 -select = [ +lint.select = [ "F", # Errors detected by Pyflakes "E", # Error detected by Pycodestyle "W", # Warning detected by Pycodestyle @@ -107,7 +107,7 @@ select = [ "UP", # pyupgrade "RUF100", # Report unused noqa directives ] -ignore = [ +lint.ignore = [ # line too long -> we accept long comment lines; formatter gets rid of long code lines "E501", # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient diff --git a/src/scib_metrics/_types.py b/src/scib_metrics/_types.py index ad0c4ac..418bbfb 100644 --- a/src/scib_metrics/_types.py +++ b/src/scib_metrics/_types.py @@ -1,10 +1,10 @@ from typing import Union -import jax import jax.numpy as jnp import numpy as np import scipy.sparse as sp +from jax import Array NdArray = Union[np.ndarray, jnp.ndarray] -IntOrKey = Union[int, jax.random.KeyArray] +IntOrKey = Union[int, Array] ArrayLike = Union[np.ndarray, sp.spmatrix, jnp.ndarray] diff --git a/src/scib_metrics/utils/_kmeans.py b/src/scib_metrics/utils/_kmeans.py index 644c4a7..8e67567 100644 --- a/src/scib_metrics/utils/_kmeans.py +++ b/src/scib_metrics/utils/_kmeans.py @@ -4,6 +4,7 @@ import jax import jax.numpy as jnp import numpy as np +from jax import Array from sklearn.utils import check_array from scib_metrics._types import IntOrKey @@ -18,7 +19,7 @@ def _tolerance(X: jnp.ndarray, tol: float) -> float: return np.mean(variances) * tol -def _initialize_random(X: jnp.ndarray, n_clusters: int, key: jax.random.KeyArray) -> jnp.ndarray: +def _initialize_random(X: jnp.ndarray, n_clusters: int, key: Array) -> jnp.ndarray: """Initialize cluster centroids randomly.""" n_obs = X.shape[0] key, subkey = jax.random.split(key) @@ -28,7 +29,7 @@ def _initialize_random(X: jnp.ndarray, n_clusters: int, key: jax.random.KeyArray @partial(jax.jit, static_argnums=1) -def _initialize_plus_plus(X: jnp.ndarray, n_clusters: int, key: jax.random.KeyArray) -> jnp.ndarray: +def _initialize_plus_plus(X: jnp.ndarray, n_clusters: int, key: Array) -> jnp.ndarray: """Initialize cluster centroids with k-means++ algorithm.""" n_obs = X.shape[0] key, subkey = jax.random.split(key) @@ -111,7 +112,7 @@ def __init__( self.n_init = n_init self.max_iter = max_iter self.tol_scale = tol - self.seed: jax.random.KeyArray = validate_seed(seed) + self.seed: jax.Array = validate_seed(seed) if init not in ["k-means++", "random"]: raise ValueError("Invalid init method, must be one of ['k-means++' or 'random'].") diff --git a/src/scib_metrics/utils/_utils.py b/src/scib_metrics/utils/_utils.py index 51ff901..b093431 100644 --- a/src/scib_metrics/utils/_utils.py +++ b/src/scib_metrics/utils/_utils.py @@ -5,7 +5,7 @@ import jax.numpy as jnp import numpy as np from chex import ArrayDevice -from jax import nn +from jax import Array, nn from scipy.sparse import csr_matrix from sklearn.neighbors import NearestNeighbors from sklearn.utils import check_array @@ -37,7 +37,7 @@ def one_hot(y: NdArray, n_classes: Optional[int] = None) -> jnp.ndarray: return nn.one_hot(jnp.ravel(y), n_classes) -def validate_seed(seed: IntOrKey) -> jax.random.KeyArray: +def validate_seed(seed: IntOrKey) -> Array: """Validate a seed and return a Jax random key.""" return jax.random.PRNGKey(seed) if isinstance(seed, int) else seed diff --git a/tests/utils/sampling.py b/tests/utils/sampling.py index 34a8e50..0ecf812 100644 --- a/tests/utils/sampling.py +++ b/tests/utils/sampling.py @@ -2,11 +2,12 @@ import jax import jax.numpy as jnp +from jax import Array -IntOrKey = Union[int, jax.random.KeyArray] +IntOrKey = Union[int, Array] -def _validate_seed(seed: IntOrKey) -> jax.random.KeyArray: +def _validate_seed(seed: IntOrKey) -> Array: return jax.random.PRNGKey(seed) if isinstance(seed, int) else seed