From c903a5ff352b1dd6e09215a25909f1fb69856ffb Mon Sep 17 00:00:00 2001 From: Tanya Date: Mon, 2 Dec 2024 12:21:40 +0530 Subject: [PATCH 1/5] encoder features tile.py --- infer/tile.py | 122 +++++++++++++++++++++++++++++++++++++------------- 1 file changed, 90 insertions(+), 32 deletions(-) diff --git a/infer/tile.py b/infer/tile.py index 3fa60dbe..f66bbd9e 100755 --- a/infer/tile.py +++ b/infer/tile.py @@ -1,6 +1,7 @@ import logging import multiprocessing from multiprocessing import Lock, Pool +import pickle as pk multiprocessing.set_start_method("spawn", True) # ! must be at top for VScode debugging import argparse @@ -45,13 +46,13 @@ #### def _prepare_patching(img, window_size, mask_size, return_src_top_corner=False): """Prepare patch information for tile processing. - + Args: img: original input image window_size: input patch size mask_size: output patch size return_src_top_corner: whether to return coordiante information for top left corner of img - + """ win_size = window_size @@ -96,10 +97,15 @@ def get_last_steps(length, msk_size, step_size): #### def _post_process_patches( - post_proc_func, post_proc_kwargs, patch_info, image_info, overlay_kwargs, + post_proc_func, + post_proc_kwargs, + patch_info, + patch_info1, + image_info, + overlay_kwargs, ): """Apply post processing to patches. - + Args: post_proc_func: post processing function to use post_proc_kwargs: keyword arguments used in post processing function @@ -112,6 +118,9 @@ def _post_process_patches( patch_info = sorted(patch_info, key=lambda x: [x[0][0], x[0][1]]) patch_info, patch_data = zip(*patch_info) + patch_info1 = sorted(patch_info1, key=lambda x: [x[0][0], x[0][1]]) + patch_info1, patch_data1 = zip(*patch_info1) + src_shape = image_info["src_shape"] src_image = image_info["src_image"] @@ -121,12 +130,20 @@ def _post_process_patches( nr_row = max([x[2] for x in patch_info]) + 1 nr_col = max([x[3] for x in patch_info]) + 1 + pred_map = np.concatenate(patch_data, axis=0) + pred_map1 = np.concatenate(patch_data1, axis=0) + # print("pred_map0", pred_map.shape) + pred_map = np.reshape(pred_map, (nr_row, nr_col) + patch_shape) + pred_map = np.transpose(pred_map, axes) + # print("pred_map2", pred_map.shape) + pred_map = np.reshape( pred_map, (patch_shape[0] * nr_row, patch_shape[1] * nr_col, ch) ) + # crop back to original shape pred_map = np.squeeze(pred_map[: src_shape[0], : src_shape[1]]) @@ -140,7 +157,14 @@ def _post_process_patches( src_image.copy(), inst_info_dict, **overlay_kwargs ) - return image_info["name"], pred_map, pred_inst, inst_info_dict, overlaid_img + return ( + image_info["name"], + pred_map, + pred_inst, + inst_info_dict, + overlaid_img, + pred_map1, + ) class InferManager(base.InferManager): @@ -159,36 +183,45 @@ def process_file_list(self, run_args): patterning = lambda x: re.sub("([\[\]])", "[\\1]", x) file_path_list = glob.glob(patterning("%s/*" % self.input_dir)) file_path_list.sort() # ensure same order - assert len(file_path_list) > 0, 'Not Detected Any Files From Path' - - rm_n_mkdir(self.output_dir + '/json/') - rm_n_mkdir(self.output_dir + '/mat/') - rm_n_mkdir(self.output_dir + '/overlay/') + assert len(file_path_list) > 0, "Not Detected Any Files From Path" + + rm_n_mkdir(self.output_dir + "/json/") + rm_n_mkdir(self.output_dir + "/mat/") + rm_n_mkdir(self.output_dir + "/overlay/") + rm_n_mkdir(self.output_dir + "/encoder_features/") + if self.save_qupath: rm_n_mkdir(self.output_dir + "/qupath/") def proc_callback(results): """Post processing callback. - + Output format is implicit assumption, taken from `_post_process_patches` """ - img_name, pred_map, pred_inst, inst_info_dict, overlaid_img = results + ( + img_name, + pred_map, + pred_inst, + inst_info_dict, + overlaid_img, + extract_features, + ) = results nuc_val_list = list(inst_info_dict.values()) # need singleton to make matlab happy - nuc_uid_list = np.array(list(inst_info_dict.keys()))[:,None] - nuc_type_list = np.array([v["type"] for v in nuc_val_list])[:,None] + nuc_uid_list = np.array(list(inst_info_dict.keys()))[:, None] + nuc_type_list = np.array([v["type"] for v in nuc_val_list])[:, None] nuc_coms_list = np.array([v["centroid"] for v in nuc_val_list]) mat_dict = { - "inst_map" : pred_inst, - "inst_uid" : nuc_uid_list, + "inst_map": pred_inst, + "inst_uid": nuc_uid_list, "inst_type": nuc_type_list, - "inst_centroid": nuc_coms_list + "inst_centroid": nuc_coms_list, } - if self.nr_types is None: # matlab does not have None type array - mat_dict.pop("inst_type", None) + if self.nr_types is None: # matlab does not have None type array + mat_dict.pop("inst_type", None) if self.save_raw_map: mat_dict["raw_map"] = pred_map @@ -196,8 +229,18 @@ def proc_callback(results): sio.savemat(save_path, mat_dict) save_path = "%s/overlay/%s.png" % (self.output_dir, img_name) + + # print("pred_map2", pred_map.shape) + cv2.imwrite(save_path, cv2.cvtColor(overlaid_img, cv2.COLOR_RGB2BGR)) + save_path = "%s/encoder_features/%s.npy" % (self.output_dir, img_name) + print("extract_features", extract_features.shape) + with open(save_path, "wb") as f: + np.save(f, extract_features) + # pk.dump(extract_features, f) + # torch.save(extract_features,save_path) + if self.save_qupath: nuc_val_list = list(inst_info_dict.values()) nuc_type_list = np.array([v["type"] for v in nuc_val_list]) @@ -303,17 +346,25 @@ def detach_items_of_uid(items_list, uid, nr_expected_items): ) accumulated_patch_output = [] + accumulated_np = [] for batch_idx, batch_data in enumerate(dataloader): sample_data_list, sample_info_list = batch_data - sample_output_list = self.run_step(sample_data_list) + sample_output_list, encoder_features = self.run_step(sample_data_list) + + sample_info_list_features = sample_info_list.numpy() sample_info_list = sample_info_list.numpy() curr_batch_size = sample_output_list.shape[0] sample_output_list = np.split( sample_output_list, curr_batch_size, axis=0 ) + encoder_features = np.split(encoder_features, curr_batch_size, axis=0) sample_info_list = np.split(sample_info_list, curr_batch_size, axis=0) sample_output_list = list(zip(sample_info_list, sample_output_list)) accumulated_patch_output.extend(sample_output_list) + + sample_info_list_features = np.split(sample_info_list_features, curr_batch_size, axis=0) + sample_info_list_features = list(zip(sample_info_list_features, encoder_features)) + accumulated_np.extend(sample_info_list_features) pbar.update() pbar.close() @@ -324,6 +375,12 @@ def detach_items_of_uid(items_list, uid, nr_expected_items): file_ouput_data, accumulated_patch_output = detach_items_of_uid( accumulated_patch_output, file_idx, image_info[1] ) + file_ouput_data1, accumulated_np = detach_items_of_uid( + accumulated_np, file_idx, image_info[1] + ) + # print("file_ouput_data", len(file_ouput_data)) + + # print("file_ouput_data1",len(file_ouput_data1)) # * detach this into func and multiproc dispatch it src_pos = image_info[2] # src top left corner within padded image @@ -350,10 +407,12 @@ def detach_items_of_uid(items_list, uid, nr_expected_items): "type_colour": self.type_info_dict, "line_thickness": 2, } + func_args = ( self.post_proc_func, post_proc_kwargs, file_ouput_data, + file_ouput_data1, file_info, overlay_kwargs, ) @@ -373,16 +432,15 @@ def detach_items_of_uid(items_list, uid, nr_expected_items): for future in as_completed(future_list): # TODO: way to retrieve which file crashed ? # ! silent crash, cancel all and raise error - if future.exception() is not None: - log_info("Silent Crash") - # ! cancel somehow leads to cascade error later - # ! so just poll it then crash once all future - # ! acquired for now - # for future in future_list: - # future.cancel() - # break - else: - file_path = proc_callback(future.result()) - log_info("Done Assembling %s" % file_path) + # if future.exception() is not None: + # log_info("Silent Crash") + # # ! cancel somehow leads to cascade error later + # # ! so just poll it then crash once all future + # # ! acquired for now + # # for future in future_list: + # # future.cancel() + # # break + # else: + file_path = proc_callback(future.result()) + log_info("Done Assembling %s" % file_path) return - From f5e2465ad0da972e1314b39eee35101ee2cdbbef Mon Sep 17 00:00:00 2001 From: Tanya Date: Mon, 2 Dec 2024 12:22:13 +0530 Subject: [PATCH 2/5] encoder features run_desc.py --- models/hovernet/run_desc.py | 479 +++++++++++------------------------- 1 file changed, 145 insertions(+), 334 deletions(-) diff --git a/models/hovernet/run_desc.py b/models/hovernet/run_desc.py index 026873c3..57ef35a6 100755 --- a/models/hovernet/run_desc.py +++ b/models/hovernet/run_desc.py @@ -1,344 +1,155 @@ +import math +from collections import OrderedDict + import numpy as np -import matplotlib.pyplot as plt import torch +import torch.nn as nn import torch.nn.functional as F -from misc.utils import center_pad_to_shape, cropping_center -from .utils import crop_to_shape, dice_loss, mse_loss, msge_loss, xentropy_loss - -from collections import OrderedDict - -#### -def train_step(batch_data, run_info): - # TODO: synchronize the attach protocol - run_info, state_info = run_info - loss_func_dict = { - "bce": xentropy_loss, - "dice": dice_loss, - "mse": mse_loss, - "msge": msge_loss, - } - # use 'ema' to add for EMA calculation, must be scalar! - result_dict = {"EMA": {}} - track_value = lambda name, value: result_dict["EMA"].update({name: value}) - - #### - model = run_info["net"]["desc"] - optimizer = run_info["net"]["optimizer"] - - #### - imgs = batch_data["img"] - true_np = batch_data["np_map"] - true_hv = batch_data["hv_map"] - - imgs = imgs.to("cuda").type(torch.float32) # to NCHW - imgs = imgs.permute(0, 3, 1, 2).contiguous() - - # HWC - true_np = true_np.to("cuda").type(torch.int64) - true_hv = true_hv.to("cuda").type(torch.float32) - - true_np_onehot = (F.one_hot(true_np, num_classes=2)).type(torch.float32) - true_dict = { - "np": true_np_onehot, - "hv": true_hv, - } - - if model.module.nr_types is not None: - true_tp = batch_data["tp_map"] - true_tp = torch.squeeze(true_tp).to("cuda").type(torch.int64) - true_tp_onehot = F.one_hot(true_tp, num_classes=model.module.nr_types) - true_tp_onehot = true_tp_onehot.type(torch.float32) - true_dict["tp"] = true_tp_onehot - - #### - model.train() - model.zero_grad() # not rnn so not accumulate - - pred_dict = model(imgs) - pred_dict = OrderedDict( - [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] - ) - pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1) - if model.module.nr_types is not None: - pred_dict["tp"] = F.softmax(pred_dict["tp"], dim=-1) - - #### - loss = 0 - loss_opts = run_info["net"]["extra_info"]["loss"] - for branch_name in pred_dict.keys(): - for loss_name, loss_weight in loss_opts[branch_name].items(): - loss_func = loss_func_dict[loss_name] - loss_args = [true_dict[branch_name], pred_dict[branch_name]] - if loss_name == "msge": - loss_args.append(true_np_onehot[..., 1]) - term_loss = loss_func(*loss_args) - track_value("loss_%s_%s" % (branch_name, loss_name), term_loss.cpu().item()) - loss += loss_weight * term_loss - - track_value("overall_loss", loss.cpu().item()) - # * gradient update - - # torch.set_printoptions(precision=10) - loss.backward() - optimizer.step() - #### - - # pick 2 random sample from the batch for visualization - sample_indices = torch.randint(0, true_np.shape[0], (2,)) - - imgs = (imgs[sample_indices]).byte() # to uint8 - imgs = imgs.permute(0, 2, 3, 1).contiguous().cpu().numpy() - - pred_dict["np"] = pred_dict["np"][..., 1] # return pos only - pred_dict = { - k: v[sample_indices].detach().cpu().numpy() for k, v in pred_dict.items() - } - - true_dict["np"] = true_np - true_dict = { - k: v[sample_indices].detach().cpu().numpy() for k, v in true_dict.items() - } - - # * Its up to user to define the protocol to process the raw output per step! - result_dict["raw"] = { # protocol for contents exchange within `raw` - "img": imgs, - "np": (true_dict["np"], pred_dict["np"]), - "hv": (true_dict["hv"], pred_dict["hv"]), - } - return result_dict - +from .net_utils import (DenseBlock, Net, ResidualBlock, TFSamepaddingLayer, + UpSample2x) +from .utils import crop_op, crop_to_shape #### -def valid_step(batch_data, run_info): - run_info, state_info = run_info - #### - model = run_info["net"]["desc"] - model.eval() # infer mode - - #### - imgs = batch_data["img"] - true_np = batch_data["np_map"] - true_hv = batch_data["hv_map"] - - imgs_gpu = imgs.to("cuda").type(torch.float32) # to NCHW - imgs_gpu = imgs_gpu.permute(0, 3, 1, 2).contiguous() - - # HWC - true_np = torch.squeeze(true_np).type(torch.int64) - true_hv = torch.squeeze(true_hv).type(torch.float32) - - true_dict = { - "np": true_np, - "hv": true_hv, - } - - if model.module.nr_types is not None: - true_tp = batch_data["tp_map"] - true_tp = torch.squeeze(true_tp).type(torch.int64) - true_dict["tp"] = true_tp - - # -------------------------------------------------------------- - with torch.no_grad(): # dont compute gradient - pred_dict = model(imgs_gpu) - pred_dict = OrderedDict( - [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] - ) - pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1] - if model.module.nr_types is not None: - type_map = F.softmax(pred_dict["tp"], dim=-1) - type_map = torch.argmax(type_map, dim=-1, keepdim=False) - type_map = type_map.type(torch.float32) - pred_dict["tp"] = type_map - - # * Its up to user to define the protocol to process the raw output per step! - result_dict = { # protocol for contents exchange within `raw` - "raw": { - "imgs": imgs.numpy(), - "true_np": true_dict["np"].numpy(), - "true_hv": true_dict["hv"].numpy(), - "prob_np": pred_dict["np"].cpu().numpy(), - "pred_hv": pred_dict["hv"].cpu().numpy(), - } - } - if model.module.nr_types is not None: - result_dict["raw"]["true_tp"] = true_dict["tp"].numpy() - result_dict["raw"]["pred_tp"] = pred_dict["tp"].cpu().numpy() - return result_dict +class HoVerNet(Net): + """Initialise HoVer-Net.""" + + def __init__(self, input_ch=3, nr_types=None, freeze=False, mode='original'): + super().__init__() + self.mode = mode + self.freeze = freeze + self.nr_types = nr_types + self.output_ch = 3 if nr_types is None else 4 + + assert mode == 'original' or mode == 'fast', \ + 'Unknown mode `%s` for HoVerNet %s. Only support `original` or `fast`.' % mode + + module_list = [ + ("/", nn.Conv2d(input_ch, 64, 7, stride=1, padding=0, bias=False)), + ("bn", nn.BatchNorm2d(64, eps=1e-5)), + ("relu", nn.ReLU(inplace=True)), + ] + if mode == 'fast': # prepend the padding for `fast` mode + module_list = [("pad", TFSamepaddingLayer(ksize=7, stride=1))] + module_list + + self.conv0 = nn.Sequential(OrderedDict(module_list)) + self.d0 = ResidualBlock(64, [1, 3, 1], [64, 64, 256], 3, stride=1) + self.d1 = ResidualBlock(256, [1, 3, 1], [128, 128, 512], 4, stride=2) + self.d2 = ResidualBlock(512, [1, 3, 1], [256, 256, 1024], 6, stride=2) + self.d3 = ResidualBlock(1024, [1, 3, 1], [512, 512, 2048], 3, stride=2) + + self.conv_bot = nn.Conv2d(2048, 1024, 1, stride=1, padding=0, bias=False) + + def create_decoder_branch(out_ch=2, ksize=5): + module_list = [ + ("conva", nn.Conv2d(1024, 256, ksize, stride=1, padding=0, bias=False)), + ("dense", DenseBlock(256, [1, ksize], [128, 32], 8, split=4)), + ("convf", nn.Conv2d(512, 512, 1, stride=1, padding=0, bias=False),), + ] + u3 = nn.Sequential(OrderedDict(module_list)) + + module_list = [ + ("conva", nn.Conv2d(512, 128, ksize, stride=1, padding=0, bias=False)), + ("dense", DenseBlock(128, [1, ksize], [128, 32], 4, split=4)), + ("convf", nn.Conv2d(256, 256, 1, stride=1, padding=0, bias=False),), + ] + u2 = nn.Sequential(OrderedDict(module_list)) + + module_list = [ + ("conva/pad", TFSamepaddingLayer(ksize=ksize, stride=1)), + ("conva", nn.Conv2d(256, 64, ksize, stride=1, padding=0, bias=False),), + ] + u1 = nn.Sequential(OrderedDict(module_list)) + + module_list = [ + ("bn", nn.BatchNorm2d(64, eps=1e-5)), + ("relu", nn.ReLU(inplace=True)), + ("conv", nn.Conv2d(64, out_ch, 1, stride=1, padding=0, bias=True),), + ] + u0 = nn.Sequential(OrderedDict(module_list)) + + decoder = nn.Sequential( + OrderedDict([("u3", u3), ("u2", u2), ("u1", u1), ("u0", u0),]) + ) + return decoder + + ksize = 5 if mode == 'original' else 3 + if nr_types is None: + self.decoder = nn.ModuleDict( + OrderedDict( + [ + ("np", create_decoder_branch(ksize=ksize,out_ch=2)), + ("hv", create_decoder_branch(ksize=ksize,out_ch=2)), + ] + ) + ) + else: + self.decoder = nn.ModuleDict( + OrderedDict( + [ + ("tp", create_decoder_branch(ksize=ksize, out_ch=nr_types)), + ("np", create_decoder_branch(ksize=ksize, out_ch=2)), + ("hv", create_decoder_branch(ksize=ksize, out_ch=2)), + ] + ) + ) + + self.upsample2x = UpSample2x() + # TODO: pytorch still require the channel eventhough its ignored + self.weights_init() + + def forward(self, imgs): + + imgs = imgs / 255.0 # to 0-1 range to match XY + + if self.training: + d0 = self.conv0(imgs) + d0 = self.d0(d0, self.freeze) + with torch.set_grad_enabled(not self.freeze): + d1 = self.d1(d0) + d2 = self.d2(d1) + d3 = self.d3(d2) + d3 = self.conv_bot(d3) + d = [d0, d1, d2, d3] + else: + d0 = self.conv0(imgs) + d0 = self.d0(d0) + d1 = self.d1(d0) + d2 = self.d2(d1) + d3 = self.d3(d2) + d3 = self.conv_bot(d3) + d = [d0, d1, d2, d3] + + # TODO: switch to `crop_to_shape` ? + if self.mode == 'original': + d[0] = crop_op(d[0], [184, 184]) + d[1] = crop_op(d[1], [72, 72]) + else: + d[0] = crop_op(d[0], [92, 92]) + d[1] = crop_op(d[1], [36, 36]) + + out_dict = OrderedDict() + out_dict["encoder_features"] = d3 + + for branch_name, branch_desc in self.decoder.items(): + u3 = self.upsample2x(d[-1]) + d[-2] + u3 = branch_desc[0](u3) + + u2 = self.upsample2x(u3) + d[-3] + u2 = branch_desc[1](u2) + + u1 = self.upsample2x(u2) + d[-4] + u1 = branch_desc[2](u1) + + u0 = branch_desc[3](u1) + out_dict[branch_name] = u0 + + return out_dict #### -def infer_step(batch_data, model): - - #### - patch_imgs = batch_data - - patch_imgs_gpu = patch_imgs.to("cuda").type(torch.float32) # to NCHW - patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() - - #### - model.eval() # infer mode - - # -------------------------------------------------------------- - with torch.no_grad(): # dont compute gradient - pred_dict = model(patch_imgs_gpu) - pred_dict = OrderedDict( - [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] - ) - pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1:] - if "tp" in pred_dict: - type_map = F.softmax(pred_dict["tp"], dim=-1) - type_map = torch.argmax(type_map, dim=-1, keepdim=True) - type_map = type_map.type(torch.float32) - pred_dict["tp"] = type_map - pred_output = torch.cat(list(pred_dict.values()), -1) - - # * Its up to user to define the protocol to process the raw output per step! - return pred_output.cpu().numpy() - - -#### -def viz_step_output(raw_data, nr_types=None): - """ - `raw_data` will be implicitly provided in the similar format as the - return dict from train/valid step, but may have been accumulated across N running step - """ - - imgs = raw_data["img"] - true_np, pred_np = raw_data["np"] - true_hv, pred_hv = raw_data["hv"] - if nr_types is not None: - true_tp, pred_tp = raw_data["tp"] - - aligned_shape = [list(imgs.shape), list(true_np.shape), list(pred_np.shape)] - aligned_shape = np.min(np.array(aligned_shape), axis=0)[1:3] - - cmap = plt.get_cmap("jet") - - def colorize(ch, vmin, vmax): - """ - Will clamp value value outside the provided range to vmax and vmin - """ - ch = np.squeeze(ch.astype("float32")) - ch[ch > vmax] = vmax # clamp value - ch[ch < vmin] = vmin - ch = (ch - vmin) / (vmax - vmin + 1.0e-16) - # take RGB from RGBA heat map - ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8") - # ch_cmap = center_pad_to_shape(ch_cmap, aligned_shape) - return ch_cmap - - viz_list = [] - for idx in range(imgs.shape[0]): - # img = center_pad_to_shape(imgs[idx], aligned_shape) - img = cropping_center(imgs[idx], aligned_shape) - - true_viz_list = [img] - # cmap may randomly fails if of other types - true_viz_list.append(colorize(true_np[idx], 0, 1)) - true_viz_list.append(colorize(true_hv[idx][..., 0], -1, 1)) - true_viz_list.append(colorize(true_hv[idx][..., 1], -1, 1)) - if nr_types is not None: # TODO: a way to pass through external info - true_viz_list.append(colorize(true_tp[idx], 0, nr_types)) - true_viz_list = np.concatenate(true_viz_list, axis=1) - - pred_viz_list = [img] - # cmap may randomly fails if of other types - pred_viz_list.append(colorize(pred_np[idx], 0, 1)) - pred_viz_list.append(colorize(pred_hv[idx][..., 0], -1, 1)) - pred_viz_list.append(colorize(pred_hv[idx][..., 1], -1, 1)) - if nr_types is not None: - pred_viz_list.append(colorize(pred_tp[idx], 0, nr_types)) - pred_viz_list = np.concatenate(pred_viz_list, axis=1) - - viz_list.append(np.concatenate([true_viz_list, pred_viz_list], axis=0)) - viz_list = np.concatenate(viz_list, axis=0) - return viz_list - - -#### -from itertools import chain - - -def proc_valid_step_output(raw_data, nr_types=None): - # TODO: add auto populate from main state track list - track_dict = {"scalar": {}, "image": {}} - - def track_value(name, value, vtype): - return track_dict[vtype].update({name: value}) - - def _dice_info(true, pred, label): - true = np.array(true == label, np.int32) - pred = np.array(pred == label, np.int32) - inter = (pred * true).sum() - total = (pred + true).sum() - return inter, total - - over_inter = 0 - over_total = 0 - over_correct = 0 - prob_np = raw_data["prob_np"] - true_np = raw_data["true_np"] - for idx in range(len(raw_data["true_np"])): - patch_prob_np = prob_np[idx] - patch_true_np = true_np[idx] - patch_pred_np = np.array(patch_prob_np > 0.5, dtype=np.int32) - inter, total = _dice_info(patch_true_np, patch_pred_np, 1) - correct = (patch_pred_np == patch_true_np).sum() - over_inter += inter - over_total += total - over_correct += correct - nr_pixels = len(true_np) * np.size(true_np[0]) - acc_np = over_correct / nr_pixels - dice_np = 2 * over_inter / (over_total + 1.0e-8) - track_value("np_acc", acc_np, "scalar") - track_value("np_dice", dice_np, "scalar") - - # * TP statistic - if nr_types is not None: - pred_tp = raw_data["pred_tp"] - true_tp = raw_data["true_tp"] - for type_id in range(0, nr_types): - over_inter = 0 - over_total = 0 - for idx in range(len(raw_data["true_np"])): - patch_pred_tp = pred_tp[idx] - patch_true_tp = true_tp[idx] - inter, total = _dice_info(patch_true_tp, patch_pred_tp, type_id) - over_inter += inter - over_total += total - dice_tp = 2 * over_inter / (over_total + 1.0e-8) - track_value("tp_dice_%d" % type_id, dice_tp, "scalar") - - # * HV regression statistic - pred_hv = raw_data["pred_hv"] - true_hv = raw_data["true_hv"] - - over_squared_error = 0 - for idx in range(len(raw_data["true_np"])): - patch_pred_hv = pred_hv[idx] - patch_true_hv = true_hv[idx] - squared_error = patch_pred_hv - patch_true_hv - squared_error = squared_error * squared_error - over_squared_error += squared_error.sum() - mse = over_squared_error / nr_pixels - track_value("hv_mse", mse, "scalar") - - # * - imgs = raw_data["imgs"] - selected_idx = np.random.randint(0, len(imgs), size=(8,)).tolist() - imgs = np.array([imgs[idx] for idx in selected_idx]) - true_np = np.array([true_np[idx] for idx in selected_idx]) - true_hv = np.array([true_hv[idx] for idx in selected_idx]) - prob_np = np.array([prob_np[idx] for idx in selected_idx]) - pred_hv = np.array([pred_hv[idx] for idx in selected_idx]) - viz_raw_data = {"img": imgs, "np": (true_np, prob_np), "hv": (true_hv, pred_hv)} - - if nr_types is not None: - true_tp = np.array([true_tp[idx] for idx in selected_idx]) - pred_tp = np.array([pred_tp[idx] for idx in selected_idx]) - viz_raw_data["tp"] = (true_tp, pred_tp) - viz_fig = viz_step_output(viz_raw_data, nr_types) - track_dict["image"]["output"] = viz_fig +def create_model(mode=None, **kwargs): + if mode not in ['original', 'fast']: + assert "Unknown Model Mode %s" % mode + return HoVerNet(mode=mode, **kwargs) - return track_dict From eb9017a039ab3d379ef37433d777d5a8ffb3c241 Mon Sep 17 00:00:00 2001 From: Tanya Date: Mon, 2 Dec 2024 12:39:12 +0530 Subject: [PATCH 3/5] encoder_features net_desc.py --- models/hovernet/net_desc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/models/hovernet/net_desc.py b/models/hovernet/net_desc.py index 745f201b..57ef35a6 100755 --- a/models/hovernet/net_desc.py +++ b/models/hovernet/net_desc.py @@ -129,6 +129,8 @@ def forward(self, imgs): d[1] = crop_op(d[1], [36, 36]) out_dict = OrderedDict() + out_dict["encoder_features"] = d3 + for branch_name, branch_desc in self.decoder.items(): u3 = self.upsample2x(d[-1]) + d[-2] u3 = branch_desc[0](u3) From a914a253663325a4e8593754dd09d4d267c60076 Mon Sep 17 00:00:00 2001 From: Tanya Date: Mon, 2 Dec 2024 12:40:03 +0530 Subject: [PATCH 4/5] Update run_desc.py --- models/hovernet/run_desc.py | 479 +++++++++++++++++++++++++----------- 1 file changed, 334 insertions(+), 145 deletions(-) diff --git a/models/hovernet/run_desc.py b/models/hovernet/run_desc.py index 57ef35a6..026873c3 100755 --- a/models/hovernet/run_desc.py +++ b/models/hovernet/run_desc.py @@ -1,155 +1,344 @@ -import math -from collections import OrderedDict - import numpy as np +import matplotlib.pyplot as plt import torch -import torch.nn as nn import torch.nn.functional as F -from .net_utils import (DenseBlock, Net, ResidualBlock, TFSamepaddingLayer, - UpSample2x) -from .utils import crop_op, crop_to_shape +from misc.utils import center_pad_to_shape, cropping_center +from .utils import crop_to_shape, dice_loss, mse_loss, msge_loss, xentropy_loss + +from collections import OrderedDict + +#### +def train_step(batch_data, run_info): + # TODO: synchronize the attach protocol + run_info, state_info = run_info + loss_func_dict = { + "bce": xentropy_loss, + "dice": dice_loss, + "mse": mse_loss, + "msge": msge_loss, + } + # use 'ema' to add for EMA calculation, must be scalar! + result_dict = {"EMA": {}} + track_value = lambda name, value: result_dict["EMA"].update({name: value}) + + #### + model = run_info["net"]["desc"] + optimizer = run_info["net"]["optimizer"] + + #### + imgs = batch_data["img"] + true_np = batch_data["np_map"] + true_hv = batch_data["hv_map"] + + imgs = imgs.to("cuda").type(torch.float32) # to NCHW + imgs = imgs.permute(0, 3, 1, 2).contiguous() + + # HWC + true_np = true_np.to("cuda").type(torch.int64) + true_hv = true_hv.to("cuda").type(torch.float32) + + true_np_onehot = (F.one_hot(true_np, num_classes=2)).type(torch.float32) + true_dict = { + "np": true_np_onehot, + "hv": true_hv, + } + + if model.module.nr_types is not None: + true_tp = batch_data["tp_map"] + true_tp = torch.squeeze(true_tp).to("cuda").type(torch.int64) + true_tp_onehot = F.one_hot(true_tp, num_classes=model.module.nr_types) + true_tp_onehot = true_tp_onehot.type(torch.float32) + true_dict["tp"] = true_tp_onehot + + #### + model.train() + model.zero_grad() # not rnn so not accumulate + + pred_dict = model(imgs) + pred_dict = OrderedDict( + [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] + ) + pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1) + if model.module.nr_types is not None: + pred_dict["tp"] = F.softmax(pred_dict["tp"], dim=-1) + + #### + loss = 0 + loss_opts = run_info["net"]["extra_info"]["loss"] + for branch_name in pred_dict.keys(): + for loss_name, loss_weight in loss_opts[branch_name].items(): + loss_func = loss_func_dict[loss_name] + loss_args = [true_dict[branch_name], pred_dict[branch_name]] + if loss_name == "msge": + loss_args.append(true_np_onehot[..., 1]) + term_loss = loss_func(*loss_args) + track_value("loss_%s_%s" % (branch_name, loss_name), term_loss.cpu().item()) + loss += loss_weight * term_loss + + track_value("overall_loss", loss.cpu().item()) + # * gradient update + + # torch.set_printoptions(precision=10) + loss.backward() + optimizer.step() + #### + + # pick 2 random sample from the batch for visualization + sample_indices = torch.randint(0, true_np.shape[0], (2,)) + + imgs = (imgs[sample_indices]).byte() # to uint8 + imgs = imgs.permute(0, 2, 3, 1).contiguous().cpu().numpy() + + pred_dict["np"] = pred_dict["np"][..., 1] # return pos only + pred_dict = { + k: v[sample_indices].detach().cpu().numpy() for k, v in pred_dict.items() + } + + true_dict["np"] = true_np + true_dict = { + k: v[sample_indices].detach().cpu().numpy() for k, v in true_dict.items() + } + + # * Its up to user to define the protocol to process the raw output per step! + result_dict["raw"] = { # protocol for contents exchange within `raw` + "img": imgs, + "np": (true_dict["np"], pred_dict["np"]), + "hv": (true_dict["hv"], pred_dict["hv"]), + } + return result_dict + #### -class HoVerNet(Net): - """Initialise HoVer-Net.""" - - def __init__(self, input_ch=3, nr_types=None, freeze=False, mode='original'): - super().__init__() - self.mode = mode - self.freeze = freeze - self.nr_types = nr_types - self.output_ch = 3 if nr_types is None else 4 - - assert mode == 'original' or mode == 'fast', \ - 'Unknown mode `%s` for HoVerNet %s. Only support `original` or `fast`.' % mode - - module_list = [ - ("/", nn.Conv2d(input_ch, 64, 7, stride=1, padding=0, bias=False)), - ("bn", nn.BatchNorm2d(64, eps=1e-5)), - ("relu", nn.ReLU(inplace=True)), - ] - if mode == 'fast': # prepend the padding for `fast` mode - module_list = [("pad", TFSamepaddingLayer(ksize=7, stride=1))] + module_list - - self.conv0 = nn.Sequential(OrderedDict(module_list)) - self.d0 = ResidualBlock(64, [1, 3, 1], [64, 64, 256], 3, stride=1) - self.d1 = ResidualBlock(256, [1, 3, 1], [128, 128, 512], 4, stride=2) - self.d2 = ResidualBlock(512, [1, 3, 1], [256, 256, 1024], 6, stride=2) - self.d3 = ResidualBlock(1024, [1, 3, 1], [512, 512, 2048], 3, stride=2) - - self.conv_bot = nn.Conv2d(2048, 1024, 1, stride=1, padding=0, bias=False) - - def create_decoder_branch(out_ch=2, ksize=5): - module_list = [ - ("conva", nn.Conv2d(1024, 256, ksize, stride=1, padding=0, bias=False)), - ("dense", DenseBlock(256, [1, ksize], [128, 32], 8, split=4)), - ("convf", nn.Conv2d(512, 512, 1, stride=1, padding=0, bias=False),), - ] - u3 = nn.Sequential(OrderedDict(module_list)) - - module_list = [ - ("conva", nn.Conv2d(512, 128, ksize, stride=1, padding=0, bias=False)), - ("dense", DenseBlock(128, [1, ksize], [128, 32], 4, split=4)), - ("convf", nn.Conv2d(256, 256, 1, stride=1, padding=0, bias=False),), - ] - u2 = nn.Sequential(OrderedDict(module_list)) - - module_list = [ - ("conva/pad", TFSamepaddingLayer(ksize=ksize, stride=1)), - ("conva", nn.Conv2d(256, 64, ksize, stride=1, padding=0, bias=False),), - ] - u1 = nn.Sequential(OrderedDict(module_list)) - - module_list = [ - ("bn", nn.BatchNorm2d(64, eps=1e-5)), - ("relu", nn.ReLU(inplace=True)), - ("conv", nn.Conv2d(64, out_ch, 1, stride=1, padding=0, bias=True),), - ] - u0 = nn.Sequential(OrderedDict(module_list)) - - decoder = nn.Sequential( - OrderedDict([("u3", u3), ("u2", u2), ("u1", u1), ("u0", u0),]) - ) - return decoder - - ksize = 5 if mode == 'original' else 3 - if nr_types is None: - self.decoder = nn.ModuleDict( - OrderedDict( - [ - ("np", create_decoder_branch(ksize=ksize,out_ch=2)), - ("hv", create_decoder_branch(ksize=ksize,out_ch=2)), - ] - ) - ) - else: - self.decoder = nn.ModuleDict( - OrderedDict( - [ - ("tp", create_decoder_branch(ksize=ksize, out_ch=nr_types)), - ("np", create_decoder_branch(ksize=ksize, out_ch=2)), - ("hv", create_decoder_branch(ksize=ksize, out_ch=2)), - ] - ) - ) - - self.upsample2x = UpSample2x() - # TODO: pytorch still require the channel eventhough its ignored - self.weights_init() - - def forward(self, imgs): - - imgs = imgs / 255.0 # to 0-1 range to match XY - - if self.training: - d0 = self.conv0(imgs) - d0 = self.d0(d0, self.freeze) - with torch.set_grad_enabled(not self.freeze): - d1 = self.d1(d0) - d2 = self.d2(d1) - d3 = self.d3(d2) - d3 = self.conv_bot(d3) - d = [d0, d1, d2, d3] - else: - d0 = self.conv0(imgs) - d0 = self.d0(d0) - d1 = self.d1(d0) - d2 = self.d2(d1) - d3 = self.d3(d2) - d3 = self.conv_bot(d3) - d = [d0, d1, d2, d3] - - # TODO: switch to `crop_to_shape` ? - if self.mode == 'original': - d[0] = crop_op(d[0], [184, 184]) - d[1] = crop_op(d[1], [72, 72]) - else: - d[0] = crop_op(d[0], [92, 92]) - d[1] = crop_op(d[1], [36, 36]) - - out_dict = OrderedDict() - out_dict["encoder_features"] = d3 - - for branch_name, branch_desc in self.decoder.items(): - u3 = self.upsample2x(d[-1]) + d[-2] - u3 = branch_desc[0](u3) - - u2 = self.upsample2x(u3) + d[-3] - u2 = branch_desc[1](u2) - - u1 = self.upsample2x(u2) + d[-4] - u1 = branch_desc[2](u1) - - u0 = branch_desc[3](u1) - out_dict[branch_name] = u0 - - return out_dict +def valid_step(batch_data, run_info): + run_info, state_info = run_info + #### + model = run_info["net"]["desc"] + model.eval() # infer mode + + #### + imgs = batch_data["img"] + true_np = batch_data["np_map"] + true_hv = batch_data["hv_map"] + + imgs_gpu = imgs.to("cuda").type(torch.float32) # to NCHW + imgs_gpu = imgs_gpu.permute(0, 3, 1, 2).contiguous() + + # HWC + true_np = torch.squeeze(true_np).type(torch.int64) + true_hv = torch.squeeze(true_hv).type(torch.float32) + + true_dict = { + "np": true_np, + "hv": true_hv, + } + + if model.module.nr_types is not None: + true_tp = batch_data["tp_map"] + true_tp = torch.squeeze(true_tp).type(torch.int64) + true_dict["tp"] = true_tp + + # -------------------------------------------------------------- + with torch.no_grad(): # dont compute gradient + pred_dict = model(imgs_gpu) + pred_dict = OrderedDict( + [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] + ) + pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1] + if model.module.nr_types is not None: + type_map = F.softmax(pred_dict["tp"], dim=-1) + type_map = torch.argmax(type_map, dim=-1, keepdim=False) + type_map = type_map.type(torch.float32) + pred_dict["tp"] = type_map + + # * Its up to user to define the protocol to process the raw output per step! + result_dict = { # protocol for contents exchange within `raw` + "raw": { + "imgs": imgs.numpy(), + "true_np": true_dict["np"].numpy(), + "true_hv": true_dict["hv"].numpy(), + "prob_np": pred_dict["np"].cpu().numpy(), + "pred_hv": pred_dict["hv"].cpu().numpy(), + } + } + if model.module.nr_types is not None: + result_dict["raw"]["true_tp"] = true_dict["tp"].numpy() + result_dict["raw"]["pred_tp"] = pred_dict["tp"].cpu().numpy() + return result_dict #### -def create_model(mode=None, **kwargs): - if mode not in ['original', 'fast']: - assert "Unknown Model Mode %s" % mode - return HoVerNet(mode=mode, **kwargs) +def infer_step(batch_data, model): + + #### + patch_imgs = batch_data + + patch_imgs_gpu = patch_imgs.to("cuda").type(torch.float32) # to NCHW + patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() + + #### + model.eval() # infer mode + + # -------------------------------------------------------------- + with torch.no_grad(): # dont compute gradient + pred_dict = model(patch_imgs_gpu) + pred_dict = OrderedDict( + [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] + ) + pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1:] + if "tp" in pred_dict: + type_map = F.softmax(pred_dict["tp"], dim=-1) + type_map = torch.argmax(type_map, dim=-1, keepdim=True) + type_map = type_map.type(torch.float32) + pred_dict["tp"] = type_map + pred_output = torch.cat(list(pred_dict.values()), -1) + + # * Its up to user to define the protocol to process the raw output per step! + return pred_output.cpu().numpy() + + +#### +def viz_step_output(raw_data, nr_types=None): + """ + `raw_data` will be implicitly provided in the similar format as the + return dict from train/valid step, but may have been accumulated across N running step + """ + + imgs = raw_data["img"] + true_np, pred_np = raw_data["np"] + true_hv, pred_hv = raw_data["hv"] + if nr_types is not None: + true_tp, pred_tp = raw_data["tp"] + + aligned_shape = [list(imgs.shape), list(true_np.shape), list(pred_np.shape)] + aligned_shape = np.min(np.array(aligned_shape), axis=0)[1:3] + + cmap = plt.get_cmap("jet") + + def colorize(ch, vmin, vmax): + """ + Will clamp value value outside the provided range to vmax and vmin + """ + ch = np.squeeze(ch.astype("float32")) + ch[ch > vmax] = vmax # clamp value + ch[ch < vmin] = vmin + ch = (ch - vmin) / (vmax - vmin + 1.0e-16) + # take RGB from RGBA heat map + ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8") + # ch_cmap = center_pad_to_shape(ch_cmap, aligned_shape) + return ch_cmap + + viz_list = [] + for idx in range(imgs.shape[0]): + # img = center_pad_to_shape(imgs[idx], aligned_shape) + img = cropping_center(imgs[idx], aligned_shape) + + true_viz_list = [img] + # cmap may randomly fails if of other types + true_viz_list.append(colorize(true_np[idx], 0, 1)) + true_viz_list.append(colorize(true_hv[idx][..., 0], -1, 1)) + true_viz_list.append(colorize(true_hv[idx][..., 1], -1, 1)) + if nr_types is not None: # TODO: a way to pass through external info + true_viz_list.append(colorize(true_tp[idx], 0, nr_types)) + true_viz_list = np.concatenate(true_viz_list, axis=1) + + pred_viz_list = [img] + # cmap may randomly fails if of other types + pred_viz_list.append(colorize(pred_np[idx], 0, 1)) + pred_viz_list.append(colorize(pred_hv[idx][..., 0], -1, 1)) + pred_viz_list.append(colorize(pred_hv[idx][..., 1], -1, 1)) + if nr_types is not None: + pred_viz_list.append(colorize(pred_tp[idx], 0, nr_types)) + pred_viz_list = np.concatenate(pred_viz_list, axis=1) + + viz_list.append(np.concatenate([true_viz_list, pred_viz_list], axis=0)) + viz_list = np.concatenate(viz_list, axis=0) + return viz_list + + +#### +from itertools import chain + + +def proc_valid_step_output(raw_data, nr_types=None): + # TODO: add auto populate from main state track list + track_dict = {"scalar": {}, "image": {}} + + def track_value(name, value, vtype): + return track_dict[vtype].update({name: value}) + + def _dice_info(true, pred, label): + true = np.array(true == label, np.int32) + pred = np.array(pred == label, np.int32) + inter = (pred * true).sum() + total = (pred + true).sum() + return inter, total + + over_inter = 0 + over_total = 0 + over_correct = 0 + prob_np = raw_data["prob_np"] + true_np = raw_data["true_np"] + for idx in range(len(raw_data["true_np"])): + patch_prob_np = prob_np[idx] + patch_true_np = true_np[idx] + patch_pred_np = np.array(patch_prob_np > 0.5, dtype=np.int32) + inter, total = _dice_info(patch_true_np, patch_pred_np, 1) + correct = (patch_pred_np == patch_true_np).sum() + over_inter += inter + over_total += total + over_correct += correct + nr_pixels = len(true_np) * np.size(true_np[0]) + acc_np = over_correct / nr_pixels + dice_np = 2 * over_inter / (over_total + 1.0e-8) + track_value("np_acc", acc_np, "scalar") + track_value("np_dice", dice_np, "scalar") + + # * TP statistic + if nr_types is not None: + pred_tp = raw_data["pred_tp"] + true_tp = raw_data["true_tp"] + for type_id in range(0, nr_types): + over_inter = 0 + over_total = 0 + for idx in range(len(raw_data["true_np"])): + patch_pred_tp = pred_tp[idx] + patch_true_tp = true_tp[idx] + inter, total = _dice_info(patch_true_tp, patch_pred_tp, type_id) + over_inter += inter + over_total += total + dice_tp = 2 * over_inter / (over_total + 1.0e-8) + track_value("tp_dice_%d" % type_id, dice_tp, "scalar") + + # * HV regression statistic + pred_hv = raw_data["pred_hv"] + true_hv = raw_data["true_hv"] + + over_squared_error = 0 + for idx in range(len(raw_data["true_np"])): + patch_pred_hv = pred_hv[idx] + patch_true_hv = true_hv[idx] + squared_error = patch_pred_hv - patch_true_hv + squared_error = squared_error * squared_error + over_squared_error += squared_error.sum() + mse = over_squared_error / nr_pixels + track_value("hv_mse", mse, "scalar") + + # * + imgs = raw_data["imgs"] + selected_idx = np.random.randint(0, len(imgs), size=(8,)).tolist() + imgs = np.array([imgs[idx] for idx in selected_idx]) + true_np = np.array([true_np[idx] for idx in selected_idx]) + true_hv = np.array([true_hv[idx] for idx in selected_idx]) + prob_np = np.array([prob_np[idx] for idx in selected_idx]) + pred_hv = np.array([pred_hv[idx] for idx in selected_idx]) + viz_raw_data = {"img": imgs, "np": (true_np, prob_np), "hv": (true_hv, pred_hv)} + + if nr_types is not None: + true_tp = np.array([true_tp[idx] for idx in selected_idx]) + pred_tp = np.array([pred_tp[idx] for idx in selected_idx]) + viz_raw_data["tp"] = (true_tp, pred_tp) + viz_fig = viz_step_output(viz_raw_data, nr_types) + track_dict["image"]["output"] = viz_fig + return track_dict From a6c7ff0c805acb1ed681bfc0b9d4c8eb5a2b15f9 Mon Sep 17 00:00:00 2001 From: Tanya Date: Tue, 3 Dec 2024 11:43:19 +0530 Subject: [PATCH 5/5] encoder_features run_desc.py --- models/hovernet/run_desc.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/models/hovernet/run_desc.py b/models/hovernet/run_desc.py index 026873c3..9624c7c0 100755 --- a/models/hovernet/run_desc.py +++ b/models/hovernet/run_desc.py @@ -191,10 +191,12 @@ def infer_step(batch_data, model): type_map = torch.argmax(type_map, dim=-1, keepdim=True) type_map = type_map.type(torch.float32) pred_dict["tp"] = type_map + pred_output1 = pred_dict['encoder_features'] + del pred_dict['encoder_features'] pred_output = torch.cat(list(pred_dict.values()), -1) - + # print("pred_output", pred_dict) # * Its up to user to define the protocol to process the raw output per step! - return pred_output.cpu().numpy() + return pred_output.cpu().numpy(), pred_output1.cpu().numpy() ####