Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

17 sparse selection #30

Merged
merged 13 commits into from
Dec 11, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import sarxarray
import stmtools

from pydepsi.classification import ps_selection
from pydepsi.io import read_metadata
from pydepsi.classification import ps_selection, network_stm_selection

# Make a logger to log the stages of processing
logger = logging.getLogger(__name__)
Expand All @@ -37,14 +38,19 @@ def get_free_port():


# ---- Config 1: Human Input ----


# Parameters
method = 'nmad' # Method for selection
threshold = 0.45 # Threshold for selection

# Input data paths
path_slc_zarr = Path("/project/caroline/slc_file.zarr") # Zarr file of all SLCs

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make this a variable, specify file path at the beginning of the script.

path_metadata = Path("/project/caroline/metadata.res") # Metadata file

# Parameters PS selection
ps_selection_method = 'nmad' # Method for PS selection
ps_selection_threshold = 0.45 # Threshold for PS selection

# Parameters network selection
network_stm_quality_metric = 'nmad' # Quality metric for network selection
network_stm_quality_threshold = 0.45 # Quality threshold for network selection
min_dist = 200 # Distance threshold for network selection, in meters
include_index = [57, 101, 189] # Force including the points with index 57, 101, and 189, use None if no point need to be included

# Output config
overwrite_zarr = False # Flag for zarr overwrite
Expand Down Expand Up @@ -87,8 +93,10 @@ def get_free_port():
)

if __name__ == "__main__":
# ---- Processing Stage 0: Initialization ----
logger.info("Initializing ...")

# Initiate a Dask client
if cluster is None:
# Use existing cluster
client = Client(ADDRESS)
Expand All @@ -98,8 +106,13 @@ def get_free_port():
cluster.scale(jobs=N_WORKERS)
client = Client(cluster)

# Load metadata
metadata = read_metadata(path_metadata)

# ---- Processing Stage 1: Pixel Classification ----
# Load the SLC data
logger.info("Loading data ...")
logger.info("Processing Stage 1: Pixel Classification")
logger.info("Loading SLC data ...")
ds = xr.open_zarr(path_slc_zarr) # Load the zarr file as a xr.Dataset
# Construct SLCs from xr.Dataset
# this construct three datavariables: complex, amplitude, and phase
Expand All @@ -110,16 +123,35 @@ def get_free_port():
# slcs = slcs.chunk({"azimuth":1000, "range":1000, "time":-1})

# Select PS
stm_ps = ps_selection(method, threshold, method='nmad', output_chunks=chunk_space)
logger.info("PS Selection ...")
stm_ps = ps_selection(method, threshold, method=ps_selection_method, output_chunks=chunk_space)

# Re-order the PS to make the spatially adjacent PS in the same chunk
logger.info("Reorder selected scatterers ...")
stm_ps_reordered = stm_ps.stm.reorder(xlabel='lon', ylabel='lat')

# Save the PS to zarr
logger.info("Writting selected pixels to Zarr ...")
if overwrite_zarr:
stm_ps_reordered.to_zarr(path_ps_zarr, mode="w")
else:
stm_ps_reordered.to_zarr(path_ps_zarr)

# ---- Processing Stage 2: Network Processing ----
# Uncomment the following line to load the PS data from zarr
# stm_ps_reordered = xr.open_zarr(path_ps_zarr)

# Select network points
logger.info("Select network scatterers ...")
# Apply a pre-filter
stm_network_candidates = xr.where(stm_ps_reordered[network_stm_quality_metric]<network_stm_quality_threshold)
# Select based on sparsity and quality
stm_network = network_stm_selection(stm_network_candidates,
min_dist,
include_index=include_index,
sortby_var=network_stm_quality_metric,
azimuth_spacing=metadata['azimuth_spacing'],
range_spacing=metadata['range_spacing'])

# Close the client when finishing
client.close()
148 changes: 148 additions & 0 deletions pydepsi/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import xarray as xr
from scipy.spatial import KDTree


def ps_selection(
Expand Down Expand Up @@ -126,6 +127,126 @@ def ps_selection(
return stm_masked


def network_stm_selection(
stm: xr.Dataset,
min_dist: int | float,
include_index: list[int] = None,
sortby_var: str = "pnt_nmad",
crs: int | str = "radar",
x_var: str = "azimuth",
y_var: str = "range",
azimuth_spacing: float = None,
range_spacing: float = None,
):
"""Select a Space-Time Matrix (STM) from a candidate STM for network processing.

The selection is based on two criteria:
1. A minimum distance between selected points.
2. A sorting metric to select better points.

The candidate STM will be sorted by the sorting metric.
The selection will be performed iteratively, starting from the best point.
In each iteration, the best point will be selected, and points within the minimum distance will be removed.
The process will continue until no points are left in the candidate STM.

Parameters
----------
stm : xr.Dataset
candidate Space-Time Matrix (STM).
min_dist : int | float
Minimum distance between selected points.
include_index : list[int], optional
Index of points in the candidate STM that must be included in the selection, by default None
sortby_var : str, optional
Sorting metric for selecting points, by default "pnt_nmad"
crs : int | str, optional
EPSG code of Coordinate Reference System of `x_var` and `y_var`, by default "radar".
If crs is "radar", the distance will be calculated based on radar coordinates, and
azimuth_spacing and range_spacing must be provided.
x_var : str, optional
Data variable name for x coordinate, by default "azimuth"
y_var : str, optional
Data variable name for y coordinate, by default "range"
azimuth_spacing : float, optional
Azimuth spacing, by default None. Required if crs is "radar".
range_spacing : float, optional
Range spacing, by default None. Required if crs is "radar".

Returns
-------
xr.Dataset
Selected network Space-Time Matrix (STM).

Raises
------
ValueError
Raised when `azimuth_spacing` or `range_spacing` is not provided for radar coordinates.
NotImplementedError
Raised when an unsupported Coordinate Reference System is provided.
"""
match crs:
case "radar":
if (azimuth_spacing is None) or (range_spacing is None):
raise ValueError("Azimuth and range spacing must be provided for radar coordinates.")
case _:
raise NotImplementedError

# Get coordinates and sorting metric, load them into memory
stm_select = None
stm_remain = stm[[x_var, y_var, sortby_var]].compute()

# Select the include_index if provided
if include_index is not None:
stm_select = stm_remain.isel(space=include_index)

# Remove points within min_dist of the included points
coords_include = np.column_stack(
[stm_select["azimuth"].values * azimuth_spacing, stm_select["range"].values * range_spacing]
)
coords_remain = np.column_stack(
[stm_remain["azimuth"].values * azimuth_spacing, stm_remain["range"].values * range_spacing]
)
idx_drop = _idx_within_distance(coords_include, coords_remain, min_dist)
if idx_drop is not None:
stm_remain = stm_remain.where(~(stm_remain["space"].isin(idx_drop)), drop=True)

# Reorder the remaining points by the sorting metric
stm_remain = stm_remain.sortby(sortby_var)

# Build a list of the index of selected points
if stm_select is None:
space_idx_sel = []
else:
space_idx_sel = stm_select["space"].values.tolist()

while stm_remain.sizes["space"] > 0:
# Select one point with best sorting metric
stm_now = stm_remain.isel(space=0)

# Append the selected point index
space_idx_sel.append(stm_now["space"].values.tolist())

# Remove the selected point from the remaining points
stm_remain = stm_remain.isel(space=slice(1, None)).copy()

# Remove points in stm_remain within min_dist of stm_now
coords_remain = np.column_stack(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This calculation is now repeated 1000+ times. Move outside loop. So store the coordinates, work with index.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see the comment above

[stm_remain["azimuth"].values * azimuth_spacing, stm_remain["range"].values * range_spacing]
)
coords_stmnow = np.column_stack(
[stm_now["azimuth"].values * azimuth_spacing, stm_now["range"].values * range_spacing]
)
idx_drop = _idx_within_distance(coords_stmnow, coords_remain, min_dist)
if idx_drop is not None:
stm_drop = stm_remain.isel(space=idx_drop)
stm_remain = stm_remain.where(~(stm_remain["space"].isin(stm_drop["space"])), drop=True)

# Get the selected points by space index from the original stm
stm_out = stm.sel(space=space_idx_sel)

return stm_out


def _nad_block(amp: xr.DataArray) -> xr.DataArray:
"""Compute Normalized Amplitude Dispersion (NAD) for a block of amplitude data.

Expand Down Expand Up @@ -170,3 +291,30 @@ def _nmad_block(amp: xr.DataArray) -> xr.DataArray:
nmad = mad / (median_amplitude + np.finfo(amp.dtype).eps) # Normalized Median Absolute Deviation

return nmad


def _idx_within_distance(coords_ref, coords_others, min_dist):
"""Get the index of points in coords_others that are within min_dist of coords_ref.

Parameters
----------
coords_ref : np.ndarray
Coordinates of reference points. Shape (n, 2).
coords_others : np.ndarray
Coordinates of other points. Shape (m, 2).
min_dist : int, float
distance threshold.

Returns
-------
np.ndarray
Index of points in coords_others that are within `min_dist` of `coords_ref`.
"""
kd_ref = KDTree(coords_ref)
kd_others = KDTree(coords_others)
sdm = kd_ref.sparse_distance_matrix(kd_others, min_dist)
if len(sdm) > 0:
idx = np.array(list(sdm.keys()))[:, 1]
return idx
else:
return None
Loading
Loading