-
Notifications
You must be signed in to change notification settings - Fork 707
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TensorRT] [ORT] support Stable Diffusion text2img Pipeline (#432)
* update cpp ddimscheduler.h * add reference github repo * update to clip output to [batch,77,768],and change the clip model onnx input name * update ort clip code * update trt clip code * ort unet code init * add exec name , avoid multi define * add unet test code * add unet code * add unet code in models * update unet code #todo fix nan error * del useless func * 更新unet得到正常结果 但是本次只为一次推理 需要更新为time_step的循环 * add vae test code * add vae model in models.h * add vae test code * update unet code #TODO clean code * update unet.h code #TODO clean code * add vae implement code #TODO clean code * add tensorrt vae implement code #TODO clean code * add tensorrt vae implement code #TODO clean code * modify unet code to use ddim scheduler cpp version # TODO clean code * add tensorrt vae define * 先初始化一个ddim的采样器 也可以使用ort的 * update trt unet in models.h * trt unet code implement * target link ddim.so * update trt unet code * update trt unet test code * update trt vae test code * update trt unet code to fix bug * update code to fp16 * update code * Added dynamic library file for ddim scheduler * Update vae code to remove useless code and functions * update utils code to add new func, add bin to vector * Update the test code of VAE * update trt clip code to clean * clean trt unet code * add new func save vector to bin, and random latents generator * update trt test code #TODO clean code * add ort and trt pipeline in model * ort pipeline implement * trt pipeline implement * new inference interface for pipeline * new inference interface for pipeline * new inference interface for pipeline * new inference interface for pipeline * add pipeline code, include ort and trt * update test code * update test code * add pipeline code * update the code to clean code * update the code to clean code * update onnx path , engine path and save path * combine all sd test files into one pipeline file * delete vae unet clip test code * update pipeline code , opt json file * delay init model on low video memory * update code * update so fliename * update test code * update code ,low varm mode * update code
- Loading branch information
1 parent
fae5d3d
commit 6256f13
Showing
30 changed files
with
1,454 additions
and
121 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
// | ||
// Created by wangzijian on 8/31/24. | ||
// | ||
#include "lite/lite.h" | ||
|
||
static void test_default() | ||
{ | ||
std::string clip_onnx_path = "../../../examples/hub/onnx/sd/clip_model.onnx"; | ||
std::string unet_onnx_path = "../../../examples/hub/onnx/sd/unet_model.onnx"; | ||
std::string vae_onnx_path = "../../../examples/examples/hub/onnx/sd/vae_model.onnx"; | ||
|
||
auto *pipeline = new lite::onnxruntime::sd::pipeline::Pipeline(clip_onnx_path, unet_onnx_path, | ||
vae_onnx_path, | ||
1); | ||
|
||
std::string prompt = "1girl with red hair,blue eyes,smile, looking at viewer"; | ||
std::string negative_prompt = ""; | ||
std::string save_path = "../../../examples/logs/output_merge.png"; | ||
std::string scheduler_config_path = "../../../lite/ort/sd/scheduler_config.json"; | ||
|
||
pipeline->inference(prompt,negative_prompt,save_path,scheduler_config_path); | ||
|
||
delete pipeline; | ||
|
||
} | ||
|
||
|
||
static void test_trt_pipeline() | ||
{ | ||
// 记录时间 | ||
std::chrono::steady_clock::time_point start_time = std::chrono::steady_clock::now(); | ||
|
||
std::string clip_engine_path = "../../../examples/hub/trt/clip_text_model_fp16.engine"; | ||
std::string unet_engine_path = "../../../examples/hub/trt/unet_fp16.engine"; | ||
std::string vae_engine_path = "../../../examples/hub/trt/vae_model_fp16.engine"; | ||
|
||
|
||
auto *pipeline = new lite::trt::sd::pipeline::PipeLine( | ||
clip_engine_path, unet_engine_path, vae_engine_path | ||
); | ||
|
||
|
||
std::string prompt = "1girl with red hair,blue eyes,smile, looking at viewer"; | ||
std::string negative_prompt = ""; | ||
std::string save_path = "../../../examples/logs/output_merge_tensorrt.png"; | ||
std::string scheduler_config_path = "../../../lite/ort/sd/scheduler_config.json"; | ||
pipeline->inference(prompt,negative_prompt,save_path,scheduler_config_path); | ||
|
||
// 记录结束时间并且输出 | ||
std::chrono::steady_clock::time_point end_time = std::chrono::steady_clock::now(); | ||
std::chrono::duration<double> elapsed_seconds = end_time - start_time; | ||
std::cout << "Elapsed time: " << elapsed_seconds.count() << " seconds" << std::endl; | ||
|
||
delete pipeline; | ||
|
||
} | ||
|
||
static void test_lite() | ||
{ | ||
test_trt_pipeline(); | ||
|
||
// test_default(); | ||
} | ||
|
||
int main() | ||
{ | ||
test_lite(); | ||
return 0; | ||
} |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
// | ||
// Created by TalkUHulk on 2024/4/25. | ||
// | ||
|
||
// reference by https://github.com/TalkUHulk/ddim_scheduler_cpp.git | ||
|
||
#ifndef DDIM_SCHEDULER_CPP_DDIMSCHEDULER_HPP | ||
#define DDIM_SCHEDULER_CPP_DDIMSCHEDULER_HPP | ||
|
||
#include <iostream> | ||
#include <vector> | ||
#include <string> | ||
namespace Scheduler { | ||
|
||
#if defined(_MSC_VER) | ||
#if defined(BUILDING_AIENGINE_DLL) | ||
#define DDIM_PUBLIC __declspec(dllexport) | ||
#elif defined(USING_AIENGINE_DLL) | ||
#define DDIM_PUBLIC __declspec(dllimport) | ||
#else | ||
#define DDIM_PUBLIC | ||
#endif | ||
#else | ||
#define DDIM_PUBLIC __attribute__((visibility("default"))) | ||
#endif | ||
|
||
struct DDIMMeta; | ||
class DDIM_PUBLIC DDIMScheduler { | ||
|
||
private: | ||
DDIMMeta* meta_ptr = nullptr; | ||
int num_inference_steps = 0; | ||
|
||
public: | ||
explicit DDIMScheduler(const std::string &config); | ||
|
||
~DDIMScheduler(); | ||
|
||
// Sets the discrete timesteps used for the diffusion chain (to be run before inference). | ||
int set_timesteps(int num_inference_steps); | ||
|
||
void get_timesteps(std::vector<int> &dst); | ||
|
||
float get_init_noise_sigma() const; | ||
|
||
int step(std::vector<float> &model_output, const std::vector<int> &model_output_size, | ||
std::vector<float> &sample, const std::vector<int> &sample_size, | ||
std::vector<float> &prev_sample, | ||
int timestep, float eta = 0.0, bool use_clipped_model_output = false); | ||
|
||
int add_noise(std::vector<float> &sample, const std::vector<int> &sample_size, | ||
std::vector<float> &noise, const std::vector<int> &noise_size, int timesteps, | ||
std::vector<float> &noisy_samples); | ||
private: | ||
float get_variance(int timestep, int prev_timestep); | ||
}; | ||
} | ||
|
||
|
||
#endif //DDIM_SCHEDULER_CPP_DDIMSCHEDULER_HPP |
Oops, something went wrong.