From 9e78e52503896a45553cacf03f5b26dc104f43f4 Mon Sep 17 00:00:00 2001 From: Qiusheng Wu Date: Thu, 18 Jan 2024 01:32:29 -0500 Subject: [PATCH] Ignore points outside image boundary --- samgeo/common.py | 28 +++++++++++++++++++++------- samgeo/samgeo.py | 10 +++++++++- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/samgeo/common.py b/samgeo/common.py index 6c03d334..83d5f34f 100644 --- a/samgeo/common.py +++ b/samgeo/common.py @@ -780,7 +780,7 @@ def geojson_to_coords( def coords_to_xy( - src_fp: str, coords: list, coord_crs: str = "epsg:4326", **kwargs + src_fp: str, coords: list, coord_crs: str = "epsg:4326", return_out_of_bounds=False, **kwargs ) -> list: """Converts a list of coordinates to pixel coordinates, i.e., (col, row) coordinates. @@ -788,11 +788,14 @@ def coords_to_xy( src_fp: The source raster file path. coords: A list of coordinates in the format of [[x1, y1], [x2, y2], ...] coord_crs: The coordinate CRS of the input coordinates. Defaults to "epsg:4326". + return_out_of_bounds: Whether to return out of bounds coordinates. Defaults to False. **kwargs: Additional keyword arguments to pass to rasterio.transform.rowcol. Returns: A list of pixel coordinates in the format of [[x1, y1], [x2, y2], ...] """ + out_of_bounds = [] + if isinstance(coords, np.ndarray): coords = coords.tolist() @@ -805,15 +808,26 @@ def coords_to_xy( rows, cols = rasterio.transform.rowcol(src.transform, xs, ys, **kwargs) result = [[col, row] for col, row in zip(cols, rows)] - result = [ - [x, y] for x, y in result if x >= 0 and y >= 0 and x < width and y < height - ] - if len(result) == 0: + output = [] + + for i, (x, y) in enumerate(result): + if x >= 0 and y >= 0 and x < width and y < height: + output.append([x, y]) + else: + out_of_bounds.append(i) + + # output = [ + # [x, y] for x, y in result if x >= 0 and y >= 0 and x < width and y < height + # ] + if len(output) == 0: print("No valid pixel coordinates found.") - elif len(result) < len(coords): + elif len(output) < len(coords): print("Some coordinates are out of the image boundary.") - return result + if return_out_of_bounds: + return output, out_of_bounds + else: + return output def boxes_to_vector(coords, src_crs, dst_crs="EPSG:4326", output=None, **kwargs): diff --git a/samgeo/samgeo.py b/samgeo/samgeo.py index fd3a4532..86bc2160 100644 --- a/samgeo/samgeo.py +++ b/samgeo/samgeo.py @@ -503,6 +503,7 @@ def predict( return_results (bool, optional): Whether to return the predicted masks, scores, and logits. Defaults to False. """ + out_of_bounds = [] if isinstance(boxes, str): gdf = gpd.read_file(boxes) @@ -529,7 +530,7 @@ def predict( point_labels = self.point_labels if (point_crs is not None) and (point_coords is not None): - point_coords = coords_to_xy(self.source, point_coords, point_crs) + point_coords, out_of_bounds = coords_to_xy(self.source, point_coords, point_crs, return_out_of_bounds=True) if isinstance(point_coords, list): point_coords = np.array(point_coords) @@ -544,6 +545,13 @@ def predict( if len(point_labels) != len(point_coords): if len(point_labels) == 1: point_labels = point_labels * len(point_coords) + elif len(out_of_bounds) > 0: + print(f"Removing {len(out_of_bounds)} out-of-bound points.") + point_labels_new = [] + for i, p in enumerate(point_labels): + if i not in out_of_bounds: + point_labels_new.append(p) + point_labels = point_labels_new else: raise ValueError( "The length of point_labels must be equal to the length of point_coords."