Skip to content

Commit

Permalink
update separation lines computation
Browse files Browse the repository at this point in the history
  • Loading branch information
emmaamblard committed Nov 29, 2023
1 parent 46ff750 commit 5e23822
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 40 deletions.
Empty file added .gitattributes
Empty file.
49 changes: 17 additions & 32 deletions multi_plankton_separation/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,34 +75,6 @@ def get_metadata():
return meta


# def get_train_args():
# arg_dict = {
# "epoch_num": fields.Int(
# required=False,
# missing=10,
# description="Total number of training epochs",
# ),
# }
# return arg_dict


# def train(**kwargs):
# """
# Dummy training. We just sleep for some number of epochs (1 epoch = 1 second)
# mimicking some computation taking place.
# We can log some random losses in Tensorboard to mimic monitoring.
# """
# logdir = BASE_DIR / "runs" / time.strftime("%Y-%m-%d_%H-%M-%S")
# writer = SummaryWriter(logdir=logdir)
# launch_tensorboard(logdir=logdir)
# for epoch in range(kwargs["epoch_num"]):
# time.sleep(1.)
# writer.add_scalar("scalars/loss", - math.log(epoch + 1), epoch)
# writer.close()

# return {"status": "done", "final accuracy": 0.9}


def get_predict_args():
"""
Get the list of arguments for the predict function
Expand All @@ -125,11 +97,16 @@ def get_predict_args():
enum=list_models,
description="The model used to perform instance segmentation"
),
"threshold": fields.Float(
"min_mask_score": fields.Float(
required=False,
missing=0.9,
description="The minimum confidence score for a mask to be selected"
),
"min_mask_value": fields.Float(
required=False,
missing=0.5,
description="The minimum value for a pixel to belong to a mask"
),
"accept" : fields.Str(
required=False,
missing='image/png',
Expand Down Expand Up @@ -162,23 +139,31 @@ def predict(**kwargs):
img = transform(orig_img)

# Get predicted masks
pred_masks, pred_masks_probs = get_predicted_masks(model, img, kwargs["threshold"])
pred_masks, pred_masks_probs = get_predicted_masks(
model, img, kwargs["min_mask_score"]
)

# Get sum of masks probabilities and mask centers
mask_sum = np.zeros(pred_masks[0].shape)
mask_centers_x = []
mask_centers_y = []

# Get sum of masks and mask centers for the watershed
for mask in pred_masks_probs:
mask_sum += mask
to_add = mask
to_add[to_add < kwargs["min_mask_value"]] = 0
mask_sum += to_add
center_x, center_y = np.unravel_index(np.argmax(mask), mask.shape)
mask_centers_x.append(center_x)
mask_centers_y.append(center_y)

mask_centers = zip(mask_centers_x, mask_centers_y)

# Get silhouette of objects to use as a mask for the watershed
binary_img = (img[0, :, :] + img[1, :, :] + img[2, :, :] != 3).numpy().astype(float)

# Apply watershed algorithm
watershed_labels = get_watershed_result(mask_sum, mask_centers)
watershed_labels = get_watershed_result(mask_sum, mask_centers, mask=binary_img)

# Save output separations
separation_mask = np.ones(watershed_labels.shape)
Expand Down
35 changes: 27 additions & 8 deletions multi_plankton_separation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from scipy import ndimage as ndi
from skimage.segmentation import watershed
from skimage.segmentation import watershed, find_boundaries

import multi_plankton_separation.config as cfg

Expand Down Expand Up @@ -75,26 +75,45 @@ def get_predicted_masks(model, image, score_threshold=0.9, mask_threshold=0.7):
return pred_masks, pred_masks_probs


def get_watershed_result(mask_map, mask_centers):
def get_watershed_result(mask_map, mask_centers, mask=None):
"""
Apply the watershed algorithm on the predicted mask map,
using the mask centers as markers
"""
# Prepare watershed markers
markers_mask = np.zeros(mask_map.shape, dtype=bool)
for (x, y) in mask_centers:
markers_mask[x, y] = True
markers, _ = ndi.label(markers_mask)

watershed_mask = np.zeros(mask_map.shape, dtype='int64')
watershed_mask[mask_map > .01] = 1
# Prepare watershed mask
if mask is None:
watershed_mask = np.zeros(mask_map.shape, dtype='int64')
watershed_mask[mask_map > .01] = 1
else:
watershed_mask = mask

# Apply watershed
labels = watershed(
-mask_map, markers, mask=watershed_mask, watershed_line=False
)
labels_with_lines = watershed(
-mask_map, markers, mask=watershed_mask, watershed_line=True
)
labels_with_lines[labels == 0] = -1

# Derive separation lines
lines = np.zeros(labels.shape)
unique_labels = list(np.unique(labels))
unique_labels.remove(0)

for value in unique_labels:
single_shape = (labels == value).astype(int)
boundaries = find_boundaries(
single_shape, connectivity=2, mode='outer', background=0
)
boundaries[(labels == 0) | (labels == value)] = 0
lines[boundaries == 1] = 1

labels_with_lines = labels
labels_with_lines[labels_with_lines == 0] = -1
labels_with_lines[lines == 1] = 0

return labels_with_lines

Expand Down

0 comments on commit 5e23822

Please sign in to comment.