From 837f5a5803fb3736e1b7bd07624eb42ff90e9103 Mon Sep 17 00:00:00 2001 From: Qiusheng Wu Date: Sat, 5 Oct 2024 15:13:46 -0400 Subject: [PATCH] Fix save masks bug (#332) * Fix save masks bug * Fix save masks bug for SAM 1 --- samgeo/samgeo.py | 7 ++++--- samgeo/samgeo2.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/samgeo/samgeo.py b/samgeo/samgeo.py index 1c6f1bc4..638326f9 100644 --- a/samgeo/samgeo.py +++ b/samgeo/samgeo.py @@ -265,8 +265,8 @@ def save_masks( # Generate a mask of objects with unique values if unique: - # Sort the masks by area in ascending order - sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=False) + # Sort the masks by area in descending order + sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) # Create an output image with the same size as the input image objects = np.zeros( @@ -276,9 +276,10 @@ def save_masks( ) ) # Assign a unique value to each object + count = len(sorted_masks) for index, ann in enumerate(sorted_masks): m = ann["segmentation"] - objects[m] = index + 1 + objects[m] = count - index # Generate a binary mask else: diff --git a/samgeo/samgeo2.py b/samgeo/samgeo2.py index d87a7748..ac4ce657 100644 --- a/samgeo/samgeo2.py +++ b/samgeo/samgeo2.py @@ -293,8 +293,8 @@ def save_masks( # Generate a mask of objects with unique values if unique: - # Sort the masks by area in ascending order - sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=False) + # Sort the masks by area in descending order + sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) # Create an output image with the same size as the input image objects = np.zeros( @@ -304,9 +304,10 @@ def save_masks( ) ) # Assign a unique value to each object + count = len(sorted_masks) for index, ann in enumerate(sorted_masks): m = ann["segmentation"] - objects[m] = index + 1 + objects[m] = count - index # Generate a binary mask else: