diff --git a/guidance/models/_model.py b/guidance/models/_model.py index da02a3ce7..c9fa030f0 100644 --- a/guidance/models/_model.py +++ b/guidance/models/_model.py @@ -543,13 +543,19 @@ def __call__( delayed_engine_outputs = [] elif not echo and engine_response.new_bytes: # do not collect tokens-metrics if echo is disabled - engine_response.generated_bytes = engine_response.new_bytes - engine_response.generated_tokens.clear() + try: + _ = (delayed_bytes + engine_response.new_bytes).decode("utf-8") + engine_response.generated_bytes = delayed_bytes + engine_response.new_bytes + delayed_bytes = b"" + engine_response.generated_tokens.clear() + except UnicodeDecodeError: + delayed_bytes += engine_response.new_bytes # process engine_response # NOTE (loc): We should not yield the engine_response if new_bytes are invalid utf-8 bytes # delayed bytes should be handled here in the engine - yield engine_response + if delayed_bytes == b"": + yield engine_response if ll_response.stop: assert mask is None @@ -1404,52 +1410,36 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1): # single generation if n == 1: - generated_value = "" # logprobs_out = [] - - delayed_bytes = b"" - # last_is_generated = False - for chunk in gen_obj: - # we make everything full probability if we are not computing uncertainty # if not self.engine.compute_log_probs: # chunk.new_bytes_prob = 1.0 # convert the bytes to a string (delaying if we don't yet have a valid unicode string) lm.token_count += chunk.new_token_count - chunk.new_bytes = delayed_bytes + chunk.new_bytes - try: - new_text = chunk.new_bytes.decode("utf8") - except UnicodeDecodeError: - delayed_bytes = chunk.new_bytes - continue - delayed_bytes = b"" if chunk.backtrack: lm.engine.metrics.engine_backtrack_tokens += chunk.backtrack - if len(chunk.new_bytes) > 0: - generated_value += new_text - - # split chunk into generated and force_forwarded parts for better animated visualization - if chunk.generated_bytes: - lm += TextOutput( - value=chunk.generated_bytes.decode("utf8"), - is_generated=True, - token_count=0, - prob=0.0, - tokens=chunk.generated_tokens, - ) + # split chunk into generated and force_forwarded parts for better animated visualization + if chunk.generated_bytes: + lm += TextOutput( + value=chunk.generated_bytes.decode("utf8"), + is_generated=True, + token_count=0, + prob=0.0, + tokens=chunk.generated_tokens, + ) - if chunk.force_forwarded_bytes: - lm += TextOutput( - value=chunk.force_forwarded_bytes.decode("utf8"), - is_force_forwarded=True, - token_count=0, - prob=0.0, - tokens=chunk.force_forwarded_tokens, - ) + if chunk.force_forwarded_bytes: + lm += TextOutput( + value=chunk.force_forwarded_bytes.decode("utf8"), + is_force_forwarded=True, + token_count=0, + prob=0.0, + tokens=chunk.force_forwarded_tokens, + ) if self.echo: lm.vis_chunk = VisBytesChunk( @@ -1462,7 +1452,6 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1): engine_outputs=chunk.engine_outputs, ) - # last_is_generated = chunk.is_generated if len(chunk.capture_groups) > 0: for k in chunk.capture_groups: v = chunk.capture_groups[k]