diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..e69de29 diff --git a/multi_plankton_separation/api.py b/multi_plankton_separation/api.py index 87b27e1..c10daec 100644 --- a/multi_plankton_separation/api.py +++ b/multi_plankton_separation/api.py @@ -75,34 +75,6 @@ def get_metadata(): return meta -# def get_train_args(): -# arg_dict = { -# "epoch_num": fields.Int( -# required=False, -# missing=10, -# description="Total number of training epochs", -# ), -# } -# return arg_dict - - -# def train(**kwargs): -# """ -# Dummy training. We just sleep for some number of epochs (1 epoch = 1 second) -# mimicking some computation taking place. -# We can log some random losses in Tensorboard to mimic monitoring. -# """ -# logdir = BASE_DIR / "runs" / time.strftime("%Y-%m-%d_%H-%M-%S") -# writer = SummaryWriter(logdir=logdir) -# launch_tensorboard(logdir=logdir) -# for epoch in range(kwargs["epoch_num"]): -# time.sleep(1.) -# writer.add_scalar("scalars/loss", - math.log(epoch + 1), epoch) -# writer.close() - -# return {"status": "done", "final accuracy": 0.9} - - def get_predict_args(): """ Get the list of arguments for the predict function @@ -125,11 +97,16 @@ def get_predict_args(): enum=list_models, description="The model used to perform instance segmentation" ), - "threshold": fields.Float( + "min_mask_score": fields.Float( required=False, missing=0.9, description="The minimum confidence score for a mask to be selected" ), + "min_mask_value": fields.Float( + required=False, + missing=0.5, + description="The minimum value for a pixel to belong to a mask" + ), "accept" : fields.Str( required=False, missing='image/png', @@ -162,23 +139,31 @@ def predict(**kwargs): img = transform(orig_img) # Get predicted masks - pred_masks, pred_masks_probs = get_predicted_masks(model, img, kwargs["threshold"]) + pred_masks, pred_masks_probs = get_predicted_masks( + model, img, kwargs["min_mask_score"] + ) # Get sum of masks probabilities and mask centers mask_sum = np.zeros(pred_masks[0].shape) mask_centers_x = [] mask_centers_y = [] + # Get sum of masks and mask centers for the watershed for mask in pred_masks_probs: - mask_sum += mask + to_add = mask + to_add[to_add < kwargs["min_mask_value"]] = 0 + mask_sum += to_add center_x, center_y = np.unravel_index(np.argmax(mask), mask.shape) mask_centers_x.append(center_x) mask_centers_y.append(center_y) mask_centers = zip(mask_centers_x, mask_centers_y) + # Get silhouette of objects to use as a mask for the watershed + binary_img = (img[0, :, :] + img[1, :, :] + img[2, :, :] != 3).numpy().astype(float) + # Apply watershed algorithm - watershed_labels = get_watershed_result(mask_sum, mask_centers) + watershed_labels = get_watershed_result(mask_sum, mask_centers, mask=binary_img) # Save output separations separation_mask = np.ones(watershed_labels.shape) diff --git a/multi_plankton_separation/utils.py b/multi_plankton_separation/utils.py index 83d7176..11f70cc 100644 --- a/multi_plankton_separation/utils.py +++ b/multi_plankton_separation/utils.py @@ -6,7 +6,7 @@ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor from scipy import ndimage as ndi -from skimage.segmentation import watershed +from skimage.segmentation import watershed, find_boundaries import multi_plankton_separation.config as cfg @@ -75,26 +75,45 @@ def get_predicted_masks(model, image, score_threshold=0.9, mask_threshold=0.7): return pred_masks, pred_masks_probs -def get_watershed_result(mask_map, mask_centers): +def get_watershed_result(mask_map, mask_centers, mask=None): """ Apply the watershed algorithm on the predicted mask map, using the mask centers as markers """ + # Prepare watershed markers markers_mask = np.zeros(mask_map.shape, dtype=bool) for (x, y) in mask_centers: markers_mask[x, y] = True markers, _ = ndi.label(markers_mask) - watershed_mask = np.zeros(mask_map.shape, dtype='int64') - watershed_mask[mask_map > .01] = 1 + # Prepare watershed mask + if mask is None: + watershed_mask = np.zeros(mask_map.shape, dtype='int64') + watershed_mask[mask_map > .01] = 1 + else: + watershed_mask = mask + # Apply watershed labels = watershed( -mask_map, markers, mask=watershed_mask, watershed_line=False ) - labels_with_lines = watershed( - -mask_map, markers, mask=watershed_mask, watershed_line=True - ) - labels_with_lines[labels == 0] = -1 + + # Derive separation lines + lines = np.zeros(labels.shape) + unique_labels = list(np.unique(labels)) + unique_labels.remove(0) + + for value in unique_labels: + single_shape = (labels == value).astype(int) + boundaries = find_boundaries( + single_shape, connectivity=2, mode='outer', background=0 + ) + boundaries[(labels == 0) | (labels == value)] = 0 + lines[boundaries == 1] = 1 + + labels_with_lines = labels + labels_with_lines[labels_with_lines == 0] = -1 + labels_with_lines[lines == 1] = 0 return labels_with_lines