Skip to content

Commit

Permalink
mistralai[patch]: Surface http errors (#20555)
Browse files Browse the repository at this point in the history
Do not swallow errors when streaming with httpx.

Update affected code if this PR gets merged to httpx:
https://github.com/florimondmanca/httpx-sse/pull/25/files
  • Loading branch information
eyurtsev authored Apr 17, 2024
1 parent 3f156e0 commit f257909
Showing 1 changed file with 34 additions and 2 deletions.
36 changes: 34 additions & 2 deletions libs/partners/mistralai/langchain_mistralai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,38 @@ def _convert_mistral_chat_message_to_message(
)


def _raise_on_error(response: httpx.Response) -> None:
"""Raise an error if the response is an error."""
if httpx.codes.is_error(response.status_code):
error_message = response.read().decode("utf-8")
raise httpx.HTTPStatusError(
f"Error response {response.status_code} "
f"while fetching {response.url}: {error_message}",
request=response.request,
response=response,
)


async def _araise_on_error(response: httpx.Response) -> None:
"""Raise an error if the response is an error."""
if httpx.codes.is_error(response.status_code):
error_message = (await response.aread()).decode("utf-8")
raise httpx.HTTPStatusError(
f"Error response {response.status_code} "
f"while fetching {response.url}: {error_message}",
request=response.request,
response=response,
)


async def _aiter_sse(
event_source_mgr: AsyncContextManager[EventSource],
) -> AsyncIterator[Dict]:
"""Iterate over the server-sent events."""
async with event_source_mgr as event_source:
# TODO(Team): Remove after this is fixed in httpx dependency
# https://github.com/florimondmanca/httpx-sse/pull/25/files
await _araise_on_error(event_source._response)
async for event in event_source.aiter_sse():
if event.data == "[DONE]":
return
Expand All @@ -144,10 +171,10 @@ async def _completion_with_retry(**kwargs: Any) -> Any:
event_source = aconnect_sse(
llm.async_client, "POST", "/chat/completions", json=kwargs
)

return _aiter_sse(event_source)
else:
response = await llm.async_client.post(url="/chat/completions", json=kwargs)
await _araise_on_error(response)
return response.json()

return await _completion_with_retry(**kwargs)
Expand Down Expand Up @@ -298,14 +325,19 @@ def iter_sse() -> Iterator[Dict]:
with connect_sse(
self.client, "POST", "/chat/completions", json=kwargs
) as event_source:
# TODO(Team): Remove after this is fixed in httpx dependency
# https://github.com/florimondmanca/httpx-sse/pull/25/files
_raise_on_error(event_source._response)
for event in event_source.iter_sse():
if event.data == "[DONE]":
return
yield event.json()

return iter_sse()
else:
return self.client.post(url="/chat/completions", json=kwargs).json()
response = self.client.post(url="/chat/completions", json=kwargs)
_raise_on_error(response)
return response.json()

rtn = _completion_with_retry(**kwargs)
return rtn
Expand Down

0 comments on commit f257909

Please sign in to comment.