Skip to content

Commit

Permalink
final version of entropy hist pull request
Browse files Browse the repository at this point in the history
  • Loading branch information
Matteo NERI authored and Matteo NERI committed Sep 3, 2024
1 parent 4ed53ff commit c9f56b9
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 21 deletions.
11 changes: 5 additions & 6 deletions hoi/core/entropies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from jax.scipy.stats import gaussian_kde

from hoi.utils.logging import logger
from hoi.utils.stats import normalize, digitize
from hoi.utils.stats import normalize, digitize_hist

###############################################################################
###############################################################################
Expand Down Expand Up @@ -299,7 +299,7 @@ def entropy_bin(x: jnp.array, base: float = 2) -> jnp.array:


@partial(jax.jit, static_argnums=(1,))
def entropy_hist(x: jnp.array, base: float = 2, n_bins: int = 5) -> jnp.array:
def entropy_hist(x: jnp.array, base: float = 2, n_bins: int = 8) -> jnp.array:
"""Entropy using binning.
Parameters
Expand All @@ -309,7 +309,7 @@ def entropy_hist(x: jnp.array, base: float = 2, n_bins: int = 5) -> jnp.array:
be discretize
base : float | 2
The logarithmic base to use. Default is base 2.
n_bins : int | 5
n_bins : int | 8
The number of bin to be considered in the binarization process
Returns
Expand All @@ -318,8 +318,8 @@ def entropy_hist(x: jnp.array, base: float = 2, n_bins: int = 5) -> jnp.array:
Entropy of x (in bits)
"""

x_binned, bin_size = digitize(
x, n_bins, axis=1, use_sklearn=False, bin_size=True
x_binned, bin_size = digitize_hist(
x, n_bins, axis=1
)

n_features, n_samples = x_binned.shape
Expand All @@ -339,7 +339,6 @@ def entropy_hist(x: jnp.array, base: float = 2, n_bins: int = 5) -> jnp.array:

return -jax.scipy.special.rel_entr(probs, bin_s).sum() / jnp.log(base)


###############################################################################
###############################################################################
# KNN
Expand Down
63 changes: 48 additions & 15 deletions hoi/utils/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,13 @@ def landscape(

return lscp


def digitize_1d(x, n_bins):
"""One dimensional digitization."""
assert x.ndim == 1
x_min, x_max = x.min(), x.max()
dx = (x_max - x_min) / n_bins
x_binned = ((x - x_min) / dx).astype(int)
x_binned = jnp.minimum(x_binned, n_bins - 1)
x_binned = np.minimum(x_binned, n_bins - 1)
return x_binned.astype(int)


Expand All @@ -147,7 +146,7 @@ def digitize_sklearn(x, **kwargs):
)


def digitize(x, n_bins, axis=0, use_sklearn=False, bin_size=False, **kwargs):
def digitize(x, n_bins, axis=0, use_sklearn=False, **kwargs):
"""Discretize a continuous variable.
Parameters
Expand All @@ -166,31 +165,65 @@ def digitize(x, n_bins, axis=0, use_sklearn=False, bin_size=False, **kwargs):
Additional arguments are passed to
sklearn.preprocessing.KBinsDiscretizer. For example, use
`strategy='quantile'` for equal population binning.
bin_size : bool | False
When true returns also the bin_sizes, note only in when
use_sklearn=False
Returns
-------
x_binned : array_like
Digitized array with the same shape as x
"""
# In case use_sklearn = False, all bins have the same size. In this case,
# in order to allow the histogram estimator, also the size of the bins is
# returned.
bins_arr = (x.max(axis=axis) - x.min(axis=axis)) / n_bins
b_size = jnp.prod(bins_arr)
if not use_sklearn and bin_size:
return jnp.apply_along_axis(digitize_1d, axis, x, n_bins), b_size
elif not use_sklearn and not bin_size:
return jnp.apply_along_axis(digitize_1d, axis, x, n_bins)
if not use_sklearn:
return np.apply_along_axis(digitize_1d, axis, x, n_bins)
else:
kwargs["n_bins"] = n_bins
kwargs["encode"] = "ordinal"
kwargs["subsample"] = None
return np.apply_along_axis(digitize_sklearn, axis, x, **kwargs)


partial(jax.jit, static_argnums=(1, 2))

def digitize_1d_hist(x, n_bins):
"""One dimensional digitization."""
assert x.ndim == 1
x_min, x_max = x.min(), x.max()
dx = (x_max - x_min) / n_bins
x_binned = ((x - x_min) / dx).astype(int)
x_binned = jnp.minimum(x_binned, n_bins - 1)
return x_binned.astype(int)


def digitize_hist(x, n_bins, axis=0, **kwargs):
"""Discretize a continuous variable.
Parameters
----------
x : array_like
Array to discretize
n_bins : int
Number of bins
axis : int | 0
Axis along which to perform the discretization. By default,
discretization is performed along the first axis (n_samples,)
kwargs : dict | {}
Additional arguments are passed to
sklearn.preprocessing.KBinsDiscretizer. For example, use
`strategy='quantile'` for equal population binning.
Returns
-------
x_binned : array_like
Digitized array with the same shape as x
b_size : float
Size of the bin used
"""
# In case use_sklearn = False, all bins have the same size. In this case,
# in order to allow the histogram estimator, also the size of the bins is
# returned.
bins_arr = (x.max(axis=axis) - x.min(axis=axis)) / n_bins
b_size = jnp.prod(bins_arr)
return jnp.apply_along_axis(digitize_1d_hist, axis, x, n_bins), b_size


partial(jax.jit, static_argnums=(1, 2))


Expand Down

0 comments on commit c9f56b9

Please sign in to comment.