From 557521d4a297c4dd1444f2dd1a09fba97937e70c Mon Sep 17 00:00:00 2001 From: wangzijian <107230768+wangzijian1010@users.noreply.github.com> Date: Fri, 25 Oct 2024 15:29:58 +0800 Subject: [PATCH] [TRT] Support RealESRGAN (#441) * trt version realesrgan implement * update code --- examples/lite/cv/test_lite_realesrgan.cpp | 21 +++++ lite/models.h | 6 ++ lite/trt/cv/trt_realesrgan.cpp | 109 ++++++++++++++++++++++ lite/trt/cv/trt_realesrgan.h | 27 ++++++ 4 files changed, 163 insertions(+) create mode 100644 lite/trt/cv/trt_realesrgan.cpp create mode 100644 lite/trt/cv/trt_realesrgan.h diff --git a/examples/lite/cv/test_lite_realesrgan.cpp b/examples/lite/cv/test_lite_realesrgan.cpp index 0d1b8fa7..b16ff1b1 100644 --- a/examples/lite/cv/test_lite_realesrgan.cpp +++ b/examples/lite/cv/test_lite_realesrgan.cpp @@ -20,9 +20,30 @@ static void test_default() } +static void test_tensorrt() +{ +#ifdef ENABLE_TENSORRT + std::string engine_path = "../../../examples/hub/trt/RealESRGAN_x4plus_fp16.engine"; + std::string test_img_path = "../../../examples/lite/resources/test_lite_realesrgan.jpg"; + std::string save_img_path = "../../../examples/logs/test_lite_realesrgan_trt.jpg"; + + lite::trt::cv::upscale::RealESRGAN *realesrgan = new lite::trt::cv::upscale::RealESRGAN (engine_path); + + cv::Mat test_image = cv::imread(test_img_path); + + realesrgan->detect(test_image,save_img_path); + + std::cout<<"trt upscale enhance done!"< input; + trtcv::utils::transform::create_tensor(preprocessed_mat, input, input_node_dims, trtcv::utils::transform::CHW); + + cudaError_t status = cudaMemcpyAsync(buffers[0], input.data(), input_size, + cudaMemcpyHostToDevice, stream); + if (status != cudaSuccess) { + std::cerr << "Input copy failed: " << cudaGetErrorString(status) << std::endl; + return; + } + cudaStreamSynchronize(stream); + + nvinfer1::Dims ESRGANDims; + ESRGANDims.nbDims = 4; + ESRGANDims.d[0] = batch_size; + ESRGANDims.d[1] = channels; + ESRGANDims.d[2] = input_h; + ESRGANDims.d[3] = input_w; + + auto input_tensor_name = trt_engine->getIOTensorName(0); + auto output_tensor_name = trt_engine->getIOTensorName(1); + trt_context->setTensorAddress(input_tensor_name, buffers[0]); + trt_context->setTensorAddress(output_tensor_name, buffers[1]); + + trt_context->setInputShape(input_tensor_name, ESRGANDims); + + bool infer_status = trt_context->enqueueV3(stream); + if (!infer_status) { + std::cerr << "TensorRT inference failed!" << std::endl; + return; + } + cudaStreamSynchronize(stream); + + const size_t total_output_elements = batch_size * channels * input_h * 4 * input_w * 4; + std::vector output(total_output_elements); + + status = cudaMemcpyAsync(output.data(), buffers[1], output_size, + cudaMemcpyDeviceToHost, stream); + if (status != cudaSuccess) { + std::cerr << "Output copy failed: " << cudaGetErrorString(status) << std::endl; + return; + } + cudaStreamSynchronize(stream); + + postprocess(output.data(), output_path); +} + +void TRTRealESRGAN::postprocess(float *trt_outputs, const std::string &output_path) { + const int out_h = ori_input_height * 4; + const int out_w = ori_input_width * 4; + const int channel_step = out_h * out_w; + cv::Mat bmat(out_h, out_w, CV_32FC1, trt_outputs); + cv::Mat gmat(out_h, out_w, CV_32FC1, trt_outputs + channel_step); + cv::Mat rmat(out_h, out_w, CV_32FC1, trt_outputs + 2 * channel_step); + bmat *= 255.f; + gmat *= 255.f; + rmat *= 255.f; + std::vector channel_mats = {rmat, gmat, bmat}; + cv::Mat dstimg; + cv::merge(channel_mats,dstimg); + dstimg.convertTo(dstimg, CV_8UC3); + cv::imwrite(output_path,dstimg); +} + + diff --git a/lite/trt/cv/trt_realesrgan.h b/lite/trt/cv/trt_realesrgan.h new file mode 100644 index 00000000..89e6dd48 --- /dev/null +++ b/lite/trt/cv/trt_realesrgan.h @@ -0,0 +1,27 @@ + +// +// Created by wangzijian on 10/25/24. +// + +#ifndef LITE_AI_TOOLKIT_TRT_REALESRGAN_H +#define LITE_AI_TOOLKIT_TRT_REALESRGAN_H +#include "lite/trt/core/trt_core.h" +#include "lite/trt/core/trt_utils.h" + +namespace trtcv{ + class LITE_EXPORTS TRTRealESRGAN : public BasicTRTHandler{ + public: + explicit TRTRealESRGAN(const std::string& _trt_model_path,unsigned int _num_threads = 1): + BasicTRTHandler(_trt_model_path, _num_threads){}; + + private: + int ori_input_width; + int ori_input_height; + void preprocess(const cv::Mat& frame,cv::Mat &output_mat); + void postprocess(float *trt_outputs,const std::string &output_path); + public: + void detect(const cv::Mat &input_mat,const std::string &output_path); + }; +} + +#endif //LITE_AI_TOOLKIT_TRT_REALESRGAN_H \ No newline at end of file