diff --git a/infer/tile.py b/infer/tile.py index 3fa60dbe..f66bbd9e 100755 --- a/infer/tile.py +++ b/infer/tile.py @@ -1,6 +1,7 @@ import logging import multiprocessing from multiprocessing import Lock, Pool +import pickle as pk multiprocessing.set_start_method("spawn", True) # ! must be at top for VScode debugging import argparse @@ -45,13 +46,13 @@ #### def _prepare_patching(img, window_size, mask_size, return_src_top_corner=False): """Prepare patch information for tile processing. - + Args: img: original input image window_size: input patch size mask_size: output patch size return_src_top_corner: whether to return coordiante information for top left corner of img - + """ win_size = window_size @@ -96,10 +97,15 @@ def get_last_steps(length, msk_size, step_size): #### def _post_process_patches( - post_proc_func, post_proc_kwargs, patch_info, image_info, overlay_kwargs, + post_proc_func, + post_proc_kwargs, + patch_info, + patch_info1, + image_info, + overlay_kwargs, ): """Apply post processing to patches. - + Args: post_proc_func: post processing function to use post_proc_kwargs: keyword arguments used in post processing function @@ -112,6 +118,9 @@ def _post_process_patches( patch_info = sorted(patch_info, key=lambda x: [x[0][0], x[0][1]]) patch_info, patch_data = zip(*patch_info) + patch_info1 = sorted(patch_info1, key=lambda x: [x[0][0], x[0][1]]) + patch_info1, patch_data1 = zip(*patch_info1) + src_shape = image_info["src_shape"] src_image = image_info["src_image"] @@ -121,12 +130,20 @@ def _post_process_patches( nr_row = max([x[2] for x in patch_info]) + 1 nr_col = max([x[3] for x in patch_info]) + 1 + pred_map = np.concatenate(patch_data, axis=0) + pred_map1 = np.concatenate(patch_data1, axis=0) + # print("pred_map0", pred_map.shape) + pred_map = np.reshape(pred_map, (nr_row, nr_col) + patch_shape) + pred_map = np.transpose(pred_map, axes) + # print("pred_map2", pred_map.shape) + pred_map = np.reshape( pred_map, (patch_shape[0] * nr_row, patch_shape[1] * nr_col, ch) ) + # crop back to original shape pred_map = np.squeeze(pred_map[: src_shape[0], : src_shape[1]]) @@ -140,7 +157,14 @@ def _post_process_patches( src_image.copy(), inst_info_dict, **overlay_kwargs ) - return image_info["name"], pred_map, pred_inst, inst_info_dict, overlaid_img + return ( + image_info["name"], + pred_map, + pred_inst, + inst_info_dict, + overlaid_img, + pred_map1, + ) class InferManager(base.InferManager): @@ -159,36 +183,45 @@ def process_file_list(self, run_args): patterning = lambda x: re.sub("([\[\]])", "[\\1]", x) file_path_list = glob.glob(patterning("%s/*" % self.input_dir)) file_path_list.sort() # ensure same order - assert len(file_path_list) > 0, 'Not Detected Any Files From Path' - - rm_n_mkdir(self.output_dir + '/json/') - rm_n_mkdir(self.output_dir + '/mat/') - rm_n_mkdir(self.output_dir + '/overlay/') + assert len(file_path_list) > 0, "Not Detected Any Files From Path" + + rm_n_mkdir(self.output_dir + "/json/") + rm_n_mkdir(self.output_dir + "/mat/") + rm_n_mkdir(self.output_dir + "/overlay/") + rm_n_mkdir(self.output_dir + "/encoder_features/") + if self.save_qupath: rm_n_mkdir(self.output_dir + "/qupath/") def proc_callback(results): """Post processing callback. - + Output format is implicit assumption, taken from `_post_process_patches` """ - img_name, pred_map, pred_inst, inst_info_dict, overlaid_img = results + ( + img_name, + pred_map, + pred_inst, + inst_info_dict, + overlaid_img, + extract_features, + ) = results nuc_val_list = list(inst_info_dict.values()) # need singleton to make matlab happy - nuc_uid_list = np.array(list(inst_info_dict.keys()))[:,None] - nuc_type_list = np.array([v["type"] for v in nuc_val_list])[:,None] + nuc_uid_list = np.array(list(inst_info_dict.keys()))[:, None] + nuc_type_list = np.array([v["type"] for v in nuc_val_list])[:, None] nuc_coms_list = np.array([v["centroid"] for v in nuc_val_list]) mat_dict = { - "inst_map" : pred_inst, - "inst_uid" : nuc_uid_list, + "inst_map": pred_inst, + "inst_uid": nuc_uid_list, "inst_type": nuc_type_list, - "inst_centroid": nuc_coms_list + "inst_centroid": nuc_coms_list, } - if self.nr_types is None: # matlab does not have None type array - mat_dict.pop("inst_type", None) + if self.nr_types is None: # matlab does not have None type array + mat_dict.pop("inst_type", None) if self.save_raw_map: mat_dict["raw_map"] = pred_map @@ -196,8 +229,18 @@ def proc_callback(results): sio.savemat(save_path, mat_dict) save_path = "%s/overlay/%s.png" % (self.output_dir, img_name) + + # print("pred_map2", pred_map.shape) + cv2.imwrite(save_path, cv2.cvtColor(overlaid_img, cv2.COLOR_RGB2BGR)) + save_path = "%s/encoder_features/%s.npy" % (self.output_dir, img_name) + print("extract_features", extract_features.shape) + with open(save_path, "wb") as f: + np.save(f, extract_features) + # pk.dump(extract_features, f) + # torch.save(extract_features,save_path) + if self.save_qupath: nuc_val_list = list(inst_info_dict.values()) nuc_type_list = np.array([v["type"] for v in nuc_val_list]) @@ -303,17 +346,25 @@ def detach_items_of_uid(items_list, uid, nr_expected_items): ) accumulated_patch_output = [] + accumulated_np = [] for batch_idx, batch_data in enumerate(dataloader): sample_data_list, sample_info_list = batch_data - sample_output_list = self.run_step(sample_data_list) + sample_output_list, encoder_features = self.run_step(sample_data_list) + + sample_info_list_features = sample_info_list.numpy() sample_info_list = sample_info_list.numpy() curr_batch_size = sample_output_list.shape[0] sample_output_list = np.split( sample_output_list, curr_batch_size, axis=0 ) + encoder_features = np.split(encoder_features, curr_batch_size, axis=0) sample_info_list = np.split(sample_info_list, curr_batch_size, axis=0) sample_output_list = list(zip(sample_info_list, sample_output_list)) accumulated_patch_output.extend(sample_output_list) + + sample_info_list_features = np.split(sample_info_list_features, curr_batch_size, axis=0) + sample_info_list_features = list(zip(sample_info_list_features, encoder_features)) + accumulated_np.extend(sample_info_list_features) pbar.update() pbar.close() @@ -324,6 +375,12 @@ def detach_items_of_uid(items_list, uid, nr_expected_items): file_ouput_data, accumulated_patch_output = detach_items_of_uid( accumulated_patch_output, file_idx, image_info[1] ) + file_ouput_data1, accumulated_np = detach_items_of_uid( + accumulated_np, file_idx, image_info[1] + ) + # print("file_ouput_data", len(file_ouput_data)) + + # print("file_ouput_data1",len(file_ouput_data1)) # * detach this into func and multiproc dispatch it src_pos = image_info[2] # src top left corner within padded image @@ -350,10 +407,12 @@ def detach_items_of_uid(items_list, uid, nr_expected_items): "type_colour": self.type_info_dict, "line_thickness": 2, } + func_args = ( self.post_proc_func, post_proc_kwargs, file_ouput_data, + file_ouput_data1, file_info, overlay_kwargs, ) @@ -373,16 +432,15 @@ def detach_items_of_uid(items_list, uid, nr_expected_items): for future in as_completed(future_list): # TODO: way to retrieve which file crashed ? # ! silent crash, cancel all and raise error - if future.exception() is not None: - log_info("Silent Crash") - # ! cancel somehow leads to cascade error later - # ! so just poll it then crash once all future - # ! acquired for now - # for future in future_list: - # future.cancel() - # break - else: - file_path = proc_callback(future.result()) - log_info("Done Assembling %s" % file_path) + # if future.exception() is not None: + # log_info("Silent Crash") + # # ! cancel somehow leads to cascade error later + # # ! so just poll it then crash once all future + # # ! acquired for now + # # for future in future_list: + # # future.cancel() + # # break + # else: + file_path = proc_callback(future.result()) + log_info("Done Assembling %s" % file_path) return - diff --git a/models/hovernet/net_desc.py b/models/hovernet/net_desc.py index 745f201b..57ef35a6 100755 --- a/models/hovernet/net_desc.py +++ b/models/hovernet/net_desc.py @@ -129,6 +129,8 @@ def forward(self, imgs): d[1] = crop_op(d[1], [36, 36]) out_dict = OrderedDict() + out_dict["encoder_features"] = d3 + for branch_name, branch_desc in self.decoder.items(): u3 = self.upsample2x(d[-1]) + d[-2] u3 = branch_desc[0](u3) diff --git a/models/hovernet/run_desc.py b/models/hovernet/run_desc.py index 026873c3..9624c7c0 100755 --- a/models/hovernet/run_desc.py +++ b/models/hovernet/run_desc.py @@ -191,10 +191,12 @@ def infer_step(batch_data, model): type_map = torch.argmax(type_map, dim=-1, keepdim=True) type_map = type_map.type(torch.float32) pred_dict["tp"] = type_map + pred_output1 = pred_dict['encoder_features'] + del pred_dict['encoder_features'] pred_output = torch.cat(list(pred_dict.values()), -1) - + # print("pred_output", pred_dict) # * Its up to user to define the protocol to process the raw output per step! - return pred_output.cpu().numpy() + return pred_output.cpu().numpy(), pred_output1.cpu().numpy() ####