diff --git a/src/cpp/src/llm_pipeline_static.cpp b/src/cpp/src/llm_pipeline_static.cpp index a126cbc477..d6b93d8640 100644 --- a/src/cpp/src/llm_pipeline_static.cpp +++ b/src/cpp/src/llm_pipeline_static.cpp @@ -721,6 +721,7 @@ std::shared_ptr StatefulLLMPipeline::setupAndCompileModel( const uint32_t kMaxPromptLen = pop_int_and_cast(pipeline_config, "MAX_PROMPT_LEN").value_or(1024u); const uint32_t kMinResponseLen = pop_int_and_cast(pipeline_config, "MIN_RESPONSE_LEN").value_or(128u); + m_max_prompt_len = kMaxPromptLen; m_kvcache_total = kMaxPromptLen + kMinResponseLen; std::string generate_hint = pop_or_default(pipeline_config, "GENERATE_HINT", "FAST_COMPILE"); @@ -852,8 +853,13 @@ EncodedResults StatefulLLMPipeline::generate( results.scores[0] = 0u; results.tokens.resize(1u); - // TODO: Check if there is enough space in KV-cache to process input prompt + // NB: Check if there is enough space in KV-cache to process input prompt auto prompt_len = input_ids.get_size(); + if (prompt_len > m_max_prompt_len) { + OPENVINO_THROW("Static Stateful LLM pipeline may only process prompts up to " + + std::to_string(m_max_prompt_len) + " tokens. " + + "Set the \"MAX_PROMPT_LEN\" config option to increase the limit."); + } ov::Tensor position_ids{ov::element::i64, input_ids.get_shape()}; utils::initialize_position_ids(position_ids, attention_mask); @@ -877,7 +883,7 @@ EncodedResults StatefulLLMPipeline::generate( auto padded_sequence_len = padded_logits.get_shape()[1]; if (padded_sequence_len > 1) { // If SliceOut is not applied: - logits = make_tensor_slice(padded_logits, 1, padded_sequence_len - input_ids.get_size(), padded_sequence_len); + logits = make_tensor_slice(padded_logits, 1, padded_sequence_len - prompt_len, padded_sequence_len); } int64_t output_sequence_len = logits.get_shape().at(1); @@ -887,7 +893,6 @@ EncodedResults StatefulLLMPipeline::generate( sequence_group->schedule_tokens(output_sequence_len); // NB: Controls what tokens are ready to be pushed into the streamer - // Set max_new_tokens here via get_max_new_token(prompt) GenerationHandle handle = std::make_shared( sequence_group->get_generation_stream(), sequence_group->get_sampling_parameters()); @@ -1323,7 +1328,7 @@ EncodedResults StatelessLLMPipeline::generate( // NB: Check if there is enough space in KV-cache to process input prompt auto prompt_len = input_ids.get_size(); if (prompt_len > m_kvcache_desc.max_prompt_size) { - OPENVINO_THROW("Static LLM pipeline may only process prompts up to " + OPENVINO_THROW("Static Stateless LLM pipeline may only process prompts up to " + std::to_string(m_kvcache_desc.max_prompt_size) + " tokens. " + "Set the \"MAX_PROMPT_LEN\" config option to increase the limit."); } diff --git a/src/cpp/src/llm_pipeline_static.hpp b/src/cpp/src/llm_pipeline_static.hpp index 20a6904f93..0138797a24 100644 --- a/src/cpp/src/llm_pipeline_static.hpp +++ b/src/cpp/src/llm_pipeline_static.hpp @@ -75,6 +75,7 @@ class StatefulLLMPipeline : public LLMPipelineImplBase { void finish_chat() override; private: + uint32_t m_max_prompt_len = 0u; uint32_t m_kvcache_total = 0u; ov::InferRequest m_request;