Skip to content

Commit

Permalink
more unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
rogerkuou committed Oct 21, 2024
1 parent 7455919 commit 2cb55d4
Showing 1 changed file with 92 additions and 1 deletion.
93 changes: 92 additions & 1 deletion tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
import xarray as xr

from pydepsi.classification import _idx_within_distance, _nad_block, _nmad_block, ps_selection
from pydepsi.classification import _idx_within_distance, _nad_block, _nmad_block, network_stm_seletcion, ps_selection

# Create a random number generator
rng = np.random.default_rng(42)
Expand Down Expand Up @@ -73,6 +73,97 @@ def test_ps_seletion_not_implemented():
ps_selection(slcs, 0.5, method="not_implemented", output_chunks=5)


def test_network_stm_seletcion_results():
stm = xr.Dataset(
data_vars={
"amplitude": (("space", "time"), np.ones((100, 10))),
"pnt_nad": (("space"), np.linspace(0, 1, 100)),
"pnt_nmad": (("space"), np.linspace(0, 1, 100)),
},
coords={"azimuth": (("space"), np.arange(100)), "range": (("space"), np.arange(100)), "time": np.arange(10)},
)
res_nad = network_stm_seletcion(stm, min_dist=20, sortby_var="pnt_nad", azimuth_spacing=10, range_spacing=10)
res_nmad = network_stm_seletcion(stm, min_dist=20, sortby_var="pnt_nmad", azimuth_spacing=10, range_spacing=10)
# Fields should remain the same
assert "pnt_nad" in res_nad
assert "azimuth" in res_nad
assert "range" in res_nad
assert "space" in res_nad.dims
assert "time" in res_nad.dims
# Dimensions should be half
assert res_nad.sizes["space"] == 50
assert res_nad.sizes["time"] == 10
assert res_nmad.sizes["space"] == 50
assert res_nmad.sizes["time"] == 10


def test_network_stm_seletcion_quality():
stm = xr.Dataset(
data_vars={
"amplitude": (("space", "time"), np.ones((5, 10))),
"pnt_nad": (("space"), np.array([0.9, 0.01, 0.9, 0.9, 0.01])),
"pnt_nmad": (("space"), np.array([0.01, 0.9, 0.9, 0.9, 0.01])),
},
coords={
"space": np.array([3, 1, 2, 5, 7]), # non monotonic space coords
"time": np.arange(10),
"azimuth": (("space"), np.arange(5)),
"range": (("space"), np.arange(5)),
},
)
res_nad = network_stm_seletcion(stm, min_dist=3, sortby_var="pnt_nad", azimuth_spacing=1, range_spacing=1)
res_nmad = network_stm_seletcion(stm, min_dist=3, sortby_var="pnt_nmad", azimuth_spacing=1, range_spacing=1)

# The two pixels with the lowest NAD should be selected
assert np.all(np.isclose(res_nad["pnt_nad"].values, 0.01, rtol=1e-09, atol=1e-09))
assert np.all(np.isclose(res_nmad["pnt_nmad"].values, 0.01, rtol=1e-09, atol=1e-09))
assert np.all(res_nad["space"].values == np.array([1, 7]))
assert np.all(res_nmad["space"].values == np.array([3, 7]))


def test_network_stm_seletcion_include_index():
stm = xr.Dataset(
data_vars={
"amplitude": (("space", "time"), np.ones((5, 10))),
"pnt_nad": (("space"), np.array([0.01, 0.01, 0.9, 0.9, 0.01])),
"pnt_nmad": (("space"), np.array([0.01, 0.01, 0.9, 0.9, 0.01])),
},
coords={
"space": np.array([1, 2, 5, 6, 7]), # non monotonic space coords
"time": np.arange(10),
"azimuth": (("space"), np.arange(5)),
"range": (("space"), np.arange(5)),
},
)
res_nad = network_stm_seletcion(
stm, min_dist=3, include_index=[1], sortby_var="pnt_nad", azimuth_spacing=1, range_spacing=1
)
res_nmad = network_stm_seletcion(
stm, min_dist=3, include_index=[1], sortby_var="pnt_nmad", azimuth_spacing=1, range_spacing=1
)

# The two pixels with the lowest NAD should be selected
assert np.all(np.isclose(res_nad["pnt_nad"].values, 0.01, rtol=1e-09, atol=1e-09))
assert np.all(np.isclose(res_nmad["pnt_nmad"].values, 0.01, rtol=1e-09, atol=1e-09))
assert np.all(res_nad["space"].values == np.array([2, 7]))
assert np.all(res_nmad["space"].values == np.array([2, 7]))


def test_network_stm_seletcion_wrong_csr():
stm = xr.Dataset(
data_vars={
"amplitude": (("space", "time"), np.ones((100, 10))),
"pnt_nad": (("space"), np.linspace(0, 1, 100)),
},
coords={"azimuth": (("space"), np.arange(100)), "range": (("space"), np.arange(100)), "time": np.arange(10)},
)
# catch not implemented method
with pytest.raises(NotImplementedError):
network_stm_seletcion(
stm, min_dist=20, sortby_var="pnt_nad", azimuth_spacing=10, range_spacing=10, crs="not_implemented"
)


def test_nad_block_zero_dispersion():
"""NAD for a constant array should be zero."""
slcs = xr.DataArray(
Expand Down

0 comments on commit 2cb55d4

Please sign in to comment.