Skip to content

Commit

Permalink
Merge pull request #24 from MotionbyLearning/10_point_selection
Browse files Browse the repository at this point in the history
10 point selection
  • Loading branch information
rogerkuou authored Oct 21, 2024
2 parents 2f14db4 + f358f09 commit d769c42
Show file tree
Hide file tree
Showing 4 changed files with 421 additions and 0 deletions.
125 changes: 125 additions & 0 deletions examples/scripts/script_ps_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Example script for selecting PS from SLCs.
This .py script is designed to be executed with a Dask SLURMCluster on a SLURM managed HPC system.
It should be executed through a SLURM script by `sbatch` command.
Please do not run this script by "python xxx.py" on a login node.
"""

import logging
import os
import socket
import xarray as xr
from pathlib import Path
import numpy as np
from matplotlib import pyplot as plt
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
import sarxarray
import stmtools

from pydepsi.classification import ps_selection

# Make a logger to log the stages of processing
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
ch = logging.StreamHandler() # create console handler
ch.setLevel(logging.INFO)
logger.addHandler(ch)


def get_free_port():
"""Get a non-occupied port number."""
sock = socket.socket()
sock.bind(("", 0)) # Bind a port, it will be busy now
freesock = sock.getsockname()[1] # get the port number
sock.close() # Free the port, so it can be used later
return freesock


# ---- Config 1: Human Input ----


# Parameters
method = 'nmad' # Method for selection
threshold = 0.45 # Threshold for selection

# Input data paths
path_slc_zarr = Path("/project/caroline/slc_file.zarr") # Zarr file of all SLCs

# Output config
overwrite_zarr = False # Flag for zarr overwrite
chunk_space = 10000 # Output chunk size in space dimension
path_figure = Path("./figure") # Output path for figure
path_ps_zarr = Path("./ps.zarr") # output file for selected PS

path_figure.mkdir(exist_ok=True) # Make figure directory if not exists

# ---- Config 2: Dask configuration ----

# Option 1: Initiate a new SLURMCluster
# Uncomment the following part to setup a new Dask SLURMCluster
# N_WORKERS = 4 # Manual input: number of workers to spin-up
# FREE_SOCKET = get_free_port() # Get a free port
# cluster = SLURMCluster(
# name="dask-worker", # Name of the Slurm job
# queue="normal", # Name of the node partition on your SLURM system
# cores=4, # Number of cores per worker
# memory="32 GB", # Total amount of memory per worker
# processes=1, # Number of Python processes per worker
# walltime="3:00:00", # Reserve each worker for X hour
# scheduler_options={"dashboard_address": f":{FREE_SOCKET}"}, # Host Dashboard in a free socket
# )
# logger.info(f"Dask dashboard hosted at port: {FREE_SOCKET}.")
# logger.info(
# f"If you are forwarding Jupyter Server to a local port 8889, \
# you can access it at: localhost:8889/proxy/{FREE_SOCKET}/status"
# )

# Option 2: Use an existing SLURMCluster by giving the schedular address
# Uncomment the following part to use an existing Dask SLURMCluster
ADDRESS = "tcp://XX.X.X.XX:12345" # Manual input: Dask schedular address
SOCKET = 12345 # Manual input: port number. It should be the number after ":" of ADDRESS
cluster = None # Keep this None, needed for an if statement
logger.info(f"Dask dashboard hosted at port: {SOCKET}.")
logger.info(
f"If you are forwarding Jupyter Server to a local port 8889, \
you can access it at: localhost:8889/proxy/{SOCKET}/status"
)

if __name__ == "__main__":
logger.info("Initializing ...")

if cluster is None:
# Use existing cluster
client = Client(ADDRESS)
else:
# Scale a certain number workers
# each worker will appear as a Slurm job
cluster.scale(jobs=N_WORKERS)
client = Client(cluster)

# Load the SLC data
logger.info("Loading data ...")
ds = xr.open_zarr(path_slc_zarr) # Load the zarr file as a xr.Dataset
# Construct SLCs from xr.Dataset
# this construct three datavariables: complex, amplitude, and phase
slcs = sarxarray.from_dataset(slcs)

# A rechunk might be needed to make a optimal usage of the resources
# Uncomment the following line to apply a rechunk after loading the data
# slcs = slcs.chunk({"azimuth":1000, "range":1000, "time":-1})

# Select PS
stm_ps = ps_selection(method, threshold, method='nmad', output_chunks=chunk_space)

# Re-order the PS to make the spatially adjacent PS in the same chunk
stm_ps_reordered = stm_ps.stm.reorder(xlabel='lon', ylabel='lat')

# Save the PS to zarr
if overwrite_zarr:
stm_ps_reordered.to_zarr(path_ps_zarr, mode="w")
else:
stm_ps_reordered.to_zarr(path_ps_zarr)

# Close the client when finishing
client.close()
172 changes: 172 additions & 0 deletions pydepsi/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""Functions for scatterer selection related operations."""

from typing import Literal

import numpy as np
import xarray as xr


def ps_selection(
slcs: xr.Dataset,
threshold: float,
method: Literal["nad", "nmad"] = "nad",
output_chunks: int = 10000,
mem_persist: bool = False,
) -> xr.Dataset:
"""Select Persistent Scatterers (PS) from an SLC stack, and return a Space-Time Matrix.
The selection method is defined by `method` and `threshold`.
The selected pixels will be reshaped to (space, time), where `space` is the number of selected pixels.
The unselected pixels will be discarded.
The original `azimuth` and `range` coordinates will be persisted.
The computed NAD or NMAD will be added to the output dataset as a new variable. It can be persisted in
memory if `mem_persist` is True.
Parameters
----------
slcs : xr.Dataset
Input SLC stack. It should have the following dimensions: ("azimuth", "range", "time").
There should be a `amplitude` variable in the dataset.
threshold : float
Threshold value for selection.
method : Literal["nad", "nmad"], optional
Method of selection, by default "nad".
- "nad": Normalized Amplitude Dispersion
- "nmad": Normalized median absolute deviation
output_chunks : int, optional
Chunk size in the `space` dimension, by default 10000
mem_persist : bool, optional
If true persist the NAD or NMAD in memory, by default False.
Returns
-------
xr.Dataset
Selected STM, in form of an xarray.Dataset with two dimensions: (space, time).
Raises
------
NotImplementedError
Raised when an unsupported method is provided.
"""
# Make sure there is no temporal chunk
# since later a block function assumes all temporal data is available in a spatial block
slcs = slcs.chunk({"time": -1})

# Calculate selection mask
match method:
case "nad":
nad = xr.map_blocks(
_nad_block, slcs["amplitude"], template=slcs["amplitude"].isel(time=0).drop_vars("time")
)
nad = nad.compute() if mem_persist else nad
slcs = slcs.assign(pnt_nad=nad)
mask = nad < threshold
case "nmad":
nmad = xr.map_blocks(
_nmad_block, slcs["amplitude"], template=slcs["amplitude"].isel(time=0).drop_vars("time")
)
nmad = nmad.compute() if mem_persist else nmad
slcs = slcs.assign(pnt_nmad=nmad)
mask = nmad < threshold
case _:
raise NotImplementedError

# Get the 1D index on space dimension
mask_1d = mask.stack(space=("azimuth", "range")).drop_vars(["azimuth", "range", "space"]) # Drop multi-index coords
index = mask_1d["space"].where(mask_1d.compute(), other=0, drop=True) # Evaluate the 1D mask to index

# Reshape from Stack ("azimuth", "range", "time") to Space-Time Matrix ("space", "time")
stacked = slcs.stack(space=("azimuth", "range"))

# Drop multi-index coords for space coordinates
# This will also azimuth and range coordinates, as they are part of the multi-index coordinates
stm = stacked.drop_vars(["space", "azimuth", "range"])

# Assign a continuous index the space dimension
# Assign azimuth and range back as coordinates
stm = stm.assign_coords(
{
"space": (["space"], range(stm.sizes["space"])),
"azimuth": (["space"], stacked["azimuth"].values),
"range": (["space"], stacked["range"].values),
}
) # keep azimuth and range as coordinates

# Apply selection
stm_masked = stm.sel(space=index)

# Re-order the dimensions to community preferred ("space", "time") order
stm_masked = stm_masked.transpose("space", "time")

# Rechunk is needed because after apply maksing, the chunksize will be inconsistant
stm_masked = stm_masked.chunk(
{
"space": output_chunks,
"time": -1,
}
)

# Reset space coordinates
stm_masked = stm_masked.assign_coords(
{
"space": (["space"], range(stm_masked.sizes["space"])),
}
)

# Compute NAD or NMAD if mem_persist is True
# This only evaluate a very short task graph, since NAD or NMAD is already in memory
if mem_persist:
match method:
case "nad":
stm_masked["pnt_nad"] = stm_masked["pnt_nad"].compute()
case "nmad":
stm_masked["pnt_nmad"] = stm_masked["pnt_nmad"].compute()

return stm_masked


def _nad_block(amp: xr.DataArray) -> xr.DataArray:
"""Compute Normalized Amplitude Dispersion (NAD) for a block of amplitude data.
Parameters
----------
amp : xr.DataArray
Amplitude data, with dimensions ("azimuth", "range", "time").
This can be extracted from an SLC xr.Dataset.
Returns
-------
xr.DataArray
Normalized Amplitude Dispersion (NAD) data, with dimensions ("azimuth", "range").
"""
# Compute amplitude dispersion
# By defalut, the mean and std function from Xarray will skip NaN values
# However, if there is NaN value in time series, we want to discard the pixel
# Therefore, we set skipna=False
# Adding epsilon to avoid zero division
nad_da = amp.std(dim="time", skipna=False) / (amp.mean(dim="time", skipna=False) + np.finfo(amp.dtype).eps)

return nad_da


def _nmad_block(amp: xr.DataArray) -> xr.DataArray:
"""Compute Normalized Median Absolute Deviation(NMAD) for a block of amplitude data.
Parameters
----------
amp : xr.DataArray
Amplitude data, with dimensions ("azimuth", "range", "time").
This can be extracted from an SLC xr.Dataset.
Returns
-------
xr.DataArray
Normalized Median Absolute Dispersion (NMAD) data, with dimensions ("azimuth", "range").
"""
# Compoute NMAD
median_amplitude = amp.median(dim="time", skipna=False)
mad = (np.abs(amp - median_amplitude)).median(dim="time") # Median Absolute Deviation
nmad = mad / (median_amplitude + np.finfo(amp.dtype).eps) # Normalized Median Absolute Deviation

return nmad
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dev = [
"pycodestyle",
"pre-commit",
"ruff",
"graphviz",
]
docs = [
"mkdocs",
Expand Down
Loading

0 comments on commit d769c42

Please sign in to comment.