Skip to content

Commit

Permalink
bench: added gudhi slices
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidLapous committed Oct 4, 2024
1 parent 627d95b commit b203952
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions benchmarks/stuff.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import multipers.grids as mpg
import multipers.ml.point_clouds as mmp
from multipers.data import noisy_annulus, orbit, three_annulus
from multipers.simplex_tree_multi import SimplexTreeMulti_type
from multipers.slicer import Slicer_type, available_columns

np.random.seed(0)
Expand All @@ -25,9 +26,9 @@

datasets: Sequence[str] = list(available_dataset.keys())
degrees: Sequence[int] = [0, 1]
num_pts: Sequence[int] = [200]
num_pts: Sequence[int] = [100, 300, 500]
complexes = ["delaunay", "rips"]
invariants = ["mma", "slice", "hilbert", "rank"]
invariants = ["mma", "slice", "hilbert", "rank", "gudhi_slice"]
vineyard = ["vine", "novine"]
num_lines = 50
num_repetition = 5
Expand All @@ -36,7 +37,7 @@


def fill_timing(arg, f):
timings[arg] = timeit(f, number=num_repetition)
timings[arg] = timeit(f, number=num_repetition) / num_repetition
terminal_width = shutil.get_terminal_size().columns
left = str(args)
right = f"{timings[arg]:.4f}"
Expand All @@ -56,15 +57,16 @@ def fill_timing(arg, f):
):
n, dataset, cplx, inv, degree, vine, dtype, col = args
pts = np.asarray(available_dataset[dataset](n))
s: Slicer_type = mmp.PointCloud2FilteredComplex(
st: SimplexTreeMulti_type = mmp.PointCloud2FilteredComplex(
complex=cplx,
bandwidths=[0.2],
num_collapses=2,
output_type="slicer",
reduce_degrees=[degree],
output_type="simplextree",
expand_dim=degree + 1,
).fit_transform([pts])[0][0]
s = mp.Slicer(s, vineyard=(vine == "vine"), dtype=dtype, column_type=col)
s = mp.Slicer(st, vineyard=(vine == "vine"), dtype=dtype, column_type=col).minpres(
degree=degree
)
box = mpg.compute_bounding_box(s)
s.minpres_degree = -1 ## makes it non-minpres again
if inv == "mma":
Expand All @@ -80,6 +82,18 @@ def fill_timing(arg, f):
)
directions = [np.ones(s.num_parameters)] * num_lines
f = lambda: s.persistence_on_lines(basepoints, directions)
elif inv == "gudhi_slice":
basepoints = np.random.uniform(
low=box[None, :, 0],
high=box[None, :, 1],
size=(num_lines, s.num_parameters),
)
directions = [np.ones(s.num_parameters)] * num_lines

def f():
for bp, dir in zip(basepoints, directions):
st.project_on_line(basepoint=bp, direction=dir).persistence()

elif inv == "hilbert":
grid = mpg.compute_grid(s, resolution=50, strategy="regular")
f = lambda: mp.signed_measure(s, grid=grid, degree=degree, invariant="hilbert")
Expand Down

0 comments on commit b203952

Please sign in to comment.