From 4252d270ebf8e0f02ddf93f47421af71ad8a20c6 Mon Sep 17 00:00:00 2001 From: wangzijian <107230768+wangzijian1010@users.noreply.github.com> Date: Mon, 28 Oct 2024 18:34:08 +0800 Subject: [PATCH] [TRT] support MODNet (#442) * add modnet trt test code * modnet trt implement * update code * add trt modnet --- examples/lite/cv/test_lite_modnet.cpp | 60 ++++++++++- lite/models.h | 6 ++ lite/trt/core/trt_utils.cpp | 42 ++++++++ lite/trt/core/trt_utils.h | 1 + lite/trt/cv/trt_modnet.cpp | 146 ++++++++++++++++++++++++++ lite/trt/cv/trt_modnet.h | 34 ++++++ 6 files changed, 284 insertions(+), 5 deletions(-) create mode 100644 lite/trt/cv/trt_modnet.cpp create mode 100644 lite/trt/cv/trt_modnet.h diff --git a/examples/lite/cv/test_lite_modnet.cpp b/examples/lite/cv/test_lite_modnet.cpp index 4eafc33d..88c2894f 100644 --- a/examples/lite/cv/test_lite_modnet.cpp +++ b/examples/lite/cv/test_lite_modnet.cpp @@ -94,6 +94,55 @@ static void test_onnxruntime() #endif } + + +static void test_tensorrt() +{ +#ifdef ENABLE_TENSORRT + std::string engine_path = "../../../examples/hub/trt/modnet_fp16.engine"; + std::string test_img_path = "../../../examples/lite/resources/test_lite_matting_input.jpg"; + std::string test_bgr_path = "../../../examples/lite/resources/test_lite_matting_bgr.jpg"; + std::string save_fgr_path = "../../../examples/logs/test_lite_modnet_fgr_trt.jpg"; + std::string save_pha_path = "../../../examples/logs/test_lite_modnet_pha_trt.jpg"; + std::string save_merge_path = "../../../examples/logs/test_lite_modnet_merge_trt.jpg"; + std::string save_swap_path = "../../../examples/logs/test_lite_modnet_swap_trt.jpg"; + + + lite::trt::cv::matting::MODNet *modnet = new lite::trt::cv::matting::MODNet (engine_path); + + lite::types::MattingContent content; + cv::Mat img_bgr = cv::imread(test_img_path); + cv::Mat bgr_mat = cv::imread(test_bgr_path); + + // 1. image matting. + modnet->detect(img_bgr, content, true, true); + + if (content.flag) + { + if (!content.fgr_mat.empty()) cv::imwrite(save_fgr_path, content.fgr_mat); + if (!content.pha_mat.empty()) cv::imwrite(save_pha_path, content.pha_mat * 255.); + if (!content.merge_mat.empty()) cv::imwrite(save_merge_path, content.merge_mat); + // swap background + cv::Mat out_mat; + if (!content.fgr_mat.empty()) + lite::utils::swap_background(content.fgr_mat, content.pha_mat, bgr_mat, out_mat, true); + else + lite::utils::swap_background(img_bgr, content.pha_mat, bgr_mat, out_mat, false); + + if (!out_mat.empty()) + { + cv::imwrite(save_swap_path, out_mat); + std::cout << "Saved Swap Image Done!" << std::endl; + } + + std::cout << "Default Version MGMatting Done!" << std::endl; + } + + delete modnet; +#endif +} + + static void test_mnn() { #ifdef ENABLE_MNN @@ -233,11 +282,12 @@ static void test_tnn() static void test_lite() { - test_default(); - test_onnxruntime(); - test_mnn(); - test_ncnn(); - test_tnn(); +// test_default(); +// test_onnxruntime(); +// test_mnn(); +// test_ncnn(); +// test_tnn(); + test_tensorrt(); } int main(__unused int argc, __unused char *argv[]) diff --git a/lite/models.h b/lite/models.h index a5d514d9..60eb82e6 100644 --- a/lite/models.h +++ b/lite/models.h @@ -132,6 +132,7 @@ #include "lite/trt/cv/trt_yolox.h" #include "lite/trt/cv/trt_yolov8.h" #include "lite/trt/cv/trt_yolov6.h" +#include "lite/trt/cv/trt_modnet.h" #include "lite/trt/cv/trt_yolov5_blazeface.h" #include "lite/trt/cv/trt_lightenhance.h" #include "lite/trt/cv/trt_realesrgan.h" @@ -731,9 +732,14 @@ namespace lite{ typedef trtcv::TRTYOLO5Face _TRT_YOLO5Face; typedef trtcv::TRTLightEnhance _TRT_LightEnhance; typedef trtcv::TRTRealESRGAN _TRT_RealESRGAN; + typedef trtcv::TRTMODNet _TRT_MODNet; namespace classification { + } + namespace matting + { + typedef _TRT_MODNet MODNet; } namespace detection { diff --git a/lite/trt/core/trt_utils.cpp b/lite/trt/core/trt_utils.cpp index 823446e9..9bc00a80 100644 --- a/lite/trt/core/trt_utils.cpp +++ b/lite/trt/core/trt_utils.cpp @@ -83,4 +83,46 @@ void trtcv::utils::transform::trt_generate_latents(std::vector &latents, for (size_t i = 0; i < total_size; ++i) { latents[i] = dist(gen) * init_noise_sigma; } +} + +void trtcv::utils::remove_small_connected_area(cv::Mat &alpha_pred, float threshold) { + cv::Mat gray, binary; + alpha_pred.convertTo(gray, CV_8UC1, 255.f); + // 255 * 0.05 ~ 13 + unsigned int binary_threshold = (unsigned int) (255.f * threshold); + // https://github.com/yucornetto/MGMatting/blob/main/code-base/utils/util.py#L209 + cv::threshold(gray, binary, binary_threshold, 255, cv::THRESH_BINARY); + // morphologyEx with OPEN operation to remove noise first. + auto kernel = cv::getStructuringElement(cv::MORPH_ELLIPSE, cv::Size(3, 3), cv::Point(-1, -1)); + cv::morphologyEx(binary, binary, cv::MORPH_OPEN, kernel); + // Computationally connected domain + cv::Mat labels = cv::Mat::zeros(alpha_pred.size(), CV_32S); + cv::Mat stats, centroids; + int num_labels = cv::connectedComponentsWithStats(binary, labels, stats, centroids, 8, 4); + if (num_labels <= 1) return; // no noise, skip. + // find max connected area, 0 is background + int max_connected_id = 1; // 1,2,... + int max_connected_area = stats.at(max_connected_id, cv::CC_STAT_AREA); + for (int i = 1; i < num_labels; ++i) + { + int tmp_connected_area = stats.at(i, cv::CC_STAT_AREA); + if (tmp_connected_area > max_connected_area) + { + max_connected_area = tmp_connected_area; + max_connected_id = i; + } + } + const int h = alpha_pred.rows; + const int w = alpha_pred.cols; + // remove small connected area. + for (int i = 0; i < h; ++i) + { + int *label_row_ptr = labels.ptr(i); + float *alpha_row_ptr = alpha_pred.ptr(i); + for (int j = 0; j < w; ++j) + { + if (label_row_ptr[j] != max_connected_id) + alpha_row_ptr[j] = 0.f; + } + } } \ No newline at end of file diff --git a/lite/trt/core/trt_utils.h b/lite/trt/core/trt_utils.h index 95329d0b..1bdef7bc 100644 --- a/lite/trt/core/trt_utils.h +++ b/lite/trt/core/trt_utils.h @@ -27,6 +27,7 @@ namespace trtcv LITE_EXPORTS void trt_generate_latents(std::vector& latents, int batch_size, int unet_channels, int latent_height, int latent_width, float init_noise_sigma); } + LITE_EXPORTS void remove_small_connected_area(cv::Mat &alpha_pred, float threshold); } } diff --git a/lite/trt/cv/trt_modnet.cpp b/lite/trt/cv/trt_modnet.cpp new file mode 100644 index 00000000..2a590960 --- /dev/null +++ b/lite/trt/cv/trt_modnet.cpp @@ -0,0 +1,146 @@ +// +// Created by wangzijian on 10/28/24. +// + +#include "trt_modnet.h" +using trtcv::TRTMODNet; + +void TRTMODNet::preprocess(cv::Mat &input_mat) { + cv::Mat ori_input_mat = input_mat; + cv::resize(input_mat,input_mat,cv::Size(512,512)); + cv::cvtColor(input_mat,input_mat,cv::COLOR_BGR2RGB); + if (input_mat.type() != CV_32FC3) input_mat.convertTo(input_mat, CV_32FC3); + else input_mat = input_mat; + input_mat = (input_mat -mean_val) * scale_val; + +} + + + +void TRTMODNet::detect(const cv::Mat &mat, types::MattingContent &content, bool remove_noise, bool minimum_post_process) { + if (mat.empty()) return; + cv::Mat preprocessed_mat = mat; + preprocess(preprocessed_mat); + + const int batch_size = 1; + const int channels = 3; + const int input_h = preprocessed_mat.rows; + const int input_w = preprocessed_mat.cols; + const size_t input_size = batch_size * channels * input_h * input_w * sizeof(float); + const size_t output_size = batch_size * channels * input_h * input_w * sizeof(float); + + for (auto& buffer : buffers) { + if (buffer) { + cudaFree(buffer); + buffer = nullptr; + } + } + cudaMalloc(&buffers[0], input_size); + cudaMalloc(&buffers[1], output_size); + if (!buffers[0] || !buffers[1]) { + std::cerr << "Failed to allocate CUDA memory" << std::endl; + return; + } + + input_node_dims = {batch_size, channels, input_h, input_w}; + + std::vector input; + trtcv::utils::transform::create_tensor(preprocessed_mat,input,input_node_dims,trtcv::utils::transform::CHW); + + //3.infer + cudaMemcpyAsync(buffers[0], input.data(), input_size, + cudaMemcpyHostToDevice, stream); + + nvinfer1::Dims MODNetDims; + MODNetDims.nbDims = 4; + MODNetDims.d[0] = batch_size; + MODNetDims.d[1] = channels; + MODNetDims.d[2] = input_h; + MODNetDims.d[3] = input_w; + + auto input_tensor_name = trt_engine->getIOTensorName(0); + auto output_tensor_name = trt_engine->getIOTensorName(1); + trt_context->setTensorAddress(input_tensor_name, buffers[0]); + trt_context->setTensorAddress(output_tensor_name, buffers[1]); + trt_context->setInputShape(input_tensor_name, MODNetDims); + + bool status = trt_context->enqueueV3(stream); + if (!status){ + std::cerr << "Failed to infer by TensorRT." << std::endl; + return; + } + + + + std::vector output(batch_size * channels * input_h * input_w); + cudaMemcpyAsync(output.data(), buffers[1], output_size, + cudaMemcpyDeviceToHost, stream); + + // post + generate_matting(output.data(),mat,content, remove_noise, minimum_post_process); +} + +void TRTMODNet::generate_matting(float *trt_outputs, const cv::Mat &mat, types::MattingContent &content, + bool remove_noise, bool minimum_post_process) { + + const unsigned int h = mat.rows; + const unsigned int w = mat.cols; + + + const unsigned int out_h = 512; + const unsigned int out_w = 512; + + cv::Mat alpha_pred(out_h, out_w, CV_32FC1, trt_outputs); + cv::imwrite("/home/lite.ai.toolkit/modnet.jpg",alpha_pred); + // post process + if (remove_noise) trtcv::utils::remove_small_connected_area(alpha_pred,0.05f); + // resize alpha + if (out_h != h || out_w != w) + // already allocated a new continuous memory after resize. + cv::resize(alpha_pred, alpha_pred, cv::Size(w, h)); + // need clone to allocate a new continuous memory if not performed resize. + // The memory elements point to will release after return. + else alpha_pred = alpha_pred.clone(); + + cv::Mat pmat = alpha_pred; // ref + content.pha_mat = pmat; // auto handle the memory inside ocv with smart ref. + + if (!minimum_post_process) + { + // MODNet only predict Alpha, no fgr. So, + // the fake fgr and merge mat may not need, + // let the fgr mat and merge mat empty to + // Speed up the post processes. + cv::Mat mat_copy; + mat.convertTo(mat_copy, CV_32FC3); + // merge mat and fgr mat may not need + std::vector mat_channels; + cv::split(mat_copy, mat_channels); + cv::Mat bmat = mat_channels.at(0); + cv::Mat gmat = mat_channels.at(1); + cv::Mat rmat = mat_channels.at(2); // ref only, zero-copy. + bmat = bmat.mul(pmat); + gmat = gmat.mul(pmat); + rmat = rmat.mul(pmat); + cv::Mat rest = 1.f - pmat; + cv::Mat mbmat = bmat.mul(pmat) + rest * 153.f; + cv::Mat mgmat = gmat.mul(pmat) + rest * 255.f; + cv::Mat mrmat = rmat.mul(pmat) + rest * 120.f; + std::vector fgr_channel_mats, merge_channel_mats; + fgr_channel_mats.push_back(bmat); + fgr_channel_mats.push_back(gmat); + fgr_channel_mats.push_back(rmat); + merge_channel_mats.push_back(mbmat); + merge_channel_mats.push_back(mgmat); + merge_channel_mats.push_back(mrmat); + + cv::merge(fgr_channel_mats, content.fgr_mat); + cv::merge(merge_channel_mats, content.merge_mat); + + content.fgr_mat.convertTo(content.fgr_mat, CV_8UC3); + content.merge_mat.convertTo(content.merge_mat, CV_8UC3); + } + + content.flag = true; + +} \ No newline at end of file diff --git a/lite/trt/cv/trt_modnet.h b/lite/trt/cv/trt_modnet.h new file mode 100644 index 00000000..d6d2851e --- /dev/null +++ b/lite/trt/cv/trt_modnet.h @@ -0,0 +1,34 @@ +// +// Created by wangzijian on 10/28/24. +// + +#ifndef LITE_AI_TOOLKIT_TRT_MODNET_H +#define LITE_AI_TOOLKIT_TRT_MODNET_H + +#include "lite/trt/core/trt_core.h" +#include "lite/trt/core/trt_utils.h" + +namespace trtcv{ + class LITE_EXPORTS TRTMODNet : public BasicTRTHandler{ + public: + explicit TRTMODNet(const std::string& _trt_model_path,unsigned int _num_threads = 1): + BasicTRTHandler(_trt_model_path, _num_threads) + {}; + private: + static constexpr const float mean_val = 127.5f; // RGB + static constexpr const float scale_val = 1.f / 127.5f; + private: + void preprocess(cv::Mat &input_mat); + + void generate_matting(float *trt_outputs, + const cv::Mat &mat, types::MattingContent &content, + bool remove_noise = false, bool minimum_post_process = false); + public: + void detect(const cv::Mat &mat, types::MattingContent &content, bool remove_noise = false, + bool minimum_post_process = false); + }; +} + + + +#endif //LITE_AI_TOOLKIT_TRT_MODNET_H