diff --git a/multipers/ml/signed_measures.py b/multipers/ml/signed_measures.py index 0355398..dedd272 100644 --- a/multipers/ml/signed_measures.py +++ b/multipers/ml/signed_measures.py @@ -1,5 +1,6 @@ +from collections.abc import Callable, Iterable from itertools import product -from typing import Callable, Iterable, Optional +from typing import Optional import matplotlib.pyplot as plt import numpy as np @@ -53,15 +54,17 @@ def __init__( # homological degrees + None for euler degrees: list[int | None] = [], rank_degrees: list[int] = [], # same for rank invariant - filtration_grid: Iterable[np.ndarray] - # filtration values to consider. Format : [ filtration values of Fi for Fi:filtration values of parameter i] - | None = None, + filtration_grid: ( + Iterable[np.ndarray] + # filtration values to consider. Format : [ filtration values of Fi for Fi:filtration values of parameter i] + | None + ) = None, progress=False, # tqdm num_collapses: int | str = 0, # edge collapses before computing n_jobs=None, - resolution: Iterable[int] - | int - | None = None, # when filtration grid is not given, the resolution of the filtration grid to infer + resolution: ( + Iterable[int] | int | None + ) = None, # when filtration grid is not given, the resolution of the filtration grid to infer # sparse=True, # sparse output # DEPRECATED TO Ssigned measure formatter plot: bool = False, filtration_quantile: float = 0.0, # quantile for inferring filtration grid @@ -79,7 +82,7 @@ def __init__( ] = None, # Can be significantly faster for some grid strategies, but can drop statistical performance enforce_null_mass: bool = False, flatten=True, - backend:Optional[str]=None, + backend: Optional[str] = None, ): super().__init__() self.degrees = degrees @@ -113,8 +116,32 @@ def __init__( self._default_mass_location = None self.flatten = flatten self.num_parameters: int = 0 + self._is_minpres: bool | None = None return + @staticmethod + def _is_filtered_complex(input): + return mp.simplex_tree_multi.is_simplextree_multi(input) or mp.slicer.is_slicer( + input, allow_minpres=True + ) + + def _input_checks(self, X): + assert len(X) > 0, "No filtered complex found. Cannot fit." + assert self._is_filtered_complex( + X[0][0] + ), f"X[0] is not a known filtered complex, {X[0]=}, nor X[0][0]." + self._num_axis = len(X[0]) + first = X[0][0] + assert ( + not mp.slicer.is_slicer(first) or self.expand is None + ), "Cannot expand slicers." + self._is_minpres = mp.slicer.is_slicer(first) and isinstance( + first, Union[tuple, list] + ) + assert not ( + self._is_minpres and self.minpres_degrees is not None + ), "Input is already a minpres. Cannot reduce again." + def _infer_filtration(self, X): indices = np.random.choice( len(X), min(int(self.fit_fraction * len(X)) + 1, len(X)), replace=False @@ -646,8 +673,8 @@ def _check_axis(self, X): self._num_axis = 1 self._axis_iterator = [slice(None)] return - assert ( ## vaguely checks that its a signed measure - self._check_sm(_sm := X[0][0][0]) + assert self._check_sm( ## vaguely checks that its a signed measure + _sm := X[0][0][0] ), f"Cannot take this input. # data, axis, degrees, sm.\n Got {_sm} of type {type(_sm)}" self._has_axis = True @@ -717,9 +744,11 @@ def _infer_grids(self, X): filtration_values = [ np.concatenate( [ - stuff - if isinstance(stuff := x[ax][degree][0], np.ndarray) - else stuff.detach().numpy() + ( + stuff + if isinstance(stuff := x[ax][degree][0], np.ndarray) + else stuff.detach().numpy() + ) for x in X for degree in range(self._num_degrees) ] @@ -1168,22 +1197,28 @@ def __init__( return def fit(self, X, y=None): - from multipers.ml.sliced_wasserstein import (SlicedWassersteinDistance, - WassersteinDistance) + from multipers.ml.sliced_wasserstein import ( + SlicedWassersteinDistance, + WassersteinDistance, + ) # _DISTANCE = lambda : SlicedWassersteinDistance(num_directions=self.num_directions) if self._sliced else WassersteinDistance(epsilon=self.epsilon, ground_norm=self.ground_norm) # WARNING if _sliced is false, this distance is not CNSD if len(X) == 0: return self num_degrees = len(X[0]) self._SWD_list = [ - SlicedWassersteinDistance( - num_directions=self.num_directions, - n_jobs=self.n_jobs, - scales=self.scales, - ) - if self._sliced - else WassersteinDistance( - epsilon=self.epsilon, ground_norm=self.ground_norm, n_jobs=self.n_jobs + ( + SlicedWassersteinDistance( + num_directions=self.num_directions, + n_jobs=self.n_jobs, + scales=self.scales, + ) + if self._sliced + else WassersteinDistance( + epsilon=self.epsilon, + ground_norm=self.ground_norm, + n_jobs=self.n_jobs, + ) ) for _ in range(num_degrees) ]