diff --git a/examples/scripts/script_ps_selection.py b/examples/scripts/script_ps_selection.py new file mode 100644 index 0000000..90a5d45 --- /dev/null +++ b/examples/scripts/script_ps_selection.py @@ -0,0 +1,125 @@ +"""Example script for selecting PS from SLCs. + +This .py script is designed to be executed with a Dask SLURMCluster on a SLURM managed HPC system. +It should be executed through a SLURM script by `sbatch` command. +Please do not run this script by "python xxx.py" on a login node. +""" + +import logging +import os +import socket +import xarray as xr +from pathlib import Path +import numpy as np +from matplotlib import pyplot as plt +from dask.distributed import Client +from dask_jobqueue import SLURMCluster +import sarxarray +import stmtools + +from pydepsi.classification import ps_selection + +# Make a logger to log the stages of processing +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +ch = logging.StreamHandler() # create console handler +ch.setLevel(logging.INFO) +logger.addHandler(ch) + + +def get_free_port(): + """Get a non-occupied port number.""" + sock = socket.socket() + sock.bind(("", 0)) # Bind a port, it will be busy now + freesock = sock.getsockname()[1] # get the port number + sock.close() # Free the port, so it can be used later + return freesock + + +# ---- Config 1: Human Input ---- + + +# Parameters +method = 'nmad' # Method for selection +threshold = 0.45 # Threshold for selection + +# Input data paths +path_slc_zarr = Path("/project/caroline/slc_file.zarr") # Zarr file of all SLCs + +# Output config +overwrite_zarr = False # Flag for zarr overwrite +chunk_space = 10000 # Output chunk size in space dimension +path_figure = Path("./figure") # Output path for figure +path_ps_zarr = Path("./ps.zarr") # output file for selected PS + +path_figure.mkdir(exist_ok=True) # Make figure directory if not exists + +# ---- Config 2: Dask configuration ---- + +# Option 1: Initiate a new SLURMCluster +# Uncomment the following part to setup a new Dask SLURMCluster +# N_WORKERS = 4 # Manual input: number of workers to spin-up +# FREE_SOCKET = get_free_port() # Get a free port +# cluster = SLURMCluster( +# name="dask-worker", # Name of the Slurm job +# queue="normal", # Name of the node partition on your SLURM system +# cores=4, # Number of cores per worker +# memory="32 GB", # Total amount of memory per worker +# processes=1, # Number of Python processes per worker +# walltime="3:00:00", # Reserve each worker for X hour +# scheduler_options={"dashboard_address": f":{FREE_SOCKET}"}, # Host Dashboard in a free socket +# ) +# logger.info(f"Dask dashboard hosted at port: {FREE_SOCKET}.") +# logger.info( +# f"If you are forwarding Jupyter Server to a local port 8889, \ +# you can access it at: localhost:8889/proxy/{FREE_SOCKET}/status" +# ) + +# Option 2: Use an existing SLURMCluster by giving the schedular address +# Uncomment the following part to use an existing Dask SLURMCluster +ADDRESS = "tcp://XX.X.X.XX:12345" # Manual input: Dask schedular address +SOCKET = 12345 # Manual input: port number. It should be the number after ":" of ADDRESS +cluster = None # Keep this None, needed for an if statement +logger.info(f"Dask dashboard hosted at port: {SOCKET}.") +logger.info( + f"If you are forwarding Jupyter Server to a local port 8889, \ + you can access it at: localhost:8889/proxy/{SOCKET}/status" +) + +if __name__ == "__main__": + logger.info("Initializing ...") + + if cluster is None: + # Use existing cluster + client = Client(ADDRESS) + else: + # Scale a certain number workers + # each worker will appear as a Slurm job + cluster.scale(jobs=N_WORKERS) + client = Client(cluster) + + # Load the SLC data + logger.info("Loading data ...") + ds = xr.open_zarr(path_slc_zarr) # Load the zarr file as a xr.Dataset + # Construct SLCs from xr.Dataset + # this construct three datavariables: complex, amplitude, and phase + slcs = sarxarray.from_dataset(slcs) + + # A rechunk might be needed to make a optimal usage of the resources + # Uncomment the following line to apply a rechunk after loading the data + # slcs = slcs.chunk({"azimuth":1000, "range":1000, "time":-1}) + + # Select PS + stm_ps = ps_selection(method, threshold, method='nmad', output_chunks=chunk_space) + + # Re-order the PS to make the spatially adjacent PS in the same chunk + stm_ps_reordered = stm_ps.stm.reorder(xlabel='lon', ylabel='lat') + + # Save the PS to zarr + if overwrite_zarr: + stm_ps_reordered.to_zarr(path_ps_zarr, mode="w") + else: + stm_ps_reordered.to_zarr(path_ps_zarr) + + # Close the client when finishing + client.close() diff --git a/pydepsi/classification.py b/pydepsi/classification.py new file mode 100644 index 0000000..14592e8 --- /dev/null +++ b/pydepsi/classification.py @@ -0,0 +1,172 @@ +"""Functions for scatterer selection related operations.""" + +from typing import Literal + +import numpy as np +import xarray as xr + + +def ps_selection( + slcs: xr.Dataset, + threshold: float, + method: Literal["nad", "nmad"] = "nad", + output_chunks: int = 10000, + mem_persist: bool = False, +) -> xr.Dataset: + """Select Persistent Scatterers (PS) from an SLC stack, and return a Space-Time Matrix. + + The selection method is defined by `method` and `threshold`. + The selected pixels will be reshaped to (space, time), where `space` is the number of selected pixels. + The unselected pixels will be discarded. + The original `azimuth` and `range` coordinates will be persisted. + The computed NAD or NMAD will be added to the output dataset as a new variable. It can be persisted in + memory if `mem_persist` is True. + + Parameters + ---------- + slcs : xr.Dataset + Input SLC stack. It should have the following dimensions: ("azimuth", "range", "time"). + There should be a `amplitude` variable in the dataset. + threshold : float + Threshold value for selection. + method : Literal["nad", "nmad"], optional + Method of selection, by default "nad". + - "nad": Normalized Amplitude Dispersion + - "nmad": Normalized median absolute deviation + output_chunks : int, optional + Chunk size in the `space` dimension, by default 10000 + mem_persist : bool, optional + If true persist the NAD or NMAD in memory, by default False. + + + Returns + ------- + xr.Dataset + Selected STM, in form of an xarray.Dataset with two dimensions: (space, time). + + Raises + ------ + NotImplementedError + Raised when an unsupported method is provided. + """ + # Make sure there is no temporal chunk + # since later a block function assumes all temporal data is available in a spatial block + slcs = slcs.chunk({"time": -1}) + + # Calculate selection mask + match method: + case "nad": + nad = xr.map_blocks( + _nad_block, slcs["amplitude"], template=slcs["amplitude"].isel(time=0).drop_vars("time") + ) + nad = nad.compute() if mem_persist else nad + slcs = slcs.assign(pnt_nad=nad) + mask = nad < threshold + case "nmad": + nmad = xr.map_blocks( + _nmad_block, slcs["amplitude"], template=slcs["amplitude"].isel(time=0).drop_vars("time") + ) + nmad = nmad.compute() if mem_persist else nmad + slcs = slcs.assign(pnt_nmad=nmad) + mask = nmad < threshold + case _: + raise NotImplementedError + + # Get the 1D index on space dimension + mask_1d = mask.stack(space=("azimuth", "range")).drop_vars(["azimuth", "range", "space"]) # Drop multi-index coords + index = mask_1d["space"].where(mask_1d.compute(), other=0, drop=True) # Evaluate the 1D mask to index + + # Reshape from Stack ("azimuth", "range", "time") to Space-Time Matrix ("space", "time") + stacked = slcs.stack(space=("azimuth", "range")) + + # Drop multi-index coords for space coordinates + # This will also azimuth and range coordinates, as they are part of the multi-index coordinates + stm = stacked.drop_vars(["space", "azimuth", "range"]) + + # Assign a continuous index the space dimension + # Assign azimuth and range back as coordinates + stm = stm.assign_coords( + { + "space": (["space"], range(stm.sizes["space"])), + "azimuth": (["space"], stacked["azimuth"].values), + "range": (["space"], stacked["range"].values), + } + ) # keep azimuth and range as coordinates + + # Apply selection + stm_masked = stm.sel(space=index) + + # Re-order the dimensions to community preferred ("space", "time") order + stm_masked = stm_masked.transpose("space", "time") + + # Rechunk is needed because after apply maksing, the chunksize will be inconsistant + stm_masked = stm_masked.chunk( + { + "space": output_chunks, + "time": -1, + } + ) + + # Reset space coordinates + stm_masked = stm_masked.assign_coords( + { + "space": (["space"], range(stm_masked.sizes["space"])), + } + ) + + # Compute NAD or NMAD if mem_persist is True + # This only evaluate a very short task graph, since NAD or NMAD is already in memory + if mem_persist: + match method: + case "nad": + stm_masked["pnt_nad"] = stm_masked["pnt_nad"].compute() + case "nmad": + stm_masked["pnt_nmad"] = stm_masked["pnt_nmad"].compute() + + return stm_masked + + +def _nad_block(amp: xr.DataArray) -> xr.DataArray: + """Compute Normalized Amplitude Dispersion (NAD) for a block of amplitude data. + + Parameters + ---------- + amp : xr.DataArray + Amplitude data, with dimensions ("azimuth", "range", "time"). + This can be extracted from an SLC xr.Dataset. + + Returns + ------- + xr.DataArray + Normalized Amplitude Dispersion (NAD) data, with dimensions ("azimuth", "range"). + """ + # Compute amplitude dispersion + # By defalut, the mean and std function from Xarray will skip NaN values + # However, if there is NaN value in time series, we want to discard the pixel + # Therefore, we set skipna=False + # Adding epsilon to avoid zero division + nad_da = amp.std(dim="time", skipna=False) / (amp.mean(dim="time", skipna=False) + np.finfo(amp.dtype).eps) + + return nad_da + + +def _nmad_block(amp: xr.DataArray) -> xr.DataArray: + """Compute Normalized Median Absolute Deviation(NMAD) for a block of amplitude data. + + Parameters + ---------- + amp : xr.DataArray + Amplitude data, with dimensions ("azimuth", "range", "time"). + This can be extracted from an SLC xr.Dataset. + + Returns + ------- + xr.DataArray + Normalized Median Absolute Dispersion (NMAD) data, with dimensions ("azimuth", "range"). + """ + # Compoute NMAD + median_amplitude = amp.median(dim="time", skipna=False) + mad = (np.abs(amp - median_amplitude)).median(dim="time") # Median Absolute Deviation + nmad = mad / (median_amplitude + np.finfo(amp.dtype).eps) # Normalized Median Absolute Deviation + + return nmad diff --git a/pyproject.toml b/pyproject.toml index df9f88a..f8a7e04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dev = [ "pycodestyle", "pre-commit", "ruff", + "graphviz", ] docs = [ "mkdocs", diff --git a/tests/test_classification.py b/tests/test_classification.py new file mode 100644 index 0000000..076d0ae --- /dev/null +++ b/tests/test_classification.py @@ -0,0 +1,123 @@ +"""test_classification.py""" + +import dask.array as da +import numpy as np +import pytest +import xarray as xr + +from pydepsi.classification import _nad_block, _nmad_block, ps_selection + +# Create a random number generator +rng = np.random.default_rng(42) + + +def test_ps_seletion_nad(): + slcs = xr.Dataset( + data_vars={"amplitude": (("azimuth", "range", "time"), np.ones((10, 10, 10)))}, + coords={"azimuth": np.arange(10), "range": np.arange(10), "time": np.arange(10)}, + ) + res = ps_selection(slcs, 0.5, method="nad", output_chunks=5) + assert res.sizes["time"] == 10 + assert res.sizes["space"] == 100 + assert "pnt_nad" in res + assert "azimuth" in res + assert "range" in res + assert "space" in res.dims + assert "time" in res.dims + assert isinstance(res["pnt_nad"].data, da.core.Array) + + +def test_ps_seletion_nmad(): + slcs = xr.Dataset( + data_vars={"amplitude": (("azimuth", "range", "time"), np.ones((10, 10, 10)))}, + coords={"azimuth": np.arange(10), "range": np.arange(10), "time": np.arange(10)}, + ) + res = ps_selection(slcs, 0.5, method="nmad", output_chunks=5) + assert res.sizes["time"] == 10 + assert res.sizes["space"] == 100 + assert "pnt_nmad" in res + assert "azimuth" in res + assert "range" in res + assert "space" in res.dims + assert "time" in res.dims + assert isinstance(res["pnt_nmad"].data, da.core.Array) + + +def test_ps_seletion_nad_mempersist(): + """When mem_persist=True, results should be a numpy array.""" + slcs = xr.Dataset( + data_vars={"amplitude": (("azimuth", "range", "time"), np.ones((10, 10, 10)))}, + coords={"azimuth": np.arange(10), "range": np.arange(10), "time": np.arange(10)}, + ) + res = ps_selection(slcs, 0.5, method="nad", output_chunks=5, mem_persist=True) + assert isinstance(res["pnt_nad"].data, np.ndarray) + + +def test_ps_seletion_nmad_mempersist(): + """When mem_persist=True, results should be a numpy array.""" + slcs = xr.Dataset( + data_vars={"amplitude": (("azimuth", "range", "time"), np.ones((10, 10, 10)))}, + coords={"azimuth": np.arange(10), "range": np.arange(10), "time": np.arange(10)}, + ) + res = ps_selection(slcs, 0.5, method="nmad", output_chunks=5, mem_persist=True) + assert isinstance(res["pnt_nmad"].data, np.ndarray) + + +def test_ps_seletion_not_implemented(): + slcs = xr.Dataset( + data_vars={"amplitude": (("azimuth", "range", "time"), np.ones((10, 10, 10)))}, + coords={"azimuth": np.arange(10), "range": np.arange(10), "time": np.arange(10)}, + ) + # catch not implemented method + with pytest.raises(NotImplementedError): + ps_selection(slcs, 0.5, method="not_implemented", output_chunks=5) + + +def test_nad_block_zero_dispersion(): + """NAD for a constant array should be zero.""" + slcs = xr.DataArray( + data=np.ones((10, 10, 10)), + dims=("azimuth", "range", "time"), + coords={"azimuth": np.arange(10), "range": np.arange(10), "time": np.arange(10)}, + ) + res = _nad_block(slcs) + assert res.shape == (10, 10) + assert np.all(res == 0) + + +def test_nmad_block_zero_dispersion(): + """NMAD for a constant array should be zero.""" + slcs = xr.DataArray( + data=np.ones((10, 10, 10)), + dims=("azimuth", "range", "time"), + coords={"azimuth": np.arange(10), "range": np.arange(10), "time": np.arange(10)}, + ) + res = _nmad_block(slcs) + assert res.shape == (10, 10) + assert np.all(res == 0) + + +def test_nad_block_select_two(): + """Should select two pixels with zero dispersion.""" + amp = rng.random((10, 10, 10)) # Random amplitude data + amp[0, 0:2, :] = 1.0 # Two pixels with constant amplitude + slcs = xr.Dataset( + data_vars={"amplitude": (("azimuth", "range", "time"), amp)}, + coords={"azimuth": np.arange(10), "range": np.arange(10), "time": np.arange(10)}, + ) + res = ps_selection(slcs, 1e-10, method="nad", output_chunks=5) # Select pixels with dispersion lower than 0.00001 + assert res.sizes["time"] == 10 + assert res.sizes["space"] == 2 + + +def test_nmad_block_select_two(): + """Should select two pixels with zero dispersion.""" + amp = rng.random((10, 10, 10)) # Random amplitude data + amp[0, 0:2, :] = 1.0 # Two pixels with constant amplitude + slcs = xr.Dataset( + data_vars={"amplitude": (("azimuth", "range", "time"), amp)}, + coords={"azimuth": np.arange(10), "range": np.arange(10), "time": np.arange(10)}, + ) + res = ps_selection(slcs, 1e-10, method="nmad", output_chunks=5) # Select pixels with dispersion lower than 0.00001 + assert res.sizes["time"] == 10 + assert res.sizes["space"] == 2