Skip to content

Commit

Permalink
[TensorRT] support YOLOv6 with TensorRT backend (#420)
Browse files Browse the repository at this point in the history
* add yolov6 tensorrt test code

* add TRTYolov6 in namespace and models

* implement yolov6 code

* update name
  • Loading branch information
wangzijian1010 authored Jul 26, 2024
1 parent 841eaf2 commit 5103236
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 0 deletions.
24 changes: 24 additions & 0 deletions examples/lite/cv/test_lite_yolov6.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,29 @@ static void test_onnxruntime()
#endif
}

static void test_tensorrt()
{
#ifdef ENABLE_TENSORRT
std::string engine_path = "../../../examples//hub/trt/yolov6s_fp32.engine";
std::string test_img_path = "../../../examples/lite/resources/test_lite_efficientdet.png";
std::string save_img_path = "../../../examples//logs/test_lite_yolov6_2_trt.jpg";

// 1. Test TensorRT Engine
lite::trt::cv::detection::YOLOV6 *yolov6 = new lite::trt::cv::detection::YOLOV6(engine_path);
std::vector<lite::types::Boxf> detected_boxes;
cv::Mat img_bgr = cv::imread(test_img_path);
yolov6->detect(img_bgr, detected_boxes);

lite::utils::draw_boxes_inplace(img_bgr, detected_boxes);

cv::imwrite(save_img_path, img_bgr);

std::cout << "Default Version Detected Boxes Num: " << detected_boxes.size() << std::endl;

delete yolov6;
#endif
}

static void test_mnn()
{
#ifdef ENABLE_MNN
Expand Down Expand Up @@ -136,6 +159,7 @@ static void test_lite()
test_mnn();
test_ncnn();
test_tnn();
test_tensorrt();
}

int main(__unused int argc, __unused char *argv[])
Expand Down
3 changes: 3 additions & 0 deletions lite/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
#include "lite/trt/cv/trt_yolov5.h"
#include "lite/trt/cv/trt_yolox.h"
#include "lite/trt/cv/trt_yolov8.h"
#include "lite/trt/cv/trt_yolov6.h"
#endif

// ENABLE_MNN
Expand Down Expand Up @@ -681,6 +682,7 @@ namespace lite{
typedef trtcv::TRTYoloV5 _TRT_YOLOv5;
typedef trtcv::TRTYoloV8 _TRT_YOLOv8;
typedef trtcv::TRTYoloX _TRT_YoloX;
typedef trtcv::TRTYoloV6 _TRT_YOLOv6;
namespace classification
{

Expand All @@ -690,6 +692,7 @@ namespace lite{
typedef _TRT_YOLOv5 YOLOV5;
typedef _TRT_YOLOv8 YOLOV8;
typedef _TRT_YoloX YoloX;
typedef _TRT_YOLOv6 YOLOV6;
}
namespace face
{
Expand Down
1 change: 1 addition & 0 deletions lite/trt/core/trt_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace trtcv{
class LITE_EXPORTS TRTYoloV5; // [2] * reference: https://github.com/ultralytics/yolov5
class LITE_EXPORTS TRTYoloX; // [3] * reference: https://github.com/Megvii-BaseDetection/YOLOX
class LITE_EXPORTS TRTYoloV8; // [4] * reference: https://github.com/ultralytics/ultralytics/tree/main
class LITE_EXPORTS TRTYoloV6; // [5] * reference: https://github.com/meituan/YOLOv6
}

namespace trtcv{
Expand Down
176 changes: 176 additions & 0 deletions lite/trt/cv/trt_yolov6.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
//
// Created by wangzijian on 7/26/24.
//

#include "trt_yolov6.h"
using trtcv::TRTYoloV6;


void TRTYoloV6::resize_unscale(const cv::Mat &mat, cv::Mat &mat_rs, int target_height, int target_width,
trtcv::TRTYoloV6::YOLOv6ScaleParams &scale_params) {
if (mat.empty()) return;
int img_height = static_cast<int>(mat.rows);
int img_width = static_cast<int>(mat.cols);

mat_rs = cv::Mat(target_height, target_width, CV_8UC3,
cv::Scalar(114, 114, 114));
// scale ratio (new / old) new_shape(h,w)
float w_r = (float) target_width / (float) img_width;
float h_r = (float) target_height / (float) img_height;
float r = std::min(w_r, h_r);
// compute padding
int new_unpad_w = static_cast<int>((float) img_width * r); // floor
int new_unpad_h = static_cast<int>((float) img_height * r); // floor
int pad_w = target_width - new_unpad_w; // >=0
int pad_h = target_height - new_unpad_h; // >=0

int dw = pad_w / 2;
int dh = pad_h / 2;

// resize with unscaling
cv::Mat new_unpad_mat;
// cv::Mat new_unpad_mat = mat.clone(); // may not need clone.
cv::resize(mat, new_unpad_mat, cv::Size(new_unpad_w, new_unpad_h));
new_unpad_mat.copyTo(mat_rs(cv::Rect(dw, dh, new_unpad_w, new_unpad_h)));

// record scale params.
scale_params.r = r;
scale_params.dw = dw;
scale_params.dh = dh;
scale_params.new_unpad_w = new_unpad_w;
scale_params.new_unpad_h = new_unpad_h;
scale_params.flag = true;
}

void TRTYoloV6::nms(std::vector<types::Boxf> &input, std::vector<types::Boxf> &output, float iou_threshold,
unsigned int topk, unsigned int nms_type) {
if (nms_type == NMS::BLEND) lite::utils::blending_nms(input, output, iou_threshold, topk);
else if (nms_type == NMS::OFFSET) lite::utils::offset_nms(input, output, iou_threshold, topk);
else lite::utils::hard_nms(input, output, iou_threshold, topk);
}


void TRTYoloV6::normalized(cv::Mat &input_mat) {
cv::cvtColor(input_mat,input_mat,cv::COLOR_BGR2RGB);
input_mat.convertTo(input_mat,CV_32FC3,1.f / 255.f,0.f);
}

void TRTYoloV6::generate_bboxes(const trtcv::TRTYoloV6::YOLOv6ScaleParams &scale_params,
std::vector<types::Boxf> &bbox_collection, float *output, float score_threshold,
int img_height, int img_width) {
auto pred_dims = output_node_dims[0];
const unsigned int num_anchors = pred_dims.at(1); // n = ?
const unsigned int num_classes = pred_dims.at(2) - 5;

float r_ = scale_params.r;
int dw_ = scale_params.dw;
int dh_ = scale_params.dh;

bbox_collection.clear();
unsigned int count = 0;
for (unsigned int i = 0; i < num_anchors; ++i)
{
float obj_conf = output[i * pred_dims.at(2) + 4];
if (obj_conf < score_threshold) continue; // filter first.

float cls_conf = output[i * pred_dims.at(2) + 5];
unsigned int label = 0;
for (unsigned int j = 0; j < num_classes; ++j)
{
float tmp_conf = output[i * pred_dims.at(2) + 5 + j];
if (tmp_conf > cls_conf)
{
cls_conf = tmp_conf;
label = j;
}
}
float conf = obj_conf * cls_conf; // cls_conf (0.,1.)
if (conf < score_threshold) continue; // filter

float cx = output[i * pred_dims.at(2)];
float cy = output[i * pred_dims.at(2) + 1];
float w = output[i * pred_dims.at(2) + 2];
float h = output[i * pred_dims.at(2) + 3];
float x1 = ((cx - w / 2.f) - (float) dw_) / r_;
float y1 = ((cy - h / 2.f) - (float) dh_) / r_;
float x2 = ((cx + w / 2.f) - (float) dw_) / r_;
float y2 = ((cy + h / 2.f) - (float) dh_) / r_;

types::Boxf box;
box.x1 = std::max(0.f, x1);
box.y1 = std::max(0.f, y1);
box.x2 = std::min(x2, (float) img_width - 1.f);
box.y2 = std::min(y2, (float) img_height - 1.f);
box.score = conf;
box.label = label;
box.label_text = class_names[label];
box.flag = true;
bbox_collection.push_back(box);

count += 1; // limit boxes for nms.
if (count > max_nms)
break;
}

#if LITETRT_DEBUG
std::cout << "detected num_anchors: " << num_anchors << "\n";
std::cout << "generate_bboxes num: " << bbox_collection.size() << "\n";
#endif



}


void TRTYoloV6::detect(const cv::Mat &mat, std::vector<types::Boxf> &detected_boxes, float score_threshold,
float iou_threshold, unsigned int topk, unsigned int nms_type) {

if (mat.empty()) return;
// this->transform(mat);
const int input_height = input_node_dims.at(2);
const int input_width = input_node_dims.at(3);
int img_height = static_cast<int>(mat.rows);
int img_width = static_cast<int>(mat.cols);

// resize & unscale
cv::Mat mat_rs;
YOLOv6ScaleParams scale_params;
resize_unscale(mat, mat_rs, input_height, input_width, scale_params);

normalized(mat_rs);

// 1.make input
std::vector<float> input;
trtcv::utils::transform::create_tensor(mat_rs,input,input_node_dims,trtcv::utils::transform::CHW);

//2. infer
cudaMemcpyAsync(buffers[0], input.data(), input_node_dims[0] * input_node_dims[1] * input_node_dims[2] * input_node_dims[3] * sizeof(float),
cudaMemcpyHostToDevice, stream);
cudaStreamSynchronize(stream);


bool status = trt_context->enqueueV3(stream);
cudaStreamSynchronize(stream);
if (!status){
std::cerr << "Failed to infer by TensorRT." << std::endl;
return;
}

// Synchronize the stream to ensure all operations are complete
cudaStreamSynchronize(stream);
// get the first output dim
auto pred_dims = output_node_dims[0];

std::vector<float> output(pred_dims[0] * pred_dims[1] * pred_dims[2]);

cudaMemcpyAsync(output.data(), buffers[1], pred_dims[0] * pred_dims[1] * pred_dims[2] * sizeof(float),
cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);

//3. generate the boxes
std::vector<types::Boxf> bbox_collection;
generate_bboxes(scale_params, bbox_collection, output.data(), score_threshold, img_height, img_width);
nms(bbox_collection, detected_boxes, iou_threshold, topk, nms_type);

}

80 changes: 80 additions & 0 deletions lite/trt/cv/trt_yolov6.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
//
// Created by wangzijian on 7/20/24.
//

#ifndef LITE_AI_TOOLKIT_TRT_YOLOV6_H
#define LITE_AI_TOOLKIT_TRT_YOLOV6_H

#include "lite/trt/core/trt_core.h"
#include "lite/utils.h"
#include "lite/trt/core/trt_utils.h"

namespace trtcv
{
class LITE_EXPORTS TRTYoloV6 : public BasicTRTHandler
{
public:
explicit TRTYoloV6(const std::string &_trt_model_path, unsigned int _num_threads = 1) :
BasicTRTHandler(_trt_model_path, _num_threads)
{};

~TRTYoloV6() override = default;

private:
// nested classes
typedef struct
{
float r;
int dw;
int dh;
int new_unpad_w;
int new_unpad_h;
bool flag;
} YOLOv6ScaleParams;

private:
const char *class_names[80] = {
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
"tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard",
"cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase",
"scissors", "teddy bear", "hair drier", "toothbrush"
};
enum NMS
{
HARD = 0, BLEND = 1, OFFSET = 2
};
static constexpr const unsigned int max_nms = 30000;

private:

void resize_unscale(const cv::Mat &mat,
cv::Mat &mat_rs,
int target_height,
int target_width,
YOLOv6ScaleParams &scale_params);

void normalized(cv::Mat &input_mat);

void generate_bboxes(const YOLOv6ScaleParams &scale_params,
std::vector<types::Boxf> &bbox_collection,
float* output,
float score_threshold, int img_height,
int img_width); // r rescale & exclude


void nms(std::vector<types::Boxf> &input, std::vector<types::Boxf> &output,
float iou_threshold, unsigned int topk, unsigned int nms_type);

public:
void detect(const cv::Mat &mat, std::vector<types::Boxf> &detected_boxes,
float score_threshold = 0.25f, float iou_threshold = 0.45f,
unsigned int topk = 100, unsigned int nms_type = NMS::OFFSET);
};
}
#endif //LITE_AI_TOOLKIT_TRT_YOLOV6_H

0 comments on commit 5103236

Please sign in to comment.