Skip to content

Commit

Permalink
Enable print properties of compiled model in genai API (#1289)
Browse files Browse the repository at this point in the history
When setting the environment variable OPENVINO_LOG_LEVEL >
ov::log::Level::WARNING, the properties of the compiled model can be
printed in genai API.

When the device is CPU, the properties of the compiled model are as
follows:
Model: Stateful LLM model
  NETWORK_NAME: Model0
  OPTIMAL_NUMBER_OF_INFER_REQUESTS: 1
  NUM_STREAMS: 1
  INFERENCE_NUM_THREADS: 48
  PERF_COUNT: NO
  INFERENCE_PRECISION_HINT: bf16
  PERFORMANCE_HINT: LATENCY
  EXECUTION_MODE_HINT: PERFORMANCE
  PERFORMANCE_HINT_NUM_REQUESTS: 0
  ENABLE_CPU_PINNING: YES
  SCHEDULING_CORE_TYPE: ANY_CORE
  MODEL_DISTRIBUTION_POLICY:
  ENABLE_HYPER_THREADING: NO
  EXECUTION_DEVICES: CPU
  CPU_DENORMALS_OPTIMIZATION: NO
  LOG_LEVEL: LOG_NONE
  CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE: 1
  DYNAMIC_QUANTIZATION_GROUP_SIZE: 32
  KV_CACHE_PRECISION: f16
  AFFINITY: CORE
EXECUTION_DEVICES:
 CPU: Intel(R) Xeon(R) Platinum 8468


[stable_diffusion_compiled_model_log.txt](https://github.com/user-attachments/files/18120641/stable_diffusion_compiled_model_log.txt)

---------

Co-authored-by: Ilya Lavrenov <[email protected]>
  • Loading branch information
wgzintel and ilya-lavrenov authored Dec 17, 2024
1 parent 79f64a6 commit b31b6a1
Show file tree
Hide file tree
Showing 24 changed files with 152 additions and 29 deletions.
1 change: 1 addition & 0 deletions .github/workflows/llm_bench-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ jobs:
SRC_DIR: ${{ github.workspace }}
LLM_BENCH_PYPATH: ${{ github.workspace }}/tools/llm_bench
WWB_PATH: ${{ github.workspace }}/tools/who_what_benchmark
OPENVINO_LOG_LEVEL: 3

steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
Expand Down
4 changes: 4 additions & 0 deletions src/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,7 @@ For information on how OpenVINO™ GenAI works, refer to the [How It Works Secti
## Supported Models

For a list of supported models, refer to the [Supported Models Section](./docs/SUPPORTED_MODELS.md).

## Debug Log

For using debug log, refer to [DEBUG Log](./doc/DEBUG_LOG.md).
4 changes: 3 additions & 1 deletion src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::init(
const ov::AnyMap& properties,
const DeviceConfig& device_config,
ov::Core& core) {
ov::InferRequest infer_request = core.compile_model(model, device_config.get_device(), properties).create_infer_request();
auto compiled_model = core.compile_model(model, device_config.get_device(), properties);
ov::genai::utils::print_compiled_model_properties(compiled_model, "LLM with Paged Attention");
ov::InferRequest infer_request = compiled_model.create_infer_request();

// setup KV caches
m_cache_manager = std::make_shared<CacheManager>(device_config, core);
Expand Down
2 changes: 2 additions & 0 deletions src/cpp/src/image_generation/models/autoencoder_kl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,14 @@ AutoencoderKL& AutoencoderKL::compile(const std::string& device, const ov::AnyMa

if (m_encoder_model) {
ov::CompiledModel encoder_compiled_model = core.compile_model(m_encoder_model, device, properties);
ov::genai::utils::print_compiled_model_properties(encoder_compiled_model, "Auto encoder KL encoder model");
m_encoder_request = encoder_compiled_model.create_infer_request();
// release the original model
m_encoder_model.reset();
}

ov::CompiledModel decoder_compiled_model = core.compile_model(m_decoder_model, device, properties);
ov::genai::utils::print_compiled_model_properties(decoder_compiled_model, "Auto encoder KL decoder model");
m_decoder_request = decoder_compiled_model.create_infer_request();
// release the original model
m_decoder_model.reset();
Expand Down
1 change: 1 addition & 0 deletions src/cpp/src/image_generation/models/clip_text_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ CLIPTextModel& CLIPTextModel::compile(const std::string& device, const ov::AnyMa
} else {
compiled_model = core.compile_model(m_model, device, properties);
}
ov::genai::utils::print_compiled_model_properties(compiled_model, "Clip Text model");
m_request = compiled_model.create_infer_request();
// release the original model
m_model.reset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ CLIPTextModelWithProjection& CLIPTextModelWithProjection::compile(const std::str
} else {
compiled_model = core.compile_model(m_model, device, properties);
}
ov::genai::utils::print_compiled_model_properties(compiled_model, "Clip Text with projection model");
m_request = compiled_model.create_infer_request();
// release the original model
m_model.reset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ FluxTransformer2DModel& FluxTransformer2DModel::reshape(int batch_size,
FluxTransformer2DModel& FluxTransformer2DModel::compile(const std::string& device, const ov::AnyMap& properties) {
OPENVINO_ASSERT(m_model, "Model has been already compiled. Cannot re-compile already compiled model");
ov::CompiledModel compiled_model = utils::singleton_core().compile_model(m_model, device, properties);
ov::genai::utils::print_compiled_model_properties(compiled_model, "Flux Transformer 2D model");
m_request = compiled_model.create_infer_request();
// release the original model
m_model.reset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ SD3Transformer2DModel& SD3Transformer2DModel::reshape(int batch_size,
SD3Transformer2DModel& SD3Transformer2DModel::compile(const std::string& device, const ov::AnyMap& properties) {
OPENVINO_ASSERT(m_model, "Model has been already compiled. Cannot re-compile already compiled model");
ov::CompiledModel compiled_model = utils::singleton_core().compile_model(m_model, device, properties);
ov::genai::utils::print_compiled_model_properties(compiled_model, "SD3 Transformer 2D model");
m_request = compiled_model.create_infer_request();
// release the original model
m_model.reset();
Expand Down
1 change: 1 addition & 0 deletions src/cpp/src/image_generation/models/t5_encoder_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ T5EncoderModel& T5EncoderModel::compile(const std::string& device, const ov::Any
ov::Core core = utils::singleton_core();
ov::CompiledModel compiled_model;
compiled_model = core.compile_model(m_model, device, properties);
ov::genai::utils::print_compiled_model_properties(compiled_model, "T5 encoder model");
m_request = compiled_model.create_infer_request();
// release the original model
m_model.reset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class UNet2DConditionModel::UNetInferenceDynamic : public UNet2DConditionModel::
ov::Core core = utils::singleton_core();

ov::CompiledModel compiled_model = core.compile_model(model, device, properties);
ov::genai::utils::print_compiled_model_properties(compiled_model, "UNet 2D Condition dynamic model");
m_request = compiled_model.create_infer_request();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class UNet2DConditionModel::UNetInferenceStaticBS1 : public UNet2DConditionModel

ov::Core core = utils::singleton_core();
ov::CompiledModel compiled_model = core.compile_model(model, device, properties);
ov::genai::utils::print_compiled_model_properties(compiled_model, "UNet 2D Condition batch-1 model");

for (int i = 0; i < m_native_batch_size; i++)
{
Expand Down
8 changes: 6 additions & 2 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,21 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
const ov::genai::GenerationConfig& generation_config
) : LLMPipelineImplBase(tokenizer, generation_config) {
ov::Core core;
ov::CompiledModel compiled_model;
auto [core_plugin_config, plugin_config] = ov::genai::utils::split_core_compile_config(config);
utils::slice_matmul_statefull_model(model);
m_kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(model);

if (auto filtered_plugin_config = extract_adapters_from_properties(plugin_config, &m_generation_config.adapters)) {
m_generation_config.adapters->set_tensor_name_prefix("base_model.model.model.");
m_adapter_controller = AdapterController(model, *m_generation_config.adapters, device); // TODO: Make the prefix name configurable
m_model_runner = core.compile_model(model, device, *filtered_plugin_config).create_infer_request();
compiled_model = core.compile_model(model, device, *filtered_plugin_config);
m_model_runner = compiled_model.create_infer_request();
} else {
m_model_runner = core.compile_model(model, device, plugin_config).create_infer_request();
compiled_model = core.compile_model(model, device, plugin_config);
m_model_runner = compiled_model.create_infer_request();
}
ov::genai::utils::print_compiled_model_properties(compiled_model, "Stateful LLM model");

// If eos_token_id was not provided, take value
if (m_generation_config.eos_token_id == -1)
Expand Down
13 changes: 8 additions & 5 deletions src/cpp/src/llm_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -777,12 +777,15 @@ void StaticLLMPipeline::setupAndCompileModels(
set_npuw_cache_dir(prefill_config);
set_npuw_cache_dir(generate_config);

m_kvcache_request = core.compile_model(
auto kv_compiled_model = core.compile_model(
kvcache_model, device, generate_config
).create_infer_request();
m_prefill_request = core.compile_model(
prefill_model, device, prefill_config
).create_infer_request();
);
ov::genai::utils::print_compiled_model_properties(kv_compiled_model, "Static LLM kv compiled model");
m_kvcache_request = kv_compiled_model.create_infer_request();

auto prefill_compiled_model = core.compile_model(prefill_model, device, prefill_config);
m_prefill_request = prefill_compiled_model.create_infer_request();
ov::genai::utils::print_compiled_model_properties(prefill_compiled_model, "Static LLM prefill compiled model");
}

void StaticLLMPipeline::setupAndImportModels(
Expand Down
4 changes: 3 additions & 1 deletion src/cpp/src/lora_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,9 @@ class InferRequestSignatureCache {

ov::Core core = ov::genai::utils::singleton_core();
auto model = std::make_shared<ov::Model>(request_results, request_parameters);
rwb.request = core.compile_model(model, device).create_infer_request();
auto compiled_model = core.compile_model(model, device);
ov::genai::utils::print_compiled_model_properties(compiled_model, "Infer Request Signature Cache");
rwb.request = compiled_model.create_infer_request();
requests.emplace(signature, rwb);
}

Expand Down
2 changes: 2 additions & 0 deletions src/cpp/src/tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ class Tokenizer::TokenizerImpl {
manager.register_pass<MakeCombineSegmentsSatateful>();
manager.run_passes(ov_tokenizer);
m_tokenizer = core.compile_model(ov_tokenizer, device, properties);
ov::genai::utils::print_compiled_model_properties(m_tokenizer, "OV Tokenizer");

m_ireq_queue_tokenizer = std::make_unique<CircularBufferQueue<ov::InferRequest>>(
m_tokenizer.get_property(ov::optimal_number_of_infer_requests),
Expand All @@ -216,6 +217,7 @@ class Tokenizer::TokenizerImpl {
manager_detok.register_pass<MakeVocabDecoderSatateful>();
manager_detok.run_passes(ov_detokenizer);
m_detokenizer = core.compile_model(ov_detokenizer, device, properties);
ov::genai::utils::print_compiled_model_properties(m_detokenizer, "OV Detokenizer");

m_ireq_queue_detokenizer = std::make_unique<CircularBufferQueue<ov::InferRequest>>(
m_detokenizer.get_property(ov::optimal_number_of_infer_requests),
Expand Down
37 changes: 37 additions & 0 deletions src/cpp/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,43 @@ void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t se
}
}

void print_compiled_model_properties(ov::CompiledModel& compiled_Model, const char* model_title) {
// Specify the name of the environment variable
const char* env_var_name = "OPENVINO_LOG_LEVEL";
const char* env_var_value = std::getenv(env_var_name);

// Check if the environment variable was found
if (env_var_value != nullptr && atoi(env_var_value) > static_cast<int>(ov::log::Level::WARNING)) {
// output of the actual settings that the device selected
auto supported_properties = compiled_Model.get_property(ov::supported_properties);
std::cout << "Model: " << model_title << std::endl;
for (const auto& cfg : supported_properties) {
if (cfg == ov::supported_properties)
continue;
auto prop = compiled_Model.get_property(cfg);
if (cfg == ov::device::properties) {
auto devices_properties = prop.as<ov::AnyMap>();
for (auto& item : devices_properties) {
std::cout << " " << item.first << ": " << std::endl;
for (auto& item2 : item.second.as<ov::AnyMap>()) {
std::cout << " " << item2.first << ": " << item2.second.as<std::string>() << std::endl;
}
}
} else {
std::cout << " " << cfg << ": " << prop.as<std::string>() << std::endl;
}
}

ov::Core core;
std::vector<std::string> exeTargets;
exeTargets = compiled_Model.get_property(ov::execution_devices);
std::cout << "EXECUTION_DEVICES:" << std::endl;
for (const auto& device : exeTargets) {
std::cout << " " << device << ": " << core.get_property(device, ov::device::full_name) << std::endl;
}
}
}

} // namespace utils
} // namespace genai
} // namespace ov
2 changes: 2 additions & 0 deletions src/cpp/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ size_t get_seq_len_axis(std::shared_ptr<const ov::Model> model);

void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t seq_length_axis, std::optional<AdapterController> adapter_controller);

void print_compiled_model_properties(ov::CompiledModel& compiled_Model, const char* model_title);

} // namespace utils
} // namespace genai
} // namespace ov
1 change: 1 addition & 0 deletions src/cpp/src/visual_language/embedding_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ EmbeddingsModel::EmbeddingsModel(const std::filesystem::path& model_dir,
merge_postprocess(m_model, scale_emb);

ov::CompiledModel compiled_model = core.compile_model(m_model, device, properties);
ov::genai::utils::print_compiled_model_properties(compiled_model, "text embeddings model");
m_request = compiled_model.create_infer_request();
}

Expand Down
7 changes: 4 additions & 3 deletions src/cpp/src/visual_language/inputs_embedder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,10 @@ class InputsEmbedderMiniCPM : public InputsEmbedder::IInputsEmbedder {
const std::string& device,
const ov::AnyMap device_config) :
IInputsEmbedder(vlm_config, model_dir, device, device_config) {
m_resampler = utils::singleton_core().compile_model(
model_dir / "openvino_resampler_model.xml", device, device_config
).create_infer_request();
auto compiled_model =
utils::singleton_core().compile_model(model_dir / "openvino_resampler_model.xml", device, device_config);
ov::genai::utils::print_compiled_model_properties(compiled_model, "VLM resampler model");
m_resampler = compiled_model.create_infer_request();

m_pos_embed_cache = get_2d_sincos_pos_embed(m_vlm_config.hidden_size, {70, 70});
}
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/visual_language/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
auto compiled_language_model = utils::singleton_core().compile_model(
models_dir / "openvino_language_model.xml", device, properties
);

ov::genai::utils::print_compiled_model_properties(compiled_language_model, "VLM language model");
auto language_model = compiled_language_model.get_runtime_model();
m_kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(language_model);

Expand Down
10 changes: 6 additions & 4 deletions src/cpp/src/visual_language/vision_encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -648,10 +648,12 @@ ov::Tensor get_pixel_values_internvl(const ov::Tensor& image, const ProcessorCon

VisionEncoder::VisionEncoder(const std::filesystem::path& model_dir, const VLMModelType model_type, const std::string& device, const ov::AnyMap device_config) :
model_type(model_type) {
m_vision_encoder = utils::singleton_core().compile_model(model_dir / "openvino_vision_embeddings_model.xml", device, device_config).create_infer_request();
m_processor_config = utils::from_config_json_if_exists<ProcessorConfig>(
model_dir, "preprocessor_config.json"
);
auto compiled_model = utils::singleton_core().compile_model(model_dir / "openvino_vision_embeddings_model.xml",
device,
device_config);
ov::genai::utils::print_compiled_model_properties(compiled_model, "VLM vision embeddings model");
m_vision_encoder = compiled_model.create_infer_request();
m_processor_config = utils::from_config_json_if_exists<ProcessorConfig>(model_dir, "preprocessor_config.json");
}

VisionEncoder::VisionEncoder(
Expand Down
21 changes: 12 additions & 9 deletions src/cpp/src/whisper_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,18 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi
auto [core_properties, compile_properties] = ov::genai::utils::split_core_compile_config(properties);
core.set_property(core_properties);

m_models.encoder =
core.compile_model((models_path / "openvino_encoder_model.xml").string(), device, compile_properties)
.create_infer_request();
m_models.decoder =
core.compile_model((models_path / "openvino_decoder_model.xml").string(), device, compile_properties)
.create_infer_request();
m_models.decoder_with_past =
core.compile_model(models_path / "openvino_decoder_with_past_model.xml", device, compile_properties)
.create_infer_request();
ov::CompiledModel compiled_model;
compiled_model =
core.compile_model((models_path / "openvino_encoder_model.xml").string(), device, compile_properties);
ov::genai::utils::print_compiled_model_properties(compiled_model, "whisper encoder model");
m_models.encoder = compiled_model.create_infer_request();
compiled_model =
core.compile_model((models_path / "openvino_decoder_model.xml").string(), device, compile_properties);
ov::genai::utils::print_compiled_model_properties(compiled_model, "whisper decoder model");
m_models.decoder = compiled_model.create_infer_request();
compiled_model = core.compile_model(models_path / "openvino_decoder_with_past_model.xml", device, compile_properties);
m_models.decoder_with_past = compiled_model.create_infer_request();
ov::genai::utils::print_compiled_model_properties(compiled_model, "whisper decoder with past model");

// If eos_token_id was not provided, take value
if (m_generation_config.eos_token_id == -1) {
Expand Down
13 changes: 10 additions & 3 deletions src/cpp/src/whisper_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -555,9 +555,16 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
preprocess_decoder(decoder_model);
preprocess_decoder(decoder_with_past_model);

m_models.encoder = core.compile_model(encoder_model, "NPU").create_infer_request();
m_models.decoder = core.compile_model(decoder_model, "NPU").create_infer_request();
m_models.decoder_with_past = core.compile_model(decoder_with_past_model, "NPU").create_infer_request();
ov::CompiledModel compiled_model;
compiled_model = core.compile_model(encoder_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper encoder model");
m_models.encoder = compiled_model.create_infer_request();
compiled_model = core.compile_model(decoder_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model");
m_models.decoder = compiled_model.create_infer_request();
compiled_model = core.compile_model(decoder_with_past_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder with past model");
m_models.decoder_with_past = compiled_model.create_infer_request();

// If eos_token_id was not provided, take value
if (m_generation_config.eos_token_id == -1) {
Expand Down
43 changes: 43 additions & 0 deletions src/docs/DEBUG_LOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
## 1. Using Debug Log

There are six levels of logs, which can be called explicitly or set via the ``OPENVINO_LOG_LEVEL`` environment variable:

0 - ``ov::log::Level::NO``
1 - ``ov::log::Level::ERR``
2 - ``ov::log::Level::WARNING``
3 - ``ov::log::Level::INFO``
4 - ``ov::log::Level::DEBUG``
5 - ``ov::log::Level::TRACE``

When setting the environment variable OPENVINO_LOG_LEVEL > ov::log::Level::WARNING, the properties of the compiled model can be printed.

For example:

Linux - export OPENVINO_LOG_LEVEL=3
Windows - set OPENVINO_LOG_LEVEL=3

the properties of the compiled model are printed as follows:
```sh
NETWORK_NAME: Model0
OPTIMAL_NUMBER_OF_INFER_REQUESTS: 1
NUM_STREAMS: 1
INFERENCE_NUM_THREADS: 48
PERF_COUNT: NO
INFERENCE_PRECISION_HINT: bf16
PERFORMANCE_HINT: LATENCY
EXECUTION_MODE_HINT: PERFORMANCE
PERFORMANCE_HINT_NUM_REQUESTS: 0
ENABLE_CPU_PINNING: YES
SCHEDULING_CORE_TYPE: ANY_CORE
MODEL_DISTRIBUTION_POLICY:
ENABLE_HYPER_THREADING: NO
EXECUTION_DEVICES: CPU
CPU_DENORMALS_OPTIMIZATION: NO
LOG_LEVEL: LOG_NONE
CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE: 1
DYNAMIC_QUANTIZATION_GROUP_SIZE: 32
KV_CACHE_PRECISION: f16
AFFINITY: CORE
EXECUTION_DEVICES:
CPU: Intel(R) Xeon(R) Platinum 8468
```

0 comments on commit b31b6a1

Please sign in to comment.