Skip to content

Commit

Permalink
Small refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
AsyaPronina committed Jan 10, 2025
1 parent 91b6891 commit e9ff820
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/cpp/src/llm_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,7 @@ std::shared_ptr<ov::CompiledModel> StatefulLLMPipeline::setupAndCompileModel(

const uint32_t kMaxPromptLen = pop_int_and_cast(pipeline_config, "MAX_PROMPT_LEN").value_or(1024u);
const uint32_t kMinResponseLen = pop_int_and_cast(pipeline_config, "MIN_RESPONSE_LEN").value_or(128u);
m_max_prompt_len = kMaxPromptLen;
m_kvcache_total = kMaxPromptLen + kMinResponseLen;
std::string generate_hint = pop_or_default<std::string>(pipeline_config, "GENERATE_HINT", "FAST_COMPILE");

Expand Down Expand Up @@ -852,8 +853,13 @@ EncodedResults StatefulLLMPipeline::generate(
results.scores[0] = 0u;
results.tokens.resize(1u);

// TODO: Check if there is enough space in KV-cache to process input prompt
// NB: Check if there is enough space in KV-cache to process input prompt
auto prompt_len = input_ids.get_size();
if (prompt_len > m_max_prompt_len) {
OPENVINO_THROW("Static Stateful LLM pipeline may only process prompts up to "
+ std::to_string(m_max_prompt_len) + " tokens. "
+ "Set the \"MAX_PROMPT_LEN\" config option to increase the limit.");
}

ov::Tensor position_ids{ov::element::i64, input_ids.get_shape()};
utils::initialize_position_ids(position_ids, attention_mask);
Expand All @@ -877,7 +883,7 @@ EncodedResults StatefulLLMPipeline::generate(
auto padded_sequence_len = padded_logits.get_shape()[1];
if (padded_sequence_len > 1) {
// If SliceOut is not applied:
logits = make_tensor_slice(padded_logits, 1, padded_sequence_len - input_ids.get_size(), padded_sequence_len);
logits = make_tensor_slice(padded_logits, 1, padded_sequence_len - prompt_len, padded_sequence_len);
}
int64_t output_sequence_len = logits.get_shape().at(1);

Expand All @@ -887,7 +893,6 @@ EncodedResults StatefulLLMPipeline::generate(
sequence_group->schedule_tokens(output_sequence_len);

// NB: Controls what tokens are ready to be pushed into the streamer
// Set max_new_tokens here via get_max_new_token(prompt)
GenerationHandle handle = std::make_shared<GenerationHandleImpl>(
sequence_group->get_generation_stream(), sequence_group->get_sampling_parameters());

Expand Down Expand Up @@ -1323,7 +1328,7 @@ EncodedResults StatelessLLMPipeline::generate(
// NB: Check if there is enough space in KV-cache to process input prompt
auto prompt_len = input_ids.get_size();
if (prompt_len > m_kvcache_desc.max_prompt_size) {
OPENVINO_THROW("Static LLM pipeline may only process prompts up to "
OPENVINO_THROW("Static Stateless LLM pipeline may only process prompts up to "
+ std::to_string(m_kvcache_desc.max_prompt_size) + " tokens. "
+ "Set the \"MAX_PROMPT_LEN\" config option to increase the limit.");
}
Expand Down
1 change: 1 addition & 0 deletions src/cpp/src/llm_pipeline_static.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class StatefulLLMPipeline : public LLMPipelineImplBase {
void finish_chat() override;

private:
uint32_t m_max_prompt_len = 0u;
uint32_t m_kvcache_total = 0u;
ov::InferRequest m_request;

Expand Down

0 comments on commit e9ff820

Please sign in to comment.