From bde3a3ddfc51091287b1f79855dd6ff1aadc8f16 Mon Sep 17 00:00:00 2001 From: sbalandi Date: Mon, 13 Jan 2025 12:49:57 +0000 Subject: [PATCH 1/3] Automatically apply chat template in non-chat scenarios --- README.md | 1 - src/README.md | 2 ++ src/cpp/src/icontinuous_batching.cpp | 12 ++++++++- src/cpp/src/llm_pipeline_stateful.cpp | 25 +++++++++++++++++-- src/cpp/src/llm_pipeline_static.cpp | 20 +++++++++++++-- .../src/visual_language/inputs_embedder.cpp | 13 +++++++++- 6 files changed, 66 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index cea1e358bc..221a81c6c3 100644 --- a/README.md +++ b/README.md @@ -133,7 +133,6 @@ from PIL import Image # Choose GPU instead of CPU in the line below to run the model on Intel integrated or discrete GPU pipe = openvino_genai.VLMPipeline("./InternVL2-1B", "CPU") -pipe.start_chat() image = Image.open("dog.jpg") image_data = np.array(image.getdata()).reshape(1, image.size[1], image.size[0], 3).astype(np.uint8) diff --git a/src/README.md b/src/README.md index af4953f98a..42d8aa9dde 100644 --- a/src/README.md +++ b/src/README.md @@ -73,6 +73,8 @@ output: 'it is made up of carbon atoms. The carbon atoms are arranged in a linear pattern, which gives the yellow color. The arrangement of carbon atoms in' ``` +>**Note**: The chat_template from tokenizer_config.json will be automatically applied to the prompt at the generation stage. If you want to disable it, you can do it by calling pipe.get_tokenizer().set_chat_template(""). + A simple chat in Python: ```python import openvino_genai as ov_genai diff --git a/src/cpp/src/icontinuous_batching.cpp b/src/cpp/src/icontinuous_batching.cpp index 8fbb9619ea..03cfbc89e9 100644 --- a/src/cpp/src/icontinuous_batching.cpp +++ b/src/cpp/src/icontinuous_batching.cpp @@ -55,7 +55,17 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate( timer.start(); for (const std::string& prompt : prompts) { const auto encode_start = std::chrono::steady_clock::now(); - input_ids.push_back(m_tokenizer.encode(prompt).input_ids); + ov::Tensor encoded_inputs; + try { + ChatHistory history({{{"role", "user"}, {"content", prompt}}}); + constexpr bool add_generation_prompt = true; + auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt); + encoded_inputs = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false)).input_ids; + } catch (const std::exception& error) { + // in case when chat_template was not found in tokenizer_config.json or set + encoded_inputs = m_tokenizer.encode(prompt).input_ids; + } + input_ids.push_back(encoded_inputs); tokenization_durations.emplace_back(PerfMetrics::get_microsec(std::chrono::steady_clock::now() - encode_start)); } timer.end(); diff --git a/src/cpp/src/llm_pipeline_stateful.cpp b/src/cpp/src/llm_pipeline_stateful.cpp index 8451709092..83e5b3c9ea 100644 --- a/src/cpp/src/llm_pipeline_stateful.cpp +++ b/src/cpp/src/llm_pipeline_stateful.cpp @@ -88,7 +88,19 @@ DecodedResults StatefulLLMPipeline::generate( if (auto input_vector = std::get_if>(&inputs)) { OPENVINO_ASSERT(!is_chat_conversation, "Can't chat with multiple prompts"); - encoded_input = m_tokenizer.encode(*input_vector); + std::vector templated_input_vector; + for (auto& input : *input_vector) { + try { + ChatHistory history({{{"role", "user"}, {"content", input}}}); + constexpr bool add_generation_prompt = true; + auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt); + templated_input_vector.push_back(templated_prompt); + } catch (const std::exception& error) { + // in case when chat_template was not found in tokenizer_config.json or set + templated_input_vector.push_back(input); + } + } + encoded_input = m_tokenizer.encode(templated_input_vector, ov::genai::add_special_tokens(false)); } else if (auto input_prompt = std::get_if(&inputs)) { std::string& prompt = *input_prompt; @@ -157,7 +169,16 @@ DecodedResults StatefulLLMPipeline::generate( // TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied } else { - encoded_input = m_tokenizer.encode(prompt); + std::string& prompt = *input_prompt; + try { + ChatHistory history({{{"role", "user"}, {"content", prompt}}}); + constexpr bool add_generation_prompt = true; + auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt); + encoded_input = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false)); + } catch (const std::exception& error) { + // in case when chat_template was not found in tokenizer_config.json or set + encoded_input = m_tokenizer.encode(prompt); + } } } diff --git a/src/cpp/src/llm_pipeline_static.cpp b/src/cpp/src/llm_pipeline_static.cpp index b29bec3b4a..49f955ed41 100644 --- a/src/cpp/src/llm_pipeline_static.cpp +++ b/src/cpp/src/llm_pipeline_static.cpp @@ -784,7 +784,15 @@ DecodedResults StatefulLLMPipeline::generate( // for chat ov::genai::add_special_tokens(false) is aligned with stateful pipeline and HF tokenized_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(false)); } else { - tokenized_input = m_tokenizer.encode(prompt); + try { + ChatHistory history({{{"role", "user"}, {"content", prompt}}}); + constexpr bool add_generation_prompt = true; + auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt); + tokenized_input = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false)); + } catch (const std::exception& error) { + // in case when chat_template was not found in tokenizer_config.json or set + tokenized_input = m_tokenizer.encode(prompt); + } } auto encode_stop_time = std::chrono::steady_clock::now(); @@ -1252,7 +1260,15 @@ DecodedResults StatelessLLMPipeline::generate( // for chat ov::genai::add_special_tokens(false) is aligned with stateful pipeline and HF tokenized_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(false)); } else { - tokenized_input = m_tokenizer.encode(prompt); + try { + ChatHistory history({{{"role", "user"}, {"content", prompt}}}); + constexpr bool add_generation_prompt = true; + auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt); + tokenized_input = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false)); + } catch (const std::exception& error) { + // in case when chat_template was not found in tokenizer_config.json or set + tokenized_input = m_tokenizer.encode(prompt); + } } auto encode_stop_time = std::chrono::steady_clock::now(); diff --git a/src/cpp/src/visual_language/inputs_embedder.cpp b/src/cpp/src/visual_language/inputs_embedder.cpp index 9f8718f14c..b4e2b1cad1 100644 --- a/src/cpp/src/visual_language/inputs_embedder.cpp +++ b/src/cpp/src/visual_language/inputs_embedder.cpp @@ -217,8 +217,19 @@ class InputsEmbedder::IInputsEmbedder { m_tokenized_history.clear(); std::copy_n(new_chat_tokens.data(), new_chat_tokens.get_size(), std::back_inserter(m_tokenized_history)); } else { + std::string templated_prompt; + ChatHistory history({{{"role", "user"}, {"content", prompt}}}); + constexpr bool add_generation_prompt = true; + + try { + templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt); + } catch (const std::exception& error) { + // Use fallback chat template if it was not found in tokenizer_config.json + templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt, chat_template_fallback); + } + auto start_tokenizer_time = std::chrono::steady_clock::now(); - encoded_input_ids = m_tokenizer.encode(prompt).input_ids; + encoded_input_ids = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false)).input_ids; auto end_tokenizer_time = std::chrono::steady_clock::now(); metrics.raw_metrics.tokenization_durations.emplace_back(PerfMetrics::get_microsec(end_tokenizer_time - start_tokenizer_time)); m_tokenized_history.clear(); From 11dec948416a9fdd06e177effb928da019b1e257 Mon Sep 17 00:00:00 2001 From: sbalandi Date: Mon, 13 Jan 2025 21:59:40 +0000 Subject: [PATCH 2/3] update test and docs --- .github/workflows/causal_lm_cpp.yml | 42 +++++++++++++++---- samples/cpp/text_generation/README.md | 2 +- samples/python/text_generation/README.md | 2 +- src/README.md | 2 +- .../include/openvino/genai/llm_pipeline.hpp | 4 ++ src/cpp/include/openvino/genai/tokenizer.hpp | 3 ++ .../genai/visual_language/pipeline.hpp | 8 ++++ src/cpp/src/icontinuous_batching.cpp | 4 +- src/cpp/src/llm_pipeline_stateful.cpp | 10 ++--- src/cpp/src/llm_pipeline_static.cpp | 8 ++-- src/cpp/src/tokenizer.cpp | 8 ++++ .../src/visual_language/inputs_embedder.cpp | 8 ++-- tests/python_tests/common.py | 10 ++++- 13 files changed, 84 insertions(+), 27 deletions(-) diff --git a/.github/workflows/causal_lm_cpp.yml b/.github/workflows/causal_lm_cpp.yml index 5fc4617f2c..e001b85509 100644 --- a/.github/workflows/causal_lm_cpp.yml +++ b/.github/workflows/causal_lm_cpp.yml @@ -119,7 +119,10 @@ jobs: with open('pred.txt', 'r') as file: predictions = file.read() tokenizer = transformers.AutoTokenizer.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0') - tokenized = tokenizer('Why is the Sun yellow?', return_tensors='pt') + prompt = 'Why is the Sun yellow?' + if tokenizer.chat_template: + prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) + tokenized = tokenizer(prompt, return_tensors='pt') for beam in transformers.LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0').generate(**tokenized, num_beam_groups=3, num_beams=15, num_return_sequences=15, diversity_penalty=1.0, max_new_tokens=20, early_stopping=False, length_penalty=1.0, no_repeat_ngram_size=9**9, do_sample=False): ref = ': ' + tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True) idx = predictions.find(ref) @@ -135,7 +138,10 @@ jobs: with open('pred.txt', 'r') as file: predictions = file.read() tokenizer = transformers.AutoTokenizer.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0') - tokenized = tokenizer('69', return_tensors='pt') + prompt = '69' + if tokenizer.chat_template: + prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) + tokenized = tokenizer(prompt, return_tensors='pt') for beam in transformers.LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0').generate(**tokenized, num_beam_groups=3, num_beams=15, num_return_sequences=15, diversity_penalty=1.0, max_new_tokens=20, early_stopping=False, length_penalty=1.0, no_repeat_ngram_size=9**9, do_sample=False): ref = ': ' + tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True) idx = predictions.find(ref) @@ -151,7 +157,10 @@ jobs: with open('pred.txt', 'r') as file: predictions = file.read() tokenizer = transformers.AutoTokenizer.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0') - tokenized = tokenizer('Hi', return_tensors='pt') + prompt = 'Hi' + if tokenizer.chat_template: + prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) + tokenized = tokenizer(prompt, return_tensors='pt') for beam in transformers.LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0').generate(**tokenized, num_beam_groups=3, num_beams=15, num_return_sequences=15, diversity_penalty=1.0, max_new_tokens=20, early_stopping=False, length_penalty=1.0, no_repeat_ngram_size=9**9, do_sample=False): ref = ': ' + tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True) idx = predictions.find(ref) @@ -167,7 +176,10 @@ jobs: with open('pred.txt', 'r') as file: predictions = file.read() tokenizer = transformers.AutoTokenizer.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0') - tokenized = tokenizer('return 0', return_tensors='pt') + prompt = 'return 0' + if tokenizer.chat_template: + prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) + tokenized = tokenizer(prompt, return_tensors='pt') for beam in transformers.LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0').generate(**tokenized, num_beam_groups=3, num_beams=15, num_return_sequences=15, diversity_penalty=1.0, max_new_tokens=20, early_stopping=False, length_penalty=1.0, no_repeat_ngram_size=9**9, do_sample=False): ref = ': ' + tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True) idx = predictions.find(ref) @@ -183,7 +195,10 @@ jobs: with open('pred.txt', 'r', errors='ignore') as file: predictions = file.read() tokenizer = transformers.AutoTokenizer.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0') - tokenized = tokenizer('你好! 你好嗎?', return_tensors='pt') + prompt = '你好! 你好嗎?' + if tokenizer.chat_template: + prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) + tokenized = tokenizer(prompt, return_tensors='pt') for beam in transformers.LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0').generate(**tokenized, num_beam_groups=3, num_beams=15, num_return_sequences=15, diversity_penalty=1.0, max_new_tokens=20, early_stopping=False, length_penalty=1.0, no_repeat_ngram_size=9**9, do_sample=False): ref = ': ' + tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True) idx = predictions.find(ref.replace('�', '')) @@ -205,6 +220,8 @@ jobs: '你好! 你好嗎?' ] for prompt in prompts: + if tokenizer.chat_template: + prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) tokenized = tokenizer(prompt, return_tensors='pt') for beam in transformers.LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0').generate(**tokenized, num_beam_groups=3, num_beams=15, num_return_sequences=15, diversity_penalty=1.0, max_new_tokens=20, early_stopping=False, length_penalty=1.0, no_repeat_ngram_size=9**9, do_sample=False): ref = ': ' + tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True) @@ -254,7 +271,10 @@ jobs: echo import transformers > ref.py echo predictions = open('cpp.txt', 'r').read() >> ref.py echo tokenizer = transformers.AutoTokenizer.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0', trust_remote_code=True) >> ref.py - echo tokenized = tokenizer('69', return_tensors='pt') >> ref.py + echo prompt = '69' + echo if tokenizer.chat_template: + echo prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) + echo tokenized = tokenizer(prompt, return_tensors='pt') >> ref.py echo for beam in transformers.AutoModelForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0', trust_remote_code=True).generate(**tokenized, max_new_tokens=100, do_sample=False): >> ref.py echo ref = tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True) >> ref.py echo idx = predictions.find(ref) >> ref.py @@ -559,7 +579,10 @@ jobs: with open('pred_greedy.txt', 'r') as file: predictions = file.read() tokenizer = transformers.AutoTokenizer.from_pretrained('microsoft/phi-1_5') - tokenized = tokenizer('Alan Turing was a', return_tensors='pt') + prompt = 'Alan Turing was a' + if tokenizer.chat_template: + prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) + tokenized = tokenizer(prompt, return_tensors='pt') for output in transformers.AutoModelForCausalLM.from_pretrained('microsoft/phi-1_5').generate(**tokenized, max_length=100, do_sample=False): ref = tokenizer.decode(output[tokenized['input_ids'].numel():], skip_special_tokens=True) idx = predictions.find(ref) @@ -614,7 +637,10 @@ jobs: with open('pred_greedy.txt', 'r') as file: predictions = file.read() tokenizer = transformers.AutoTokenizer.from_pretrained('ikala/redpajama-3b-chat') - tokenized = tokenizer('Alan Turing was a', return_tensors='pt') + prompt = 'Alan Turing was a' + if tokenizer.chat_template: + prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) + tokenized = tokenizer(prompt, return_tensors='pt') for output in transformers.AutoModelForCausalLM.from_pretrained('ikala/redpajama-3b-chat').generate(**tokenized, max_length=100, do_sample=False): ref = tokenizer.decode(output[tokenized['input_ids'].numel():], skip_special_tokens=True) idx = predictions.find(ref) diff --git a/samples/cpp/text_generation/README.md b/samples/cpp/text_generation/README.md index d9e5bd8d22..c05fa25f9c 100644 --- a/samples/cpp/text_generation/README.md +++ b/samples/cpp/text_generation/README.md @@ -62,7 +62,7 @@ Recommended models: meta-llama/Llama-2-7b-chat-hf, TinyLlama/TinyLlama-1.1B-Chat ./chat_sample ``` #### Missing chat template -If you encounter an exception indicating a missing "chat template" when launching the `ov::genai::LLMPipeline` in chat mode, it likely means the model was not tuned for chat functionality. To work this around, manually add the chat template to tokenizer_config.json of your model. +If you encounter an exception indicating a missing "chat template" when launching the `ov::genai::LLMPipeline` in chat mode, it likely means the model was not tuned for chat functionality. To work this around, manually add the chat template to tokenizer_config.json of your model or update it using call `pipe.get_tokenizer().set_chat_template(new_chat_template)`. The following template can be used as a default, but it may not work properly with every model: ``` "chat_template": "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|im_start|>user\n' + message['content'] + '<|im_end|>\n<|im_start|>assistant\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>\n'}}{% endif %}{% endfor %}", diff --git a/samples/python/text_generation/README.md b/samples/python/text_generation/README.md index 9940904cfb..db2f6b0d5f 100644 --- a/samples/python/text_generation/README.md +++ b/samples/python/text_generation/README.md @@ -62,7 +62,7 @@ Recommended models: meta-llama/Llama-2-7b-chat-hf, TinyLlama/TinyLlama-1.1B-Chat python chat_sample.py model_dir ``` #### Missing chat template -If you encounter an exception indicating a missing "chat template" when launching the `ov::genai::LLMPipeline` in chat mode, it likely means the model was not tuned for chat functionality. To work this around, manually add the chat template to tokenizer_config.json of your model. +If you encounter an exception indicating a missing "chat template" when launching the `ov::genai::LLMPipeline` in chat mode, it likely means the model was not tuned for chat functionality. To work this around, manually add the chat template to tokenizer_config.json of your model or update it using call `pipe.get_tokenizer().set_chat_template(new_chat_template)`. The following template can be used as a default, but it may not work properly with every model: ``` "chat_template": "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|im_start|>user\n' + message['content'] + '<|im_end|>\n<|im_start|>assistant\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>\n'}}{% endif %}{% endfor %}", diff --git a/src/README.md b/src/README.md index 42d8aa9dde..c2ed8c2a60 100644 --- a/src/README.md +++ b/src/README.md @@ -73,7 +73,7 @@ output: 'it is made up of carbon atoms. The carbon atoms are arranged in a linear pattern, which gives the yellow color. The arrangement of carbon atoms in' ``` ->**Note**: The chat_template from tokenizer_config.json will be automatically applied to the prompt at the generation stage. If you want to disable it, you can do it by calling pipe.get_tokenizer().set_chat_template(""). +>**Note**: The chat_template from tokenizer_config.json or from tokenizer/detokenizer model will be automatically applied to the prompt at the generation stage. If you want to disable it, you can do it by calling pipe.get_tokenizer().set_chat_template(""). A simple chat in Python: ```python diff --git a/src/cpp/include/openvino/genai/llm_pipeline.hpp b/src/cpp/include/openvino/genai/llm_pipeline.hpp index e7a7c40f9b..e1619551d2 100644 --- a/src/cpp/include/openvino/genai/llm_pipeline.hpp +++ b/src/cpp/include/openvino/genai/llm_pipeline.hpp @@ -177,6 +177,8 @@ class OPENVINO_GENAI_EXPORTS LLMPipeline { * @param generation_config optional GenerationConfig * @param streamer optional streamer * @return DecodedResults decoded resulting text + * chat_template will be applied to the prompt, run pipe.get_tokenizer().set_chat_template(custom_chat_template) to update it. + * Use custom_chat_template = "" to disable it for non-chat mode. */ DecodedResults generate( StringInputs inputs, @@ -191,6 +193,8 @@ class OPENVINO_GENAI_EXPORTS LLMPipeline { * @param inputs input prompt or a vector of prompts * @param properties properties * @return DecodedResults decoded resulting text + * chat_template will be applied to the prompt, run pipe.get_tokenizer().set_chat_template(custom_chat_template) to update it. + * Use custom_chat_template = "" to disable it for non-chat mode. */ template util::EnableIfAllStringAny generate( diff --git a/src/cpp/include/openvino/genai/tokenizer.hpp b/src/cpp/include/openvino/genai/tokenizer.hpp index 548e4dc332..e4b754ea16 100644 --- a/src/cpp/include/openvino/genai/tokenizer.hpp +++ b/src/cpp/include/openvino/genai/tokenizer.hpp @@ -221,6 +221,9 @@ class OPENVINO_GENAI_EXPORTS Tokenizer { /// @param chat_template The new template to override with. void set_chat_template(const std::string& chat_template); + // get information about a chat template to check its status, for example whether it is empty + std::string get_chat_template() const; + // information about , tokens should be public, // they are used at least in StreamerBase descendants int64_t get_bos_token_id() const; diff --git a/src/cpp/include/openvino/genai/visual_language/pipeline.hpp b/src/cpp/include/openvino/genai/visual_language/pipeline.hpp index 43f8a9b8b3..edf2b24517 100644 --- a/src/cpp/include/openvino/genai/visual_language/pipeline.hpp +++ b/src/cpp/include/openvino/genai/visual_language/pipeline.hpp @@ -98,6 +98,8 @@ class OPENVINO_GENAI_EXPORTS VLMPipeline { /// @param generation_config A config to follow for text generation. /// @param streamer A streamer to acquire intermediate result. /// @return A string generated by a model. + /// chat_template will be applied to the prompt, run pipe.set_chat_template(custom_chat_template) to update it. + /// Use custom_chat_template="" to disable it for non-chat mode. VLMDecodedResults generate( const std::string& prompt, const std::vector& rgbs, @@ -111,6 +113,8 @@ class OPENVINO_GENAI_EXPORTS VLMPipeline { /// @param generation_config A config to follow for text generation. /// @param streamer A streamer to acquire intermediate result. /// @return A string generated by a model. + /// chat_template will be applied to the prompt, run pipe.set_chat_template(custom_chat_template) to update it. + /// Use custom_chat_template="" to disable it for non-chat mode. VLMDecodedResults generate( const std::string& prompt, const ov::Tensor& rgb, @@ -124,6 +128,8 @@ class OPENVINO_GENAI_EXPORTS VLMPipeline { /// for its members, StreamerVariant a single image or multiple /// images. /// @return A string generated by a model. + /// chat_template will be applied to the prompt, run pipe.set_chat_template(custom_chat_template) to update it. + /// Use custom_chat_template="" to disable it for non-chat mode. VLMDecodedResults generate( const std::string& prompt, const ov::AnyMap& config_map @@ -137,6 +143,8 @@ class OPENVINO_GENAI_EXPORTS VLMPipeline { /// @param ...properties ov::Property instances to be combined into /// ov::AnyMap. /// @return A string generated by a model. + /// chat_template will be applied to the prompt, run pipe.set_chat_template(custom_chat_template) to update it. + /// Use custom_chat_template="" to disable it for non-chat mode. template util::EnableIfAllStringAny generate( const std::string& prompt, diff --git a/src/cpp/src/icontinuous_batching.cpp b/src/cpp/src/icontinuous_batching.cpp index 03cfbc89e9..467f521a12 100644 --- a/src/cpp/src/icontinuous_batching.cpp +++ b/src/cpp/src/icontinuous_batching.cpp @@ -56,12 +56,12 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate( for (const std::string& prompt : prompts) { const auto encode_start = std::chrono::steady_clock::now(); ov::Tensor encoded_inputs; - try { + if (!m_tokenizer.get_chat_template().empty()) { ChatHistory history({{{"role", "user"}, {"content", prompt}}}); constexpr bool add_generation_prompt = true; auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt); encoded_inputs = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false)).input_ids; - } catch (const std::exception& error) { + } else { // in case when chat_template was not found in tokenizer_config.json or set encoded_inputs = m_tokenizer.encode(prompt).input_ids; } diff --git a/src/cpp/src/llm_pipeline_stateful.cpp b/src/cpp/src/llm_pipeline_stateful.cpp index 83e5b3c9ea..eab8bf4a3f 100644 --- a/src/cpp/src/llm_pipeline_stateful.cpp +++ b/src/cpp/src/llm_pipeline_stateful.cpp @@ -90,12 +90,12 @@ DecodedResults StatefulLLMPipeline::generate( OPENVINO_ASSERT(!is_chat_conversation, "Can't chat with multiple prompts"); std::vector templated_input_vector; for (auto& input : *input_vector) { - try { + if (!m_tokenizer.get_chat_template().empty()) { ChatHistory history({{{"role", "user"}, {"content", input}}}); constexpr bool add_generation_prompt = true; auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt); templated_input_vector.push_back(templated_prompt); - } catch (const std::exception& error) { + } else { // in case when chat_template was not found in tokenizer_config.json or set templated_input_vector.push_back(input); } @@ -116,7 +116,7 @@ DecodedResults StatefulLLMPipeline::generate( m_history.push_back({{"role", "user"}, {"content", prompt}}); constexpr bool add_generation_prompt = true; - auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); + auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); // Do not add special tokens in chat scenario to be aligned with HF. auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(false)); auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false)); @@ -170,12 +170,12 @@ DecodedResults StatefulLLMPipeline::generate( // TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied } else { std::string& prompt = *input_prompt; - try { + if (!m_tokenizer.get_chat_template().empty()) { ChatHistory history({{{"role", "user"}, {"content", prompt}}}); constexpr bool add_generation_prompt = true; auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt); encoded_input = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false)); - } catch (const std::exception& error) { + } else { // in case when chat_template was not found in tokenizer_config.json or set encoded_input = m_tokenizer.encode(prompt); } diff --git a/src/cpp/src/llm_pipeline_static.cpp b/src/cpp/src/llm_pipeline_static.cpp index 49f955ed41..52232d8259 100644 --- a/src/cpp/src/llm_pipeline_static.cpp +++ b/src/cpp/src/llm_pipeline_static.cpp @@ -784,12 +784,12 @@ DecodedResults StatefulLLMPipeline::generate( // for chat ov::genai::add_special_tokens(false) is aligned with stateful pipeline and HF tokenized_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(false)); } else { - try { + if (!m_tokenizer.get_chat_template().empty()) { ChatHistory history({{{"role", "user"}, {"content", prompt}}}); constexpr bool add_generation_prompt = true; auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt); tokenized_input = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false)); - } catch (const std::exception& error) { + } else { // in case when chat_template was not found in tokenizer_config.json or set tokenized_input = m_tokenizer.encode(prompt); } @@ -1260,12 +1260,12 @@ DecodedResults StatelessLLMPipeline::generate( // for chat ov::genai::add_special_tokens(false) is aligned with stateful pipeline and HF tokenized_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(false)); } else { - try { + if (!m_tokenizer.get_chat_template().empty()) { ChatHistory history({{{"role", "user"}, {"content", prompt}}}); constexpr bool add_generation_prompt = true; auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt); tokenized_input = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false)); - } catch (const std::exception& error) { + } else { // in case when chat_template was not found in tokenizer_config.json or set tokenized_input = m_tokenizer.encode(prompt); } diff --git a/src/cpp/src/tokenizer.cpp b/src/cpp/src/tokenizer.cpp index 03e3b0a5d3..30cb77dd2f 100644 --- a/src/cpp/src/tokenizer.cpp +++ b/src/cpp/src/tokenizer.cpp @@ -587,6 +587,10 @@ class Tokenizer::TokenizerImpl { void set_chat_template(const std::string& chat_template) { m_chat_template = patch_chat_template(chat_template); } + + std::string get_chat_template() { + return m_chat_template; + } }; Tokenizer::Tokenizer(const std::filesystem::path& tokenizer_path, const ov::AnyMap& properties) { @@ -690,6 +694,10 @@ std::string Tokenizer::apply_chat_template(ChatHistory history, return m_pimpl->apply_chat_template(history, add_generation_prompt, chat_template); } +std::string Tokenizer::get_chat_template() const { + return m_pimpl->get_chat_template(); +} + void Tokenizer::set_chat_template(const std::string& chat_template) { m_pimpl->set_chat_template(chat_template); } diff --git a/src/cpp/src/visual_language/inputs_embedder.cpp b/src/cpp/src/visual_language/inputs_embedder.cpp index b4e2b1cad1..b9655c34d6 100644 --- a/src/cpp/src/visual_language/inputs_embedder.cpp +++ b/src/cpp/src/visual_language/inputs_embedder.cpp @@ -163,9 +163,9 @@ class InputsEmbedder::IInputsEmbedder { m_history.push_back({{"role", "user"}, {"content", prompt}}); constexpr bool add_generation_prompt = true; std::string new_templated_chat_history; - try { + if (!m_tokenizer.get_chat_template().empty()) { new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); - } catch (const std::exception& error) { + } else { // Use fallback chat template if it was not found in tokenizer_config.json new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt, chat_template_fallback); } @@ -221,9 +221,9 @@ class InputsEmbedder::IInputsEmbedder { ChatHistory history({{{"role", "user"}, {"content", prompt}}}); constexpr bool add_generation_prompt = true; - try { + if (!m_tokenizer.get_chat_template().empty()) { templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt); - } catch (const std::exception& error) { + } else { // Use fallback chat template if it was not found in tokenizer_config.json templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt, chat_template_fallback); } diff --git a/tests/python_tests/common.py b/tests/python_tests/common.py index 2fca58a959..8926c5cd96 100644 --- a/tests/python_tests/common.py +++ b/tests/python_tests/common.py @@ -251,6 +251,8 @@ def run_hugging_face( # process prompt by promp as we have multiple generation configs for prompt, generation_config in zip(prompts, generation_configs): hf_generation_config = convert_to_hf(opt_model.generation_config, generation_config) + if hf_tokenizer.chat_template: + prompt = hf_tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) inputs = hf_tokenizer(prompt, return_tensors="pt") input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask'] prompt_len = 0 if generation_config.echo else input_ids.numel() @@ -265,8 +267,14 @@ def run_hugging_face( generation_result.m_scores = [score for score in generate_outputs.sequences_scores] generation_results.append(generation_result) else: + processed_prompts = [] + if hf_tokenizer.chat_template: + for prompt in prompts: + processed_prompts.append(hf_tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True)) + else: + processed_prompts = prompts # process all prompts as a single batch as we have a single generation config for all prompts - inputs = hf_tokenizer(prompts, return_tensors='pt', padding=True, truncation=True, add_special_tokens=True, padding_side='left') + inputs = hf_tokenizer(processed_prompts, return_tensors='pt', padding=True, truncation=True, add_special_tokens=True, padding_side='left') input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask'] hf_generation_config = convert_to_hf(opt_model.generation_config, generation_configs) hf_encoded_outputs = opt_model.generate(input_ids, attention_mask=attention_mask, generation_config=hf_generation_config, tokenizer=hf_tokenizer) From f5d74b483b6fa87e36c73fc2f3a2391297810025 Mon Sep 17 00:00:00 2001 From: sbalandi Date: Tue, 14 Jan 2025 14:39:04 +0000 Subject: [PATCH 3/3] ci fix --- .github/workflows/causal_lm_cpp.yml | 20 ++++++++++---------- tests/python_tests/common.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/workflows/causal_lm_cpp.yml b/.github/workflows/causal_lm_cpp.yml index e001b85509..b1d691d6be 100644 --- a/.github/workflows/causal_lm_cpp.yml +++ b/.github/workflows/causal_lm_cpp.yml @@ -122,7 +122,7 @@ jobs: prompt = 'Why is the Sun yellow?' if tokenizer.chat_template: prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) - tokenized = tokenizer(prompt, return_tensors='pt') + tokenized = tokenizer(prompt, return_tensors='pt', add_special_tokens=False) for beam in transformers.LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0').generate(**tokenized, num_beam_groups=3, num_beams=15, num_return_sequences=15, diversity_penalty=1.0, max_new_tokens=20, early_stopping=False, length_penalty=1.0, no_repeat_ngram_size=9**9, do_sample=False): ref = ': ' + tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True) idx = predictions.find(ref) @@ -141,7 +141,7 @@ jobs: prompt = '69' if tokenizer.chat_template: prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) - tokenized = tokenizer(prompt, return_tensors='pt') + tokenized = tokenizer(prompt, return_tensors='pt', add_special_tokens=False) for beam in transformers.LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0').generate(**tokenized, num_beam_groups=3, num_beams=15, num_return_sequences=15, diversity_penalty=1.0, max_new_tokens=20, early_stopping=False, length_penalty=1.0, no_repeat_ngram_size=9**9, do_sample=False): ref = ': ' + tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True) idx = predictions.find(ref) @@ -160,7 +160,7 @@ jobs: prompt = 'Hi' if tokenizer.chat_template: prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) - tokenized = tokenizer(prompt, return_tensors='pt') + tokenized = tokenizer(prompt, return_tensors='pt', add_special_tokens=False) for beam in transformers.LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0').generate(**tokenized, num_beam_groups=3, num_beams=15, num_return_sequences=15, diversity_penalty=1.0, max_new_tokens=20, early_stopping=False, length_penalty=1.0, no_repeat_ngram_size=9**9, do_sample=False): ref = ': ' + tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True) idx = predictions.find(ref) @@ -179,7 +179,7 @@ jobs: prompt = 'return 0' if tokenizer.chat_template: prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) - tokenized = tokenizer(prompt, return_tensors='pt') + tokenized = tokenizer(prompt, return_tensors='pt', add_special_tokens=False) for beam in transformers.LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0').generate(**tokenized, num_beam_groups=3, num_beams=15, num_return_sequences=15, diversity_penalty=1.0, max_new_tokens=20, early_stopping=False, length_penalty=1.0, no_repeat_ngram_size=9**9, do_sample=False): ref = ': ' + tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True) idx = predictions.find(ref) @@ -198,7 +198,7 @@ jobs: prompt = '你好! 你好嗎?' if tokenizer.chat_template: prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) - tokenized = tokenizer(prompt, return_tensors='pt') + tokenized = tokenizer(prompt, return_tensors='pt', add_special_tokens=False) for beam in transformers.LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0').generate(**tokenized, num_beam_groups=3, num_beams=15, num_return_sequences=15, diversity_penalty=1.0, max_new_tokens=20, early_stopping=False, length_penalty=1.0, no_repeat_ngram_size=9**9, do_sample=False): ref = ': ' + tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True) idx = predictions.find(ref.replace('�', '')) @@ -222,7 +222,7 @@ jobs: for prompt in prompts: if tokenizer.chat_template: prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) - tokenized = tokenizer(prompt, return_tensors='pt') + tokenized = tokenizer(prompt, return_tensors='pt', add_special_tokens=False) for beam in transformers.LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0').generate(**tokenized, num_beam_groups=3, num_beams=15, num_return_sequences=15, diversity_penalty=1.0, max_new_tokens=20, early_stopping=False, length_penalty=1.0, no_repeat_ngram_size=9**9, do_sample=False): ref = ': ' + tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True) idx = predictions.find(ref.replace('�', '')) @@ -272,9 +272,9 @@ jobs: echo predictions = open('cpp.txt', 'r').read() >> ref.py echo tokenizer = transformers.AutoTokenizer.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0', trust_remote_code=True) >> ref.py echo prompt = '69' - echo if tokenizer.chat_template: + echo if tokenizer.chat_template: echo prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) - echo tokenized = tokenizer(prompt, return_tensors='pt') >> ref.py + echo tokenized = tokenizer(prompt, return_tensors='pt', add_special_tokens=False) >> ref.py echo for beam in transformers.AutoModelForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0', trust_remote_code=True).generate(**tokenized, max_new_tokens=100, do_sample=False): >> ref.py echo ref = tokenizer.decode(beam[tokenized['input_ids'].numel():], skip_special_tokens=True) >> ref.py echo idx = predictions.find(ref) >> ref.py @@ -582,7 +582,7 @@ jobs: prompt = 'Alan Turing was a' if tokenizer.chat_template: prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) - tokenized = tokenizer(prompt, return_tensors='pt') + tokenized = tokenizer(prompt, return_tensors='pt', add_special_tokens=False) for output in transformers.AutoModelForCausalLM.from_pretrained('microsoft/phi-1_5').generate(**tokenized, max_length=100, do_sample=False): ref = tokenizer.decode(output[tokenized['input_ids'].numel():], skip_special_tokens=True) idx = predictions.find(ref) @@ -640,7 +640,7 @@ jobs: prompt = 'Alan Turing was a' if tokenizer.chat_template: prompt = tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) - tokenized = tokenizer(prompt, return_tensors='pt') + tokenized = tokenizer(prompt, return_tensors='pt', add_special_tokens=False) for output in transformers.AutoModelForCausalLM.from_pretrained('ikala/redpajama-3b-chat').generate(**tokenized, max_length=100, do_sample=False): ref = tokenizer.decode(output[tokenized['input_ids'].numel():], skip_special_tokens=True) idx = predictions.find(ref) diff --git a/tests/python_tests/common.py b/tests/python_tests/common.py index 8926c5cd96..7d7a03cf76 100644 --- a/tests/python_tests/common.py +++ b/tests/python_tests/common.py @@ -274,7 +274,7 @@ def run_hugging_face( else: processed_prompts = prompts # process all prompts as a single batch as we have a single generation config for all prompts - inputs = hf_tokenizer(processed_prompts, return_tensors='pt', padding=True, truncation=True, add_special_tokens=True, padding_side='left') + inputs = hf_tokenizer(processed_prompts, return_tensors='pt', padding=True, truncation=True, add_special_tokens=False, padding_side='left') input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask'] hf_generation_config = convert_to_hf(opt_model.generation_config, generation_configs) hf_encoded_outputs = opt_model.generate(input_ids, attention_mask=attention_mask, generation_config=hf_generation_config, tokenizer=hf_tokenizer)