Skip to content

Commit

Permalink
Merge pull request #30 from MotionbyLearning/17_sparse_selection
Browse files Browse the repository at this point in the history
17 sparse selection
  • Loading branch information
rogerkuou authored Dec 11, 2024
2 parents 4dfd08f + de4b146 commit 102ba6f
Show file tree
Hide file tree
Showing 3 changed files with 310 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import sarxarray
import stmtools

from pydepsi.classification import ps_selection
from pydepsi.io import read_metadata
from pydepsi.classification import ps_selection, network_stm_selection

# Make a logger to log the stages of processing
logger = logging.getLogger(__name__)
Expand All @@ -37,14 +38,19 @@ def get_free_port():


# ---- Config 1: Human Input ----


# Parameters
method = 'nmad' # Method for selection
threshold = 0.45 # Threshold for selection

# Input data paths
path_slc_zarr = Path("/project/caroline/slc_file.zarr") # Zarr file of all SLCs
path_metadata = Path("/project/caroline/metadata.res") # Metadata file

# Parameters PS selection
ps_selection_method = 'nmad' # Method for PS selection
ps_selection_threshold = 0.45 # Threshold for PS selection

# Parameters network selection
network_stm_quality_metric = 'nmad' # Quality metric for network selection
network_stm_quality_threshold = 0.45 # Quality threshold for network selection
min_dist = 200 # Distance threshold for network selection, in meters
include_index = [57, 101, 189] # Force including the points with index 57, 101, and 189, use None if no point need to be included

# Output config
overwrite_zarr = False # Flag for zarr overwrite
Expand Down Expand Up @@ -87,8 +93,10 @@ def get_free_port():
)

if __name__ == "__main__":
# ---- Processing Stage 0: Initialization ----
logger.info("Initializing ...")

# Initiate a Dask client
if cluster is None:
# Use existing cluster
client = Client(ADDRESS)
Expand All @@ -98,8 +106,13 @@ def get_free_port():
cluster.scale(jobs=N_WORKERS)
client = Client(cluster)

# Load metadata
metadata = read_metadata(path_metadata)

# ---- Processing Stage 1: Pixel Classification ----
# Load the SLC data
logger.info("Loading data ...")
logger.info("Processing Stage 1: Pixel Classification")
logger.info("Loading SLC data ...")
ds = xr.open_zarr(path_slc_zarr) # Load the zarr file as a xr.Dataset
# Construct SLCs from xr.Dataset
# this construct three datavariables: complex, amplitude, and phase
Expand All @@ -110,16 +123,35 @@ def get_free_port():
# slcs = slcs.chunk({"azimuth":1000, "range":1000, "time":-1})

# Select PS
stm_ps = ps_selection(method, threshold, method='nmad', output_chunks=chunk_space)
logger.info("PS Selection ...")
stm_ps = ps_selection(method, threshold, method=ps_selection_method, output_chunks=chunk_space)

# Re-order the PS to make the spatially adjacent PS in the same chunk
logger.info("Reorder selected scatterers ...")
stm_ps_reordered = stm_ps.stm.reorder(xlabel='lon', ylabel='lat')

# Save the PS to zarr
logger.info("Writting selected pixels to Zarr ...")
if overwrite_zarr:
stm_ps_reordered.to_zarr(path_ps_zarr, mode="w")
else:
stm_ps_reordered.to_zarr(path_ps_zarr)

# ---- Processing Stage 2: Network Processing ----
# Uncomment the following line to load the PS data from zarr
# stm_ps_reordered = xr.open_zarr(path_ps_zarr)

# Select network points
logger.info("Select network scatterers ...")
# Apply a pre-filter
stm_network_candidates = xr.where(stm_ps_reordered[network_stm_quality_metric]<network_stm_quality_threshold)
# Select based on sparsity and quality
stm_network = network_stm_selection(stm_network_candidates,
min_dist,
include_index=include_index,
sortby_var=network_stm_quality_metric,
azimuth_spacing=metadata['azimuth_spacing'],
range_spacing=metadata['range_spacing'])

# Close the client when finishing
client.close()
148 changes: 148 additions & 0 deletions pydepsi/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import xarray as xr
from scipy.spatial import KDTree


def ps_selection(
Expand Down Expand Up @@ -126,6 +127,126 @@ def ps_selection(
return stm_masked


def network_stm_selection(
stm: xr.Dataset,
min_dist: int | float,
include_index: list[int] = None,
sortby_var: str = "pnt_nmad",
crs: int | str = "radar",
x_var: str = "azimuth",
y_var: str = "range",
azimuth_spacing: float = None,
range_spacing: float = None,
):
"""Select a Space-Time Matrix (STM) from a candidate STM for network processing.
The selection is based on two criteria:
1. A minimum distance between selected points.
2. A sorting metric to select better points.
The candidate STM will be sorted by the sorting metric.
The selection will be performed iteratively, starting from the best point.
In each iteration, the best point will be selected, and points within the minimum distance will be removed.
The process will continue until no points are left in the candidate STM.
Parameters
----------
stm : xr.Dataset
candidate Space-Time Matrix (STM).
min_dist : int | float
Minimum distance between selected points.
include_index : list[int], optional
Index of points in the candidate STM that must be included in the selection, by default None
sortby_var : str, optional
Sorting metric for selecting points, by default "pnt_nmad"
crs : int | str, optional
EPSG code of Coordinate Reference System of `x_var` and `y_var`, by default "radar".
If crs is "radar", the distance will be calculated based on radar coordinates, and
azimuth_spacing and range_spacing must be provided.
x_var : str, optional
Data variable name for x coordinate, by default "azimuth"
y_var : str, optional
Data variable name for y coordinate, by default "range"
azimuth_spacing : float, optional
Azimuth spacing, by default None. Required if crs is "radar".
range_spacing : float, optional
Range spacing, by default None. Required if crs is "radar".
Returns
-------
xr.Dataset
Selected network Space-Time Matrix (STM).
Raises
------
ValueError
Raised when `azimuth_spacing` or `range_spacing` is not provided for radar coordinates.
NotImplementedError
Raised when an unsupported Coordinate Reference System is provided.
"""
match crs:
case "radar":
if (azimuth_spacing is None) or (range_spacing is None):
raise ValueError("Azimuth and range spacing must be provided for radar coordinates.")
case _:
raise NotImplementedError

# Get coordinates and sorting metric, load them into memory
stm_select = None
stm_remain = stm[[x_var, y_var, sortby_var]].compute()

# Select the include_index if provided
if include_index is not None:
stm_select = stm_remain.isel(space=include_index)

# Remove points within min_dist of the included points
coords_include = np.column_stack(
[stm_select["azimuth"].values * azimuth_spacing, stm_select["range"].values * range_spacing]
)
coords_remain = np.column_stack(
[stm_remain["azimuth"].values * azimuth_spacing, stm_remain["range"].values * range_spacing]
)
idx_drop = _idx_within_distance(coords_include, coords_remain, min_dist)
if idx_drop is not None:
stm_remain = stm_remain.where(~(stm_remain["space"].isin(idx_drop)), drop=True)

# Reorder the remaining points by the sorting metric
stm_remain = stm_remain.sortby(sortby_var)

# Build a list of the index of selected points
if stm_select is None:
space_idx_sel = []
else:
space_idx_sel = stm_select["space"].values.tolist()

while stm_remain.sizes["space"] > 0:
# Select one point with best sorting metric
stm_now = stm_remain.isel(space=0)

# Append the selected point index
space_idx_sel.append(stm_now["space"].values.tolist())

# Remove the selected point from the remaining points
stm_remain = stm_remain.isel(space=slice(1, None)).copy()

# Remove points in stm_remain within min_dist of stm_now
coords_remain = np.column_stack(
[stm_remain["azimuth"].values * azimuth_spacing, stm_remain["range"].values * range_spacing]
)
coords_stmnow = np.column_stack(
[stm_now["azimuth"].values * azimuth_spacing, stm_now["range"].values * range_spacing]
)
idx_drop = _idx_within_distance(coords_stmnow, coords_remain, min_dist)
if idx_drop is not None:
stm_drop = stm_remain.isel(space=idx_drop)
stm_remain = stm_remain.where(~(stm_remain["space"].isin(stm_drop["space"])), drop=True)

# Get the selected points by space index from the original stm
stm_out = stm.sel(space=space_idx_sel)

return stm_out


def _nad_block(amp: xr.DataArray) -> xr.DataArray:
"""Compute Normalized Amplitude Dispersion (NAD) for a block of amplitude data.
Expand Down Expand Up @@ -170,3 +291,30 @@ def _nmad_block(amp: xr.DataArray) -> xr.DataArray:
nmad = mad / (median_amplitude + np.finfo(amp.dtype).eps) # Normalized Median Absolute Deviation

return nmad


def _idx_within_distance(coords_ref, coords_others, min_dist):
"""Get the index of points in coords_others that are within min_dist of coords_ref.
Parameters
----------
coords_ref : np.ndarray
Coordinates of reference points. Shape (n, 2).
coords_others : np.ndarray
Coordinates of other points. Shape (m, 2).
min_dist : int, float
distance threshold.
Returns
-------
np.ndarray
Index of points in coords_others that are within `min_dist` of `coords_ref`.
"""
kd_ref = KDTree(coords_ref)
kd_others = KDTree(coords_others)
sdm = kd_ref.sparse_distance_matrix(kd_others, min_dist)
if len(sdm) > 0:
idx = np.array(list(sdm.keys()))[:, 1]
return idx
else:
return None
Loading

0 comments on commit 102ba6f

Please sign in to comment.