Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed streamer without corrupting hieroglyphs #1540

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions samples/python/text_generation/multinomial_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, tokenizer):
self.tokens_cache = []
self.text_queue = queue.Queue()
self.print_len = 0
self.decoded_lengths = []

def __iter__(self):
"""
Expand Down Expand Up @@ -80,30 +81,35 @@ def put(self, token_id: int) -> bool:

Returns:
bool: True if generation should be stopped, False otherwise.
"""
"""
self.tokens_cache.append(token_id)
text = self.tokenizer.decode(self.tokens_cache)
self.decoded_lengths.append(len(text))

word = ''
delay_n_tokens = 3
if len(text) > self.print_len and '\n' == text[-1]:
# Flush the cache after the new line symbol.
word = text[self.print_len:]
word = text[self.print_len:]
self.tokens_cache = []
self.decoded_lengths = []
self.print_len = 0
elif len(text) >= 3 and text[-1] == chr(65533):
elif len(text) > 0 and text[-1] == chr(65533):
# Don't print incomplete text.
pass
elif len(text) > self.print_len:
# It is possible to have a shorter text after adding new token.
# Print to output only if text length is increaesed.
word = text[self.print_len:]
self.print_len = len(text)
self.put_word(word)

self.decoded_lengths[-1] = -1
elif len(self.tokens_cache) >= delay_n_tokens:
print_until = self.decoded_lengths[-delay_n_tokens]
if print_until != -1 and print_until > self.print_len:
# It is possible to have a shorter text after adding new token.
# Print to output only if text length is increased and text is complete (print_until != -1).
word = text[self.print_len:print_until]
self.print_len = print_until
self.put_word(word)

if self.get_stop_flag():
# When generation is stopped from streamer then end is not called, need to call it here manually.
self.end()
return True # True means stop generation
return True # True means stop generation
else:
return False # False means continue generation

Expand Down
34 changes: 24 additions & 10 deletions src/cpp/src/text_callback_streamer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,36 @@ bool TextCallbackStreamer::put(int64_t token) {
std::stringstream res;
m_tokens_cache.push_back(token);
std::string text = m_tokenizer.decode(m_tokens_cache);
if (!text.empty() && '\n' == text.back() && text.size() > print_len) {
m_decoded_lengths.push_back(text.length());

if (!text.empty() && '\n' == text.back() && text.size() > m_printed_len) {
// Flush the cache after the new line symbol
res << std::string_view{text.data() + print_len, text.size() - print_len};
res << std::string_view{text.data() + m_printed_len, text.size() - m_printed_len};
m_tokens_cache.clear();
print_len = 0;
m_decoded_lengths.clear();
m_printed_len = 0;
return on_finalized_subword_callback(res.str());
}

constexpr size_t delay_n_tokens = 3;
auto print_until = m_decoded_lengths[m_decoded_lengths.size() - delay_n_tokens];
constexpr char replacement[] = "\xef\xbf\xbd"; // MSVC with /utf-8 fails to compile � directly with newline in string literal error.
if (text.size() >= 3 && text.compare(text.size() - 3, 3, replacement) == 0) {
m_decoded_lengths[m_decoded_lengths.size() - 1] = -1;
// Don't print incomplete text
return on_finalized_subword_callback(res.str());
} else if (text.size() > print_len) {
}
// In some cases adding the next token can shorten the text,
// e.g. when apostrophe removing regex had worked after adding new tokens.
// Printing several last tokens is delayed.
if (m_tokens_cache.size() < delay_n_tokens) {
return on_finalized_subword_callback(res.str());
}
if (print_until != -1 && print_until > m_printed_len) {
// It is possible to have a shorter text after adding new token.
// Print to output only if text length is increaesed.
res << std::string_view{text.data() + print_len, text.size() - print_len} << std::flush;
print_len = text.size();
res << std::string_view{text.data() + m_printed_len, print_until - m_printed_len} << std::flush;
m_printed_len = print_until;
}

return on_finalized_subword_callback(res.str());
Expand All @@ -40,11 +53,12 @@ bool TextCallbackStreamer::put(int64_t token) {
void TextCallbackStreamer::end() {
std::stringstream res;
std::string text = m_tokenizer.decode(m_tokens_cache);
if (text.size() <= print_len)
return ;
res << std::string_view{text.data() + print_len, text.size() - print_len} << std::flush;
if (text.size() <= m_printed_len)
return;
res << std::string_view{text.data() + m_printed_len, text.size() - m_printed_len} << std::flush;
m_tokens_cache.clear();
print_len = 0;
m_decoded_lengths.clear();
m_printed_len = 0;
on_finalized_subword_callback(res.str());
return;
}
Expand Down
3 changes: 2 additions & 1 deletion src/cpp/src/text_callback_streamer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ class TextCallbackStreamer: public StreamerBase {
protected:
Tokenizer m_tokenizer;
std::vector<int64_t> m_tokens_cache;
size_t print_len = 0;
std::vector<int64_t> m_decoded_lengths;
size_t m_printed_len = 0;
};

} // namespace genai
Expand Down
19 changes: 19 additions & 0 deletions tests/python_tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,3 +474,22 @@ def get_image_by_link(link):
image = image.convert('RGB')
image_data = np.array((np.array(image.getdata()) - 128).astype(np.byte)).reshape(1, 3, image.size[1], image.size[0])
return Tensor(image_data)

def get_streamer_with_results():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as we agreed - we need to add to run_llm_pipeline_with_ref and run_cb_pipeline_with_ref and as they cover more cases (like streaming when stop strings are enabled)

# Return a streamer which accumulates results in order to compare with results returned from generate.
class StreamerWithResults:
results: List[str] = []
def __init__(self):
self.results = []

def accumulate(self, subword) -> bool:
self.results.append(subword)
return False

def get_result_str(self) -> str:
return ''.join(self.results)

def reset(self):
self.results = []

return StreamerWithResults()
28 changes: 27 additions & 1 deletion tests/python_tests/test_llm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pathlib import Path
import torch

from common import run_llm_pipeline_with_ref, convert_to_hf
from common import run_llm_pipeline_with_ref, convert_to_hf, get_streamer_with_results
from ov_genai_test_utils import (
get_models_list,
read_model,
Expand Down Expand Up @@ -194,6 +194,17 @@ def test_callback_kwargs_one_string(callback):
pipe.generate('table is made of', max_new_tokens=10, streamer=callback)


@pytest.mark.parametrize("streamer", [get_streamer_with_results()])
@pytest.mark.parametrize("prompt", ['table is made of', 'The Sun is yellow because', '你好! 你好嗎?'])
@pytest.mark.precommit
@pytest.mark.nightly
def test_callback_kwargs_one_string(prompt, streamer):
pipe = read_model(get_models_list()[0])[4]
streamer.reset()
res = pipe.generate(prompt, max_new_tokens=10, streamer=streamer.accumulate)
assert res == streamer.get_result_str()


@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
@pytest.mark.precommit
@pytest.mark.nightly
Expand All @@ -207,6 +218,21 @@ def test_callback_decoding_metallama(model_descr, callback):
ov_pipe = read_model(model_descr)[4]
ov_pipe.generate(prompt, max_new_tokens=300, streamer=callback)

@pytest.mark.parametrize("streamer", [get_streamer_with_results()])
@pytest.mark.precommit
@pytest.mark.nightly
@pytest.mark.parametrize("model_descr", get_models_list())
def test_callback_decoding_metallama_with_accumlation(model_descr, streamer):
# On metallama this prompt generates output which can shorten after adding new tokens.
# Test that streamer correctly handles such cases.
prompt = 'I have an interview about product speccing with the company Weekend Health. Give me an example of a question they might ask with regards about a new feature'
if model_descr[0] != 'meta-llama/Meta-Llama-3-8B-Instruct':
pytest.skip()
ov_pipe = read_model(model_descr)[4]
streamer.reset()
res = ov_pipe.generate(prompt, max_new_tokens=300, streamer=streamer.accumulate)
assert res == streamer.get_result_str()


@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
@pytest.mark.precommit
Expand Down
10 changes: 8 additions & 2 deletions tests/python_tests/test_vlm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def get_ov_model(cache):
@pytest.mark.nightly
def test_vlm_pipeline(cache):
def streamer(word: str) -> bool:
nonlocal result_from_streamer
result_from_streamer.append(word)
return False

models_path = get_ov_model(cache)
Expand All @@ -59,10 +61,14 @@ def streamer(word: str) -> bool:
ov_pipe = VLMPipeline(models_path, "CPU")
ov_pipe.start_chat()

ov_pipe.generate(prompts[0], images=images, generation_config=generation_config, streamer=streamer)
result_from_streamer = []
res = ov_pipe.generate(prompts[0], images=images, generation_config=generation_config, streamer=streamer)
assert res.texts[0] == ''.join(result_from_streamer)

for prompt in prompts[1:]:
ov_pipe.generate(prompt, generation_config=generation_config, streamer=streamer)
result_from_streamer = []
res = ov_pipe.generate(prompt, generation_config=generation_config, streamer=streamer)
assert res.texts[0] == ''.join(result_from_streamer)

ov_pipe.finish_chat()

Expand Down
Loading