From 39b5a835751d9373672aa6c0c54e7af175e3c8a4 Mon Sep 17 00:00:00 2001 From: Ou Ku Date: Mon, 21 Oct 2024 17:04:12 +0200 Subject: [PATCH] more unit test --- tests/test_classification.py | 93 +++++++++++++++++++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/tests/test_classification.py b/tests/test_classification.py index a39613d..be405b7 100644 --- a/tests/test_classification.py +++ b/tests/test_classification.py @@ -5,7 +5,7 @@ import pytest import xarray as xr -from pydepsi.classification import _idx_within_distance, _nad_block, _nmad_block, ps_selection +from pydepsi.classification import _idx_within_distance, _nad_block, _nmad_block, network_stm_seletcion, ps_selection # Create a random number generator rng = np.random.default_rng(42) @@ -73,6 +73,97 @@ def test_ps_seletion_not_implemented(): ps_selection(slcs, 0.5, method="not_implemented", output_chunks=5) +def test_network_stm_seletcion_results(): + stm = xr.Dataset( + data_vars={ + "amplitude": (("space", "time"), np.ones((100, 10))), + "pnt_nad": (("space"), np.linspace(0, 1, 100)), + "pnt_nmad": (("space"), np.linspace(0, 1, 100)), + }, + coords={"azimuth": (("space"), np.arange(100)), "range": (("space"), np.arange(100)), "time": np.arange(10)}, + ) + res_nad = network_stm_seletcion(stm, min_dist=20, sortby_var="pnt_nad", azimuth_spacing=10, range_spacing=10) + res_nmad = network_stm_seletcion(stm, min_dist=20, sortby_var="pnt_nmad", azimuth_spacing=10, range_spacing=10) + # Fields should remain the same + assert "pnt_nad" in res_nad + assert "azimuth" in res_nad + assert "range" in res_nad + assert "space" in res_nad.dims + assert "time" in res_nad.dims + # Dimensions should be half + assert res_nad.sizes["space"] == 50 + assert res_nad.sizes["time"] == 10 + assert res_nmad.sizes["space"] == 50 + assert res_nmad.sizes["time"] == 10 + + +def test_network_stm_seletcion_quality(): + stm = xr.Dataset( + data_vars={ + "amplitude": (("space", "time"), np.ones((5, 10))), + "pnt_nad": (("space"), np.array([0.9, 0.01, 0.9, 0.9, 0.01])), + "pnt_nmad": (("space"), np.array([0.01, 0.9, 0.9, 0.9, 0.01])), + }, + coords={ + "space": np.array([3, 1, 2, 5, 7]), # non monotonic space coords + "time": np.arange(10), + "azimuth": (("space"), np.arange(5)), + "range": (("space"), np.arange(5)), + }, + ) + res_nad = network_stm_seletcion(stm, min_dist=3, sortby_var="pnt_nad", azimuth_spacing=1, range_spacing=1) + res_nmad = network_stm_seletcion(stm, min_dist=3, sortby_var="pnt_nmad", azimuth_spacing=1, range_spacing=1) + + # The two pixels with the lowest NAD should be selected + assert np.all(np.isclose(res_nad["pnt_nad"].values, 0.01, rtol=1e-09, atol=1e-09)) + assert np.all(np.isclose(res_nmad["pnt_nmad"].values, 0.01, rtol=1e-09, atol=1e-09)) + assert np.all(res_nad["space"].values == np.array([1, 7])) + assert np.all(res_nmad["space"].values == np.array([3, 7])) + + +def test_network_stm_seletcion_include_index(): + stm = xr.Dataset( + data_vars={ + "amplitude": (("space", "time"), np.ones((5, 10))), + "pnt_nad": (("space"), np.array([0.01, 0.01, 0.9, 0.9, 0.01])), + "pnt_nmad": (("space"), np.array([0.01, 0.01, 0.9, 0.9, 0.01])), + }, + coords={ + "space": np.array([1, 2, 5, 6, 7]), # non monotonic space coords + "time": np.arange(10), + "azimuth": (("space"), np.arange(5)), + "range": (("space"), np.arange(5)), + }, + ) + res_nad = network_stm_seletcion( + stm, min_dist=3, include_index=[1], sortby_var="pnt_nad", azimuth_spacing=1, range_spacing=1 + ) + res_nmad = network_stm_seletcion( + stm, min_dist=3, include_index=[1], sortby_var="pnt_nmad", azimuth_spacing=1, range_spacing=1 + ) + + # The two pixels with the lowest NAD should be selected + assert np.all(np.isclose(res_nad["pnt_nad"].values, 0.01, rtol=1e-09, atol=1e-09)) + assert np.all(np.isclose(res_nmad["pnt_nmad"].values, 0.01, rtol=1e-09, atol=1e-09)) + assert np.all(res_nad["space"].values == np.array([2, 7])) + assert np.all(res_nmad["space"].values == np.array([2, 7])) + + +def test_network_stm_seletcion_wrong_csr(): + stm = xr.Dataset( + data_vars={ + "amplitude": (("space", "time"), np.ones((100, 10))), + "pnt_nad": (("space"), np.linspace(0, 1, 100)), + }, + coords={"azimuth": (("space"), np.arange(100)), "range": (("space"), np.arange(100)), "time": np.arange(10)}, + ) + # catch not implemented method + with pytest.raises(NotImplementedError): + network_stm_seletcion( + stm, min_dist=20, sortby_var="pnt_nad", azimuth_spacing=10, range_spacing=10, crs="not_implemented" + ) + + def test_nad_block_zero_dispersion(): """NAD for a constant array should be zero.""" slcs = xr.DataArray(