From 28b6e467dd58b5db89b7ed754259bcbc6d0745fb Mon Sep 17 00:00:00 2001 From: Rohit Khati <50332813+ro-hit81@users.noreply.github.com> Date: Tue, 22 Oct 2024 19:13:22 +0200 Subject: [PATCH] Simplify Mask Handling When Only One Object is Detected (#352) * Fix empty mask overlay handling when no objects are detected. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * mask overlay corrected * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Handle case for single objects detected. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- samgeo/text_sam.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/samgeo/text_sam.py b/samgeo/text_sam.py index aa5679c3..1c010111 100644 --- a/samgeo/text_sam.py +++ b/samgeo/text_sam.py @@ -337,7 +337,10 @@ def predict( masks = torch.tensor([]) if len(boxes) > 0: masks = self.predict_sam(image_pil, boxes) - if 1 in masks.shape: + # If masks have 4 dimensions and the second dimension is 1 (e.g., [boxes, 1, height, width]), + # squeeze that dimension to reduce it to 3 dimensions ([boxes, height, width]). + # If boxes = 1, the mask's shape will be [1, height, width] after squeezing. + if masks.ndim == 4 and masks.shape[1] == 1: masks = masks.squeeze(1) if boxes.nelement() == 0: # No "object" instances found