Skip to content

Commit

Permalink
[TRT] Support RealESRGAN (#441)
Browse files Browse the repository at this point in the history
* trt version realesrgan implement

* update code
  • Loading branch information
wangzijian1010 authored Oct 25, 2024
1 parent 2458241 commit 557521d
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 0 deletions.
21 changes: 21 additions & 0 deletions examples/lite/cv/test_lite_realesrgan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,30 @@ static void test_default()
}


static void test_tensorrt()
{
#ifdef ENABLE_TENSORRT
std::string engine_path = "../../../examples/hub/trt/RealESRGAN_x4plus_fp16.engine";
std::string test_img_path = "../../../examples/lite/resources/test_lite_realesrgan.jpg";
std::string save_img_path = "../../../examples/logs/test_lite_realesrgan_trt.jpg";

lite::trt::cv::upscale::RealESRGAN *realesrgan = new lite::trt::cv::upscale::RealESRGAN (engine_path);

cv::Mat test_image = cv::imread(test_img_path);

realesrgan->detect(test_image,save_img_path);

std::cout<<"trt upscale enhance done!"<<std::endl;

delete realesrgan;
#endif
}


static void test_lite()
{
test_default();
test_tensorrt();
}

int main(__unused int argc, __unused char *argv[])
Expand Down
6 changes: 6 additions & 0 deletions lite/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
#include "lite/trt/cv/trt_yolov6.h"
#include "lite/trt/cv/trt_yolov5_blazeface.h"
#include "lite/trt/cv/trt_lightenhance.h"
#include "lite/trt/cv/trt_realesrgan.h"
#include "lite/trt/sd/trt_clip.h"
#include "lite/trt/sd/trt_vae.h"
#include "lite/trt/sd/trt_unet.h"
Expand Down Expand Up @@ -729,6 +730,7 @@ namespace lite{
typedef trtcv::TRTYoloV6 _TRT_YOLOv6;
typedef trtcv::TRTYOLO5Face _TRT_YOLO5Face;
typedef trtcv::TRTLightEnhance _TRT_LightEnhance;
typedef trtcv::TRTRealESRGAN _TRT_RealESRGAN;
namespace classification
{

Expand All @@ -752,6 +754,10 @@ namespace lite{
{
typedef _TRT_LightEnhance LightEnhance;
}
namespace upscale
{
typedef _TRT_RealESRGAN RealESRGAN;
}
}

namespace sd
Expand Down
109 changes: 109 additions & 0 deletions lite/trt/cv/trt_realesrgan.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
//
// Created by wangzijian on 10/25/24.
//

#include "trt_realesrgan.h"
using trtcv::TRTRealESRGAN;

void TRTRealESRGAN::preprocess(const cv::Mat &frame, cv::Mat &output_mat) {
cv::cvtColor(frame,output_mat,cv::COLOR_BGR2RGB);
output_mat.convertTo(output_mat,CV_32FC3,1 / 255.f);
}


void TRTRealESRGAN::detect(const cv::Mat &input_mat, const std::string &output_path) {
if (input_mat.empty()) return;

ori_input_width = input_mat.cols;
ori_input_height = input_mat.rows;

cv::Mat preprocessed_mat;
preprocess(input_mat, preprocessed_mat);

const int batch_size = 1;
const int channels = 3;
const int input_h = preprocessed_mat.rows;
const int input_w = preprocessed_mat.cols;
const size_t input_size = batch_size * channels * input_h * input_w * sizeof(float);
const size_t output_size = batch_size * channels * input_h * 4 * input_w * 4 * sizeof(float);

for (auto& buffer : buffers) {
if (buffer) {
cudaFree(buffer);
buffer = nullptr;
}
}

cudaMalloc(&buffers[0], input_size);
cudaMalloc(&buffers[1], output_size);
if (!buffers[0] || !buffers[1]) {
std::cerr << "Failed to allocate CUDA memory" << std::endl;
return;
}

input_node_dims = {batch_size, channels, input_h, input_w};

std::vector<float> input;
trtcv::utils::transform::create_tensor(preprocessed_mat, input, input_node_dims, trtcv::utils::transform::CHW);

cudaError_t status = cudaMemcpyAsync(buffers[0], input.data(), input_size,
cudaMemcpyHostToDevice, stream);
if (status != cudaSuccess) {
std::cerr << "Input copy failed: " << cudaGetErrorString(status) << std::endl;
return;
}
cudaStreamSynchronize(stream);

nvinfer1::Dims ESRGANDims;
ESRGANDims.nbDims = 4;
ESRGANDims.d[0] = batch_size;
ESRGANDims.d[1] = channels;
ESRGANDims.d[2] = input_h;
ESRGANDims.d[3] = input_w;

auto input_tensor_name = trt_engine->getIOTensorName(0);
auto output_tensor_name = trt_engine->getIOTensorName(1);
trt_context->setTensorAddress(input_tensor_name, buffers[0]);
trt_context->setTensorAddress(output_tensor_name, buffers[1]);

trt_context->setInputShape(input_tensor_name, ESRGANDims);

bool infer_status = trt_context->enqueueV3(stream);
if (!infer_status) {
std::cerr << "TensorRT inference failed!" << std::endl;
return;
}
cudaStreamSynchronize(stream);

const size_t total_output_elements = batch_size * channels * input_h * 4 * input_w * 4;
std::vector<float> output(total_output_elements);

status = cudaMemcpyAsync(output.data(), buffers[1], output_size,
cudaMemcpyDeviceToHost, stream);
if (status != cudaSuccess) {
std::cerr << "Output copy failed: " << cudaGetErrorString(status) << std::endl;
return;
}
cudaStreamSynchronize(stream);

postprocess(output.data(), output_path);
}

void TRTRealESRGAN::postprocess(float *trt_outputs, const std::string &output_path) {
const int out_h = ori_input_height * 4;
const int out_w = ori_input_width * 4;
const int channel_step = out_h * out_w;
cv::Mat bmat(out_h, out_w, CV_32FC1, trt_outputs);
cv::Mat gmat(out_h, out_w, CV_32FC1, trt_outputs + channel_step);
cv::Mat rmat(out_h, out_w, CV_32FC1, trt_outputs + 2 * channel_step);
bmat *= 255.f;
gmat *= 255.f;
rmat *= 255.f;
std::vector<cv::Mat> channel_mats = {rmat, gmat, bmat};
cv::Mat dstimg;
cv::merge(channel_mats,dstimg);
dstimg.convertTo(dstimg, CV_8UC3);
cv::imwrite(output_path,dstimg);
}


27 changes: 27 additions & 0 deletions lite/trt/cv/trt_realesrgan.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

//
// Created by wangzijian on 10/25/24.
//

#ifndef LITE_AI_TOOLKIT_TRT_REALESRGAN_H
#define LITE_AI_TOOLKIT_TRT_REALESRGAN_H
#include "lite/trt/core/trt_core.h"
#include "lite/trt/core/trt_utils.h"

namespace trtcv{
class LITE_EXPORTS TRTRealESRGAN : public BasicTRTHandler{
public:
explicit TRTRealESRGAN(const std::string& _trt_model_path,unsigned int _num_threads = 1):
BasicTRTHandler(_trt_model_path, _num_threads){};

private:
int ori_input_width;
int ori_input_height;
void preprocess(const cv::Mat& frame,cv::Mat &output_mat);
void postprocess(float *trt_outputs,const std::string &output_path);
public:
void detect(const cv::Mat &input_mat,const std::string &output_path);
};
}

#endif //LITE_AI_TOOLKIT_TRT_REALESRGAN_H

0 comments on commit 557521d

Please sign in to comment.