-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #24 from MotionbyLearning/10_point_selection
10 point selection
- Loading branch information
Showing
4 changed files
with
421 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
"""Example script for selecting PS from SLCs. | ||
This .py script is designed to be executed with a Dask SLURMCluster on a SLURM managed HPC system. | ||
It should be executed through a SLURM script by `sbatch` command. | ||
Please do not run this script by "python xxx.py" on a login node. | ||
""" | ||
|
||
import logging | ||
import os | ||
import socket | ||
import xarray as xr | ||
from pathlib import Path | ||
import numpy as np | ||
from matplotlib import pyplot as plt | ||
from dask.distributed import Client | ||
from dask_jobqueue import SLURMCluster | ||
import sarxarray | ||
import stmtools | ||
|
||
from pydepsi.classification import ps_selection | ||
|
||
# Make a logger to log the stages of processing | ||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
ch = logging.StreamHandler() # create console handler | ||
ch.setLevel(logging.INFO) | ||
logger.addHandler(ch) | ||
|
||
|
||
def get_free_port(): | ||
"""Get a non-occupied port number.""" | ||
sock = socket.socket() | ||
sock.bind(("", 0)) # Bind a port, it will be busy now | ||
freesock = sock.getsockname()[1] # get the port number | ||
sock.close() # Free the port, so it can be used later | ||
return freesock | ||
|
||
|
||
# ---- Config 1: Human Input ---- | ||
|
||
|
||
# Parameters | ||
method = 'nmad' # Method for selection | ||
threshold = 0.45 # Threshold for selection | ||
|
||
# Input data paths | ||
path_slc_zarr = Path("/project/caroline/slc_file.zarr") # Zarr file of all SLCs | ||
|
||
# Output config | ||
overwrite_zarr = False # Flag for zarr overwrite | ||
chunk_space = 10000 # Output chunk size in space dimension | ||
path_figure = Path("./figure") # Output path for figure | ||
path_ps_zarr = Path("./ps.zarr") # output file for selected PS | ||
|
||
path_figure.mkdir(exist_ok=True) # Make figure directory if not exists | ||
|
||
# ---- Config 2: Dask configuration ---- | ||
|
||
# Option 1: Initiate a new SLURMCluster | ||
# Uncomment the following part to setup a new Dask SLURMCluster | ||
# N_WORKERS = 4 # Manual input: number of workers to spin-up | ||
# FREE_SOCKET = get_free_port() # Get a free port | ||
# cluster = SLURMCluster( | ||
# name="dask-worker", # Name of the Slurm job | ||
# queue="normal", # Name of the node partition on your SLURM system | ||
# cores=4, # Number of cores per worker | ||
# memory="32 GB", # Total amount of memory per worker | ||
# processes=1, # Number of Python processes per worker | ||
# walltime="3:00:00", # Reserve each worker for X hour | ||
# scheduler_options={"dashboard_address": f":{FREE_SOCKET}"}, # Host Dashboard in a free socket | ||
# ) | ||
# logger.info(f"Dask dashboard hosted at port: {FREE_SOCKET}.") | ||
# logger.info( | ||
# f"If you are forwarding Jupyter Server to a local port 8889, \ | ||
# you can access it at: localhost:8889/proxy/{FREE_SOCKET}/status" | ||
# ) | ||
|
||
# Option 2: Use an existing SLURMCluster by giving the schedular address | ||
# Uncomment the following part to use an existing Dask SLURMCluster | ||
ADDRESS = "tcp://XX.X.X.XX:12345" # Manual input: Dask schedular address | ||
SOCKET = 12345 # Manual input: port number. It should be the number after ":" of ADDRESS | ||
cluster = None # Keep this None, needed for an if statement | ||
logger.info(f"Dask dashboard hosted at port: {SOCKET}.") | ||
logger.info( | ||
f"If you are forwarding Jupyter Server to a local port 8889, \ | ||
you can access it at: localhost:8889/proxy/{SOCKET}/status" | ||
) | ||
|
||
if __name__ == "__main__": | ||
logger.info("Initializing ...") | ||
|
||
if cluster is None: | ||
# Use existing cluster | ||
client = Client(ADDRESS) | ||
else: | ||
# Scale a certain number workers | ||
# each worker will appear as a Slurm job | ||
cluster.scale(jobs=N_WORKERS) | ||
client = Client(cluster) | ||
|
||
# Load the SLC data | ||
logger.info("Loading data ...") | ||
ds = xr.open_zarr(path_slc_zarr) # Load the zarr file as a xr.Dataset | ||
# Construct SLCs from xr.Dataset | ||
# this construct three datavariables: complex, amplitude, and phase | ||
slcs = sarxarray.from_dataset(slcs) | ||
|
||
# A rechunk might be needed to make a optimal usage of the resources | ||
# Uncomment the following line to apply a rechunk after loading the data | ||
# slcs = slcs.chunk({"azimuth":1000, "range":1000, "time":-1}) | ||
|
||
# Select PS | ||
stm_ps = ps_selection(method, threshold, method='nmad', output_chunks=chunk_space) | ||
|
||
# Re-order the PS to make the spatially adjacent PS in the same chunk | ||
stm_ps_reordered = stm_ps.stm.reorder(xlabel='lon', ylabel='lat') | ||
|
||
# Save the PS to zarr | ||
if overwrite_zarr: | ||
stm_ps_reordered.to_zarr(path_ps_zarr, mode="w") | ||
else: | ||
stm_ps_reordered.to_zarr(path_ps_zarr) | ||
|
||
# Close the client when finishing | ||
client.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
"""Functions for scatterer selection related operations.""" | ||
|
||
from typing import Literal | ||
|
||
import numpy as np | ||
import xarray as xr | ||
|
||
|
||
def ps_selection( | ||
slcs: xr.Dataset, | ||
threshold: float, | ||
method: Literal["nad", "nmad"] = "nad", | ||
output_chunks: int = 10000, | ||
mem_persist: bool = False, | ||
) -> xr.Dataset: | ||
"""Select Persistent Scatterers (PS) from an SLC stack, and return a Space-Time Matrix. | ||
The selection method is defined by `method` and `threshold`. | ||
The selected pixels will be reshaped to (space, time), where `space` is the number of selected pixels. | ||
The unselected pixels will be discarded. | ||
The original `azimuth` and `range` coordinates will be persisted. | ||
The computed NAD or NMAD will be added to the output dataset as a new variable. It can be persisted in | ||
memory if `mem_persist` is True. | ||
Parameters | ||
---------- | ||
slcs : xr.Dataset | ||
Input SLC stack. It should have the following dimensions: ("azimuth", "range", "time"). | ||
There should be a `amplitude` variable in the dataset. | ||
threshold : float | ||
Threshold value for selection. | ||
method : Literal["nad", "nmad"], optional | ||
Method of selection, by default "nad". | ||
- "nad": Normalized Amplitude Dispersion | ||
- "nmad": Normalized median absolute deviation | ||
output_chunks : int, optional | ||
Chunk size in the `space` dimension, by default 10000 | ||
mem_persist : bool, optional | ||
If true persist the NAD or NMAD in memory, by default False. | ||
Returns | ||
------- | ||
xr.Dataset | ||
Selected STM, in form of an xarray.Dataset with two dimensions: (space, time). | ||
Raises | ||
------ | ||
NotImplementedError | ||
Raised when an unsupported method is provided. | ||
""" | ||
# Make sure there is no temporal chunk | ||
# since later a block function assumes all temporal data is available in a spatial block | ||
slcs = slcs.chunk({"time": -1}) | ||
|
||
# Calculate selection mask | ||
match method: | ||
case "nad": | ||
nad = xr.map_blocks( | ||
_nad_block, slcs["amplitude"], template=slcs["amplitude"].isel(time=0).drop_vars("time") | ||
) | ||
nad = nad.compute() if mem_persist else nad | ||
slcs = slcs.assign(pnt_nad=nad) | ||
mask = nad < threshold | ||
case "nmad": | ||
nmad = xr.map_blocks( | ||
_nmad_block, slcs["amplitude"], template=slcs["amplitude"].isel(time=0).drop_vars("time") | ||
) | ||
nmad = nmad.compute() if mem_persist else nmad | ||
slcs = slcs.assign(pnt_nmad=nmad) | ||
mask = nmad < threshold | ||
case _: | ||
raise NotImplementedError | ||
|
||
# Get the 1D index on space dimension | ||
mask_1d = mask.stack(space=("azimuth", "range")).drop_vars(["azimuth", "range", "space"]) # Drop multi-index coords | ||
index = mask_1d["space"].where(mask_1d.compute(), other=0, drop=True) # Evaluate the 1D mask to index | ||
|
||
# Reshape from Stack ("azimuth", "range", "time") to Space-Time Matrix ("space", "time") | ||
stacked = slcs.stack(space=("azimuth", "range")) | ||
|
||
# Drop multi-index coords for space coordinates | ||
# This will also azimuth and range coordinates, as they are part of the multi-index coordinates | ||
stm = stacked.drop_vars(["space", "azimuth", "range"]) | ||
|
||
# Assign a continuous index the space dimension | ||
# Assign azimuth and range back as coordinates | ||
stm = stm.assign_coords( | ||
{ | ||
"space": (["space"], range(stm.sizes["space"])), | ||
"azimuth": (["space"], stacked["azimuth"].values), | ||
"range": (["space"], stacked["range"].values), | ||
} | ||
) # keep azimuth and range as coordinates | ||
|
||
# Apply selection | ||
stm_masked = stm.sel(space=index) | ||
|
||
# Re-order the dimensions to community preferred ("space", "time") order | ||
stm_masked = stm_masked.transpose("space", "time") | ||
|
||
# Rechunk is needed because after apply maksing, the chunksize will be inconsistant | ||
stm_masked = stm_masked.chunk( | ||
{ | ||
"space": output_chunks, | ||
"time": -1, | ||
} | ||
) | ||
|
||
# Reset space coordinates | ||
stm_masked = stm_masked.assign_coords( | ||
{ | ||
"space": (["space"], range(stm_masked.sizes["space"])), | ||
} | ||
) | ||
|
||
# Compute NAD or NMAD if mem_persist is True | ||
# This only evaluate a very short task graph, since NAD or NMAD is already in memory | ||
if mem_persist: | ||
match method: | ||
case "nad": | ||
stm_masked["pnt_nad"] = stm_masked["pnt_nad"].compute() | ||
case "nmad": | ||
stm_masked["pnt_nmad"] = stm_masked["pnt_nmad"].compute() | ||
|
||
return stm_masked | ||
|
||
|
||
def _nad_block(amp: xr.DataArray) -> xr.DataArray: | ||
"""Compute Normalized Amplitude Dispersion (NAD) for a block of amplitude data. | ||
Parameters | ||
---------- | ||
amp : xr.DataArray | ||
Amplitude data, with dimensions ("azimuth", "range", "time"). | ||
This can be extracted from an SLC xr.Dataset. | ||
Returns | ||
------- | ||
xr.DataArray | ||
Normalized Amplitude Dispersion (NAD) data, with dimensions ("azimuth", "range"). | ||
""" | ||
# Compute amplitude dispersion | ||
# By defalut, the mean and std function from Xarray will skip NaN values | ||
# However, if there is NaN value in time series, we want to discard the pixel | ||
# Therefore, we set skipna=False | ||
# Adding epsilon to avoid zero division | ||
nad_da = amp.std(dim="time", skipna=False) / (amp.mean(dim="time", skipna=False) + np.finfo(amp.dtype).eps) | ||
|
||
return nad_da | ||
|
||
|
||
def _nmad_block(amp: xr.DataArray) -> xr.DataArray: | ||
"""Compute Normalized Median Absolute Deviation(NMAD) for a block of amplitude data. | ||
Parameters | ||
---------- | ||
amp : xr.DataArray | ||
Amplitude data, with dimensions ("azimuth", "range", "time"). | ||
This can be extracted from an SLC xr.Dataset. | ||
Returns | ||
------- | ||
xr.DataArray | ||
Normalized Median Absolute Dispersion (NMAD) data, with dimensions ("azimuth", "range"). | ||
""" | ||
# Compoute NMAD | ||
median_amplitude = amp.median(dim="time", skipna=False) | ||
mad = (np.abs(amp - median_amplitude)).median(dim="time") # Median Absolute Deviation | ||
nmad = mad / (median_amplitude + np.finfo(amp.dtype).eps) # Normalized Median Absolute Deviation | ||
|
||
return nmad |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ dev = [ | |
"pycodestyle", | ||
"pre-commit", | ||
"ruff", | ||
"graphviz", | ||
] | ||
docs = [ | ||
"mkdocs", | ||
|
Oops, something went wrong.