Skip to content

Commit

Permalink
[TensorRT] support YOLOv8 with TensorRT backend (#419)
Browse files Browse the repository at this point in the history
* test yolox code,TODO clean code

* add tensorrt yolox test code

* update var name to trt engine model path

* update transform func,to avoid pointers not being freed

* update transform func,to avoid pointers not being freed

* update infer code to use new transform func

* add yolov8 in namespace

* implement yolov8 code and test code

* add yolov8 compile

* update var name in test code

* add TRTYoloV8 reference

* update code to fix bug
  • Loading branch information
wangzijian1010 authored Jul 25, 2024
1 parent a21114e commit 841eaf2
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 0 deletions.
1 change: 1 addition & 0 deletions examples/lite/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,5 @@ add_lite_executable(lite_yolov6 cv)
add_lite_executable(lite_face_parsing_bisenet cv)
add_lite_executable(lite_face_parsing_bisenet_dyn cv)
add_lite_executable(lite_yolov8face cv)
add_lite_executable(lite_yolov8 cv)

43 changes: 43 additions & 0 deletions examples/lite/cv/test_lite_yolov8.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//
// Created by ai-test1 on 24-7-8.
//

#include "lite/lite.h"



static void test_tensorrt()
{
#ifdef ENABLE_TENSORRT
std::string engine_path = "../../..//examples/hub/trt/yolov8n_fp32.engine";
std::string test_img_path = "../../..//examples/lite/resources/test_lite_yolov5_1.jpg";
std::string save_img_path = "../../..//examples/logs/test_lite_yolov8_trt_1.jpg";

lite::trt::cv::detection::YOLOV8 *yolov8 = new lite::trt::cv::detection::YOLOV8(engine_path);

cv::Mat test_image = cv::imread(test_img_path);

std::vector<lite::types::Boxf> detected_boxes;

yolov8->detect(test_image,detected_boxes,0.5f,0.4f);

std::cout<<"trt yolov8 detect done!"<<std::endl;
lite::utils::draw_boxes_inplace(test_image, detected_boxes);
cv::imwrite(save_img_path, test_image);

delete yolov8;
#endif
}

static void test_lite()
{
test_tensorrt();
}



int main(__unused int argc, __unused char *argv[])
{
test_lite();
return 0;
}
3 changes: 3 additions & 0 deletions lite/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
#include "lite/trt/cv/trt_yolofacev8.h"
#include "lite/trt/cv/trt_yolov5.h"
#include "lite/trt/cv/trt_yolox.h"
#include "lite/trt/cv/trt_yolov8.h"
#endif

// ENABLE_MNN
Expand Down Expand Up @@ -678,6 +679,7 @@ namespace lite{
{
typedef trtcv::TRTYoloFaceV8 _TRT_YOLOFaceNet;
typedef trtcv::TRTYoloV5 _TRT_YOLOv5;
typedef trtcv::TRTYoloV8 _TRT_YOLOv8;
typedef trtcv::TRTYoloX _TRT_YoloX;
namespace classification
{
Expand All @@ -686,6 +688,7 @@ namespace lite{
namespace detection
{
typedef _TRT_YOLOv5 YOLOV5;
typedef _TRT_YOLOv8 YOLOV8;
typedef _TRT_YoloX YoloX;
}
namespace face
Expand Down
1 change: 1 addition & 0 deletions lite/trt/core/trt_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace trtcv{
class LITE_EXPORTS TRTYoloFaceV8; // [1] * reference: https://github.com/derronqi/yolov8-face
class LITE_EXPORTS TRTYoloV5; // [2] * reference: https://github.com/ultralytics/yolov5
class LITE_EXPORTS TRTYoloX; // [3] * reference: https://github.com/Megvii-BaseDetection/YOLOX
class LITE_EXPORTS TRTYoloV8; // [4] * reference: https://github.com/ultralytics/ultralytics/tree/main
}

namespace trtcv{
Expand Down
137 changes: 137 additions & 0 deletions lite/trt/cv/trt_yolov8.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
//
// Created by wangzijian on 7/24/24.
//

#include "trt_yolov8.h"
using trtcv::TRTYoloV8;


void TRTYoloV8::nms(std::vector<types::Boxf> &input, std::vector<types::Boxf> &output,
float iou_threshold, unsigned int topk, unsigned int nms_type)
{
if (nms_type == NMS::BLEND) lite::utils::blending_nms(input, output, iou_threshold, topk);
else if (nms_type == NMS::OFFSET) lite::utils::offset_nms(input, output, iou_threshold, topk);
else lite::utils::hard_nms(input, output, iou_threshold, topk);
}

void TRTYoloV8::generate_bboxes(std::vector<types::Boxf> &bbox_collection, float* output, float score_threshold,
int img_height, int img_width) {
auto pred_dims = output_node_dims[0];
const unsigned int num_anchors = pred_dims[2]; // 8400
const unsigned int num_classes = pred_dims[1] - 4; // 80

float x_factor = float(img_width) / input_node_dims[3];
float y_factor = float(img_height) / input_node_dims[2];

bbox_collection.clear();
unsigned int count = 0;

for (unsigned int i = 0; i < num_anchors; ++i) {

std::vector<float> class_scores(num_classes);
for (unsigned int j = 0; j < num_classes; ++j) {
class_scores[j] = output[(4 + j) * num_anchors + i];
}

auto max_it = std::max_element(class_scores.begin(), class_scores.end());
float max_cls_conf = *max_it;
unsigned int label = std::distance(class_scores.begin(), max_it);

float conf = max_cls_conf;
if (conf < score_threshold) continue;

float cx = output[0 * num_anchors + i];
float cy = output[1 * num_anchors + i];
float w = output[2 * num_anchors + i];
float h = output[3 * num_anchors + i];

float x1 = (cx - w / 2.f) * x_factor;
float y1 = (cy - h / 2.f) * y_factor;

w = w * x_factor;
h = h * y_factor;

float x2 = x1 + w ;
float y2 = y1 + h;

types::Boxf box;
box.x1 = std::max(0.f, x1);
box.y1 = std::max(0.f, y1);
box.x2 = std::min(x2, (float) img_width - 1.f);
box.y2 = std::min(y2, (float) img_height - 1.f);
box.score = conf;
box.label = label;
box.label_text = class_names[label];
box.flag = true;
bbox_collection.push_back(box);

count += 1;
if (count > max_nms)
break;
}

#if LITETRT_DEBUG
std::cout << "detected num_anchors: " << num_anchors << "\n";
std::cout << "generate_bboxes num: " << bbox_collection.size() << "\n";
#endif
}

void TRTYoloV8::preprocess(cv::Mat &input_image) {

// Convert color space from BGR to RGB
cv::cvtColor(input_image, input_image, cv::COLOR_BGR2RGB);

// Resize image
cv::resize(input_image, input_image, cv::Size(input_node_dims[2], input_node_dims[3]), 0, 0, cv::INTER_LINEAR);

// Normalize image
input_image.convertTo(input_image, CV_32F, scale_val, mean_val);
}


void TRTYoloV8::detect(const cv::Mat &mat, std::vector<types::Boxf> &detected_boxes, float score_threshold,
float iou_threshold, unsigned int topk, unsigned int nms_type) {

if (mat.empty()) return;
int img_height = static_cast<int>(mat.rows);
int img_width = static_cast<int>(mat.cols);

// resize & unscale
cv::Mat mat_rs = mat.clone();

preprocess(mat_rs);

//1. make the input
std::vector<float> input;
trtcv::utils::transform::create_tensor(mat_rs,input,input_node_dims,trtcv::utils::transform::CHW);

//2. infer
cudaMemcpyAsync(buffers[0], input.data(), input_node_dims[0] * input_node_dims[1] * input_node_dims[2] * input_node_dims[3] * sizeof(float),
cudaMemcpyHostToDevice, stream);

cudaStreamSynchronize(stream);

bool status = trt_context->enqueueV3(stream);
cudaStreamSynchronize(stream);
if (!status){
std::cerr << "Failed to infer by TensorRT." << std::endl;
return;
}

cudaStreamSynchronize(stream);

// get the first output dim
auto pred_dims = output_node_dims[0];

std::vector<float> output(pred_dims[0] * pred_dims[1] * pred_dims[2]);

cudaMemcpyAsync(output.data(), buffers[1], pred_dims[0] * pred_dims[1] * pred_dims[2] * sizeof(float),
cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);

std::vector<types::Boxf> bbox_collection;
generate_bboxes(bbox_collection,output.data(),score_threshold,img_height,img_width);
nms(bbox_collection, detected_boxes, iou_threshold, topk, nms_type);

}

64 changes: 64 additions & 0 deletions lite/trt/cv/trt_yolov8.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//
// Created by wangzijian on 7/24/24.
//

#ifndef LITE_AI_TOOLKIT_TRT_YOLOV8_H
#define LITE_AI_TOOLKIT_TRT_YOLOV8_H


#include "lite/trt/core/trt_core.h"
#include "lite/utils.h"
#include "lite/trt/core/trt_utils.h"
#include <algorithm>

namespace trtcv {
class LITE_EXPORTS TRTYoloV8 : public BasicTRTHandler {
public:
explicit TRTYoloV8(const std::string &_trt_model_path, unsigned int _num_threads = 1) :
BasicTRTHandler(_trt_model_path, _num_threads) {};

~TRTYoloV8() override = default;


private:
static constexpr const float mean_val = 0.f;
static constexpr const float scale_val = 1.0 / 255.f;
const char *class_names[80] = {
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
"tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard",
"cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase",
"scissors", "teddy bear", "hair drier", "toothbrush"
};
enum NMS {
HARD = 0, BLEND = 1, OFFSET = 2
};
static constexpr const unsigned int max_nms = 30000;

private:

void preprocess(cv::Mat &input_image);

void normalized(cv::Mat &input_image);

void generate_bboxes(
std::vector<types::Boxf> &bbox_collection,
float *output,
float score_threshold, int img_height,
int img_width); // rescale & exclude

void nms(std::vector<types::Boxf> &input, std::vector<types::Boxf> &output,
float iou_threshold, unsigned int topk, unsigned int nms_type);

public:
void detect(const cv::Mat &mat, std::vector<types::Boxf> &detected_boxes,
float score_threshold = 0.25f, float iou_threshold = 0.45f,
unsigned int topk = 100, unsigned int nms_type = NMS::OFFSET);
};

#endif //LITE_AI_TOOLKIT_TRT_YOLOV8_H
}

0 comments on commit 841eaf2

Please sign in to comment.