Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle invalid utf-8 bytes in engine class instead of model class #1094

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 26 additions & 37 deletions guidance/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,13 +543,19 @@ def __call__(
delayed_engine_outputs = []
elif not echo and engine_response.new_bytes:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it that we only delay bytes if not echo?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is another logic to process delayed_bytes if echo is True above.

# do not collect tokens-metrics if echo is disabled
engine_response.generated_bytes = engine_response.new_bytes
engine_response.generated_tokens.clear()
try:
_ = (delayed_bytes + engine_response.new_bytes).decode("utf-8")
engine_response.generated_bytes = delayed_bytes + engine_response.new_bytes
delayed_bytes = b""
engine_response.generated_tokens.clear()
except UnicodeDecodeError:
delayed_bytes += engine_response.new_bytes

# process engine_response
# NOTE (loc): We should not yield the engine_response if new_bytes are invalid utf-8 bytes
# delayed bytes should be handled here in the engine
yield engine_response
if delayed_bytes == b"":
yield engine_response

if ll_response.stop:
assert mask is None
Expand Down Expand Up @@ -1404,52 +1410,36 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1):

# single generation
if n == 1:
generated_value = ""
# logprobs_out = []

delayed_bytes = b""
# last_is_generated = False

for chunk in gen_obj:

# we make everything full probability if we are not computing uncertainty
# if not self.engine.compute_log_probs:
# chunk.new_bytes_prob = 1.0

# convert the bytes to a string (delaying if we don't yet have a valid unicode string)
lm.token_count += chunk.new_token_count
chunk.new_bytes = delayed_bytes + chunk.new_bytes
try:
new_text = chunk.new_bytes.decode("utf8")
except UnicodeDecodeError:
delayed_bytes = chunk.new_bytes
continue
delayed_bytes = b""

if chunk.backtrack:
lm.engine.metrics.engine_backtrack_tokens += chunk.backtrack

if len(chunk.new_bytes) > 0:
generated_value += new_text

# split chunk into generated and force_forwarded parts for better animated visualization
if chunk.generated_bytes:
lm += TextOutput(
value=chunk.generated_bytes.decode("utf8"),
is_generated=True,
token_count=0,
prob=0.0,
tokens=chunk.generated_tokens,
)
# split chunk into generated and force_forwarded parts for better animated visualization
if chunk.generated_bytes:
lm += TextOutput(
value=chunk.generated_bytes.decode("utf8"),
is_generated=True,
token_count=0,
prob=0.0,
tokens=chunk.generated_tokens,
)

if chunk.force_forwarded_bytes:
lm += TextOutput(
value=chunk.force_forwarded_bytes.decode("utf8"),
is_force_forwarded=True,
token_count=0,
prob=0.0,
tokens=chunk.force_forwarded_tokens,
)
if chunk.force_forwarded_bytes:
lm += TextOutput(
value=chunk.force_forwarded_bytes.decode("utf8"),
is_force_forwarded=True,
token_count=0,
prob=0.0,
tokens=chunk.force_forwarded_tokens,
)

if self.echo:
lm.vis_chunk = VisBytesChunk(
Expand All @@ -1462,7 +1452,6 @@ def _run_stateless(self, stateless_function, temperature=0.0, top_p=1.0, n=1):
engine_outputs=chunk.engine_outputs,
)

# last_is_generated = chunk.is_generated
if len(chunk.capture_groups) > 0:
for k in chunk.capture_groups:
v = chunk.capture_groups[k]
Expand Down
Loading