Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encoder Features Extraction #288

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 90 additions & 32 deletions infer/tile.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import multiprocessing
from multiprocessing import Lock, Pool
import pickle as pk

multiprocessing.set_start_method("spawn", True) # ! must be at top for VScode debugging
import argparse
Expand Down Expand Up @@ -45,13 +46,13 @@
####
def _prepare_patching(img, window_size, mask_size, return_src_top_corner=False):
"""Prepare patch information for tile processing.

Args:
img: original input image
window_size: input patch size
mask_size: output patch size
return_src_top_corner: whether to return coordiante information for top left corner of img

"""

win_size = window_size
Expand Down Expand Up @@ -96,10 +97,15 @@ def get_last_steps(length, msk_size, step_size):

####
def _post_process_patches(
post_proc_func, post_proc_kwargs, patch_info, image_info, overlay_kwargs,
post_proc_func,
post_proc_kwargs,
patch_info,
patch_info1,
image_info,
overlay_kwargs,
):
"""Apply post processing to patches.

Args:
post_proc_func: post processing function to use
post_proc_kwargs: keyword arguments used in post processing function
Expand All @@ -112,6 +118,9 @@ def _post_process_patches(
patch_info = sorted(patch_info, key=lambda x: [x[0][0], x[0][1]])
patch_info, patch_data = zip(*patch_info)

patch_info1 = sorted(patch_info1, key=lambda x: [x[0][0], x[0][1]])
patch_info1, patch_data1 = zip(*patch_info1)

src_shape = image_info["src_shape"]
src_image = image_info["src_image"]

Expand All @@ -121,12 +130,20 @@ def _post_process_patches(

nr_row = max([x[2] for x in patch_info]) + 1
nr_col = max([x[3] for x in patch_info]) + 1

pred_map = np.concatenate(patch_data, axis=0)
pred_map1 = np.concatenate(patch_data1, axis=0)
# print("pred_map0", pred_map.shape)

pred_map = np.reshape(pred_map, (nr_row, nr_col) + patch_shape)

pred_map = np.transpose(pred_map, axes)
# print("pred_map2", pred_map.shape)

pred_map = np.reshape(
pred_map, (patch_shape[0] * nr_row, patch_shape[1] * nr_col, ch)
)

# crop back to original shape
pred_map = np.squeeze(pred_map[: src_shape[0], : src_shape[1]])

Expand All @@ -140,7 +157,14 @@ def _post_process_patches(
src_image.copy(), inst_info_dict, **overlay_kwargs
)

return image_info["name"], pred_map, pred_inst, inst_info_dict, overlaid_img
return (
image_info["name"],
pred_map,
pred_inst,
inst_info_dict,
overlaid_img,
pred_map1,
)


class InferManager(base.InferManager):
Expand All @@ -159,45 +183,64 @@ def process_file_list(self, run_args):
patterning = lambda x: re.sub("([\[\]])", "[\\1]", x)
file_path_list = glob.glob(patterning("%s/*" % self.input_dir))
file_path_list.sort() # ensure same order
assert len(file_path_list) > 0, 'Not Detected Any Files From Path'

rm_n_mkdir(self.output_dir + '/json/')
rm_n_mkdir(self.output_dir + '/mat/')
rm_n_mkdir(self.output_dir + '/overlay/')
assert len(file_path_list) > 0, "Not Detected Any Files From Path"

rm_n_mkdir(self.output_dir + "/json/")
rm_n_mkdir(self.output_dir + "/mat/")
rm_n_mkdir(self.output_dir + "/overlay/")
rm_n_mkdir(self.output_dir + "/encoder_features/")

if self.save_qupath:
rm_n_mkdir(self.output_dir + "/qupath/")

def proc_callback(results):
"""Post processing callback.

Output format is implicit assumption, taken from `_post_process_patches`

"""
img_name, pred_map, pred_inst, inst_info_dict, overlaid_img = results
(
img_name,
pred_map,
pred_inst,
inst_info_dict,
overlaid_img,
extract_features,
) = results

nuc_val_list = list(inst_info_dict.values())
# need singleton to make matlab happy
nuc_uid_list = np.array(list(inst_info_dict.keys()))[:,None]
nuc_type_list = np.array([v["type"] for v in nuc_val_list])[:,None]
nuc_uid_list = np.array(list(inst_info_dict.keys()))[:, None]
nuc_type_list = np.array([v["type"] for v in nuc_val_list])[:, None]
nuc_coms_list = np.array([v["centroid"] for v in nuc_val_list])

mat_dict = {
"inst_map" : pred_inst,
"inst_uid" : nuc_uid_list,
"inst_map": pred_inst,
"inst_uid": nuc_uid_list,
"inst_type": nuc_type_list,
"inst_centroid": nuc_coms_list
"inst_centroid": nuc_coms_list,
}
if self.nr_types is None: # matlab does not have None type array
mat_dict.pop("inst_type", None)
if self.nr_types is None: # matlab does not have None type array
mat_dict.pop("inst_type", None)

if self.save_raw_map:
mat_dict["raw_map"] = pred_map
save_path = "%s/mat/%s.mat" % (self.output_dir, img_name)
sio.savemat(save_path, mat_dict)

save_path = "%s/overlay/%s.png" % (self.output_dir, img_name)

# print("pred_map2", pred_map.shape)

cv2.imwrite(save_path, cv2.cvtColor(overlaid_img, cv2.COLOR_RGB2BGR))

save_path = "%s/encoder_features/%s.npy" % (self.output_dir, img_name)
print("extract_features", extract_features.shape)
with open(save_path, "wb") as f:
np.save(f, extract_features)
# pk.dump(extract_features, f)
# torch.save(extract_features,save_path)

if self.save_qupath:
nuc_val_list = list(inst_info_dict.values())
nuc_type_list = np.array([v["type"] for v in nuc_val_list])
Expand Down Expand Up @@ -303,17 +346,25 @@ def detach_items_of_uid(items_list, uid, nr_expected_items):
)

accumulated_patch_output = []
accumulated_np = []
for batch_idx, batch_data in enumerate(dataloader):
sample_data_list, sample_info_list = batch_data
sample_output_list = self.run_step(sample_data_list)
sample_output_list, encoder_features = self.run_step(sample_data_list)

sample_info_list_features = sample_info_list.numpy()
sample_info_list = sample_info_list.numpy()
curr_batch_size = sample_output_list.shape[0]
sample_output_list = np.split(
sample_output_list, curr_batch_size, axis=0
)
encoder_features = np.split(encoder_features, curr_batch_size, axis=0)
sample_info_list = np.split(sample_info_list, curr_batch_size, axis=0)
sample_output_list = list(zip(sample_info_list, sample_output_list))
accumulated_patch_output.extend(sample_output_list)

sample_info_list_features = np.split(sample_info_list_features, curr_batch_size, axis=0)
sample_info_list_features = list(zip(sample_info_list_features, encoder_features))
accumulated_np.extend(sample_info_list_features)
pbar.update()
pbar.close()

Expand All @@ -324,6 +375,12 @@ def detach_items_of_uid(items_list, uid, nr_expected_items):
file_ouput_data, accumulated_patch_output = detach_items_of_uid(
accumulated_patch_output, file_idx, image_info[1]
)
file_ouput_data1, accumulated_np = detach_items_of_uid(
accumulated_np, file_idx, image_info[1]
)
# print("file_ouput_data", len(file_ouput_data))

# print("file_ouput_data1",len(file_ouput_data1))

# * detach this into func and multiproc dispatch it
src_pos = image_info[2] # src top left corner within padded image
Expand All @@ -350,10 +407,12 @@ def detach_items_of_uid(items_list, uid, nr_expected_items):
"type_colour": self.type_info_dict,
"line_thickness": 2,
}

func_args = (
self.post_proc_func,
post_proc_kwargs,
file_ouput_data,
file_ouput_data1,
file_info,
overlay_kwargs,
)
Expand All @@ -373,16 +432,15 @@ def detach_items_of_uid(items_list, uid, nr_expected_items):
for future in as_completed(future_list):
# TODO: way to retrieve which file crashed ?
# ! silent crash, cancel all and raise error
if future.exception() is not None:
log_info("Silent Crash")
# ! cancel somehow leads to cascade error later
# ! so just poll it then crash once all future
# ! acquired for now
# for future in future_list:
# future.cancel()
# break
else:
file_path = proc_callback(future.result())
log_info("Done Assembling %s" % file_path)
# if future.exception() is not None:
# log_info("Silent Crash")
# # ! cancel somehow leads to cascade error later
# # ! so just poll it then crash once all future
# # ! acquired for now
# # for future in future_list:
# # future.cancel()
# # break
# else:
file_path = proc_callback(future.result())
log_info("Done Assembling %s" % file_path)
return

2 changes: 2 additions & 0 deletions models/hovernet/net_desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def forward(self, imgs):
d[1] = crop_op(d[1], [36, 36])

out_dict = OrderedDict()
out_dict["encoder_features"] = d3

for branch_name, branch_desc in self.decoder.items():
u3 = self.upsample2x(d[-1]) + d[-2]
u3 = branch_desc[0](u3)
Expand Down
6 changes: 4 additions & 2 deletions models/hovernet/run_desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,12 @@ def infer_step(batch_data, model):
type_map = torch.argmax(type_map, dim=-1, keepdim=True)
type_map = type_map.type(torch.float32)
pred_dict["tp"] = type_map
pred_output1 = pred_dict['encoder_features']
del pred_dict['encoder_features']
pred_output = torch.cat(list(pred_dict.values()), -1)

# print("pred_output", pred_dict)
# * Its up to user to define the protocol to process the raw output per step!
return pred_output.cpu().numpy()
return pred_output.cpu().numpy(), pred_output1.cpu().numpy()


####
Expand Down