Skip to content

Commit

Permalink
feat: added some checks in sm pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidLapous committed Sep 26, 2024
1 parent bb007e6 commit ef681e6
Showing 1 changed file with 58 additions and 23 deletions.
81 changes: 58 additions & 23 deletions multipers/ml/signed_measures.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Callable, Iterable
from itertools import product
from typing import Callable, Iterable, Optional
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -53,15 +54,17 @@ def __init__(
# homological degrees + None for euler
degrees: list[int | None] = [],
rank_degrees: list[int] = [], # same for rank invariant
filtration_grid: Iterable[np.ndarray]
# filtration values to consider. Format : [ filtration values of Fi for Fi:filtration values of parameter i]
| None = None,
filtration_grid: (
Iterable[np.ndarray]
# filtration values to consider. Format : [ filtration values of Fi for Fi:filtration values of parameter i]
| None
) = None,
progress=False, # tqdm
num_collapses: int | str = 0, # edge collapses before computing
n_jobs=None,
resolution: Iterable[int]
| int
| None = None, # when filtration grid is not given, the resolution of the filtration grid to infer
resolution: (
Iterable[int] | int | None
) = None, # when filtration grid is not given, the resolution of the filtration grid to infer
# sparse=True, # sparse output # DEPRECATED TO Ssigned measure formatter
plot: bool = False,
filtration_quantile: float = 0.0, # quantile for inferring filtration grid
Expand All @@ -79,7 +82,7 @@ def __init__(
] = None, # Can be significantly faster for some grid strategies, but can drop statistical performance
enforce_null_mass: bool = False,
flatten=True,
backend:Optional[str]=None,
backend: Optional[str] = None,
):
super().__init__()
self.degrees = degrees
Expand Down Expand Up @@ -113,8 +116,32 @@ def __init__(
self._default_mass_location = None
self.flatten = flatten
self.num_parameters: int = 0
self._is_minpres: bool | None = None
return

@staticmethod
def _is_filtered_complex(input):
return mp.simplex_tree_multi.is_simplextree_multi(input) or mp.slicer.is_slicer(
input, allow_minpres=True
)

def _input_checks(self, X):
assert len(X) > 0, "No filtered complex found. Cannot fit."
assert self._is_filtered_complex(
X[0][0]
), f"X[0] is not a known filtered complex, {X[0]=}, nor X[0][0]."
self._num_axis = len(X[0])
first = X[0][0]
assert (
not mp.slicer.is_slicer(first) or self.expand is None
), "Cannot expand slicers."
self._is_minpres = mp.slicer.is_slicer(first) and isinstance(
first, Union[tuple, list]
)
assert not (
self._is_minpres and self.minpres_degrees is not None
), "Input is already a minpres. Cannot reduce again."

def _infer_filtration(self, X):
indices = np.random.choice(
len(X), min(int(self.fit_fraction * len(X)) + 1, len(X)), replace=False
Expand Down Expand Up @@ -646,8 +673,8 @@ def _check_axis(self, X):
self._num_axis = 1
self._axis_iterator = [slice(None)]
return
assert ( ## vaguely checks that its a signed measure
self._check_sm(_sm := X[0][0][0])
assert self._check_sm( ## vaguely checks that its a signed measure
_sm := X[0][0][0]
), f"Cannot take this input. # data, axis, degrees, sm.\n Got {_sm} of type {type(_sm)}"

self._has_axis = True
Expand Down Expand Up @@ -717,9 +744,11 @@ def _infer_grids(self, X):
filtration_values = [
np.concatenate(
[
stuff
if isinstance(stuff := x[ax][degree][0], np.ndarray)
else stuff.detach().numpy()
(
stuff
if isinstance(stuff := x[ax][degree][0], np.ndarray)
else stuff.detach().numpy()
)
for x in X
for degree in range(self._num_degrees)
]
Expand Down Expand Up @@ -1168,22 +1197,28 @@ def __init__(
return

def fit(self, X, y=None):
from multipers.ml.sliced_wasserstein import (SlicedWassersteinDistance,
WassersteinDistance)
from multipers.ml.sliced_wasserstein import (
SlicedWassersteinDistance,
WassersteinDistance,
)

# _DISTANCE = lambda : SlicedWassersteinDistance(num_directions=self.num_directions) if self._sliced else WassersteinDistance(epsilon=self.epsilon, ground_norm=self.ground_norm) # WARNING if _sliced is false, this distance is not CNSD
if len(X) == 0:
return self
num_degrees = len(X[0])
self._SWD_list = [
SlicedWassersteinDistance(
num_directions=self.num_directions,
n_jobs=self.n_jobs,
scales=self.scales,
)
if self._sliced
else WassersteinDistance(
epsilon=self.epsilon, ground_norm=self.ground_norm, n_jobs=self.n_jobs
(
SlicedWassersteinDistance(
num_directions=self.num_directions,
n_jobs=self.n_jobs,
scales=self.scales,
)
if self._sliced
else WassersteinDistance(
epsilon=self.epsilon,
ground_norm=self.ground_norm,
n_jobs=self.n_jobs,
)
)
for _ in range(num_degrees)
]
Expand Down

0 comments on commit ef681e6

Please sign in to comment.