Skip to content

Commit

Permalink
final changes to entropy_hist
Browse files Browse the repository at this point in the history
  • Loading branch information
Matteo NERI authored and Matteo NERI committed Sep 3, 2024
1 parent c9f56b9 commit f088ce0
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
5 changes: 2 additions & 3 deletions hoi/core/entropies.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,7 @@ def entropy_hist(x: jnp.array, base: float = 2, n_bins: int = 8) -> jnp.array:
Entropy of x (in bits)
"""

x_binned, bin_size = digitize_hist(
x, n_bins, axis=1
)
x_binned, bin_size = digitize_hist(x, n_bins, axis=1)

n_features, n_samples = x_binned.shape

Expand All @@ -339,6 +337,7 @@ def entropy_hist(x: jnp.array, base: float = 2, n_bins: int = 8) -> jnp.array:

return -jax.scipy.special.rel_entr(probs, bin_s).sum() / jnp.log(base)


###############################################################################
###############################################################################
# KNN
Expand Down
4 changes: 3 additions & 1 deletion hoi/utils/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def landscape(

return lscp


def digitize_1d(x, n_bins):
"""One dimensional digitization."""
assert x.ndim == 1
Expand Down Expand Up @@ -182,6 +183,7 @@ def digitize(x, n_bins, axis=0, use_sklearn=False, **kwargs):

partial(jax.jit, static_argnums=(1, 2))


def digitize_1d_hist(x, n_bins):
"""One dimensional digitization."""
assert x.ndim == 1
Expand Down Expand Up @@ -214,7 +216,7 @@ def digitize_hist(x, n_bins, axis=0, **kwargs):
x_binned : array_like
Digitized array with the same shape as x
b_size : float
Size of the bin used
Size of the bin used
"""
# In case use_sklearn = False, all bins have the same size. In this case,
# in order to allow the histogram estimator, also the size of the bins is
Expand Down

0 comments on commit f088ce0

Please sign in to comment.