From f2579096993ae460516a0aae1d3e09f3eb5c1772 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 17 Apr 2024 10:47:56 -0400 Subject: [PATCH] mistralai[patch]: Surface http errors (#20555) Do not swallow errors when streaming with httpx. Update affected code if this PR gets merged to httpx: https://github.com/florimondmanca/httpx-sse/pull/25/files --- .../langchain_mistralai/chat_models.py | 36 +++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 8b248ebb98e8e..3b41e6cf3ee2a 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -116,11 +116,38 @@ def _convert_mistral_chat_message_to_message( ) +def _raise_on_error(response: httpx.Response) -> None: + """Raise an error if the response is an error.""" + if httpx.codes.is_error(response.status_code): + error_message = response.read().decode("utf-8") + raise httpx.HTTPStatusError( + f"Error response {response.status_code} " + f"while fetching {response.url}: {error_message}", + request=response.request, + response=response, + ) + + +async def _araise_on_error(response: httpx.Response) -> None: + """Raise an error if the response is an error.""" + if httpx.codes.is_error(response.status_code): + error_message = (await response.aread()).decode("utf-8") + raise httpx.HTTPStatusError( + f"Error response {response.status_code} " + f"while fetching {response.url}: {error_message}", + request=response.request, + response=response, + ) + + async def _aiter_sse( event_source_mgr: AsyncContextManager[EventSource], ) -> AsyncIterator[Dict]: """Iterate over the server-sent events.""" async with event_source_mgr as event_source: + # TODO(Team): Remove after this is fixed in httpx dependency + # https://github.com/florimondmanca/httpx-sse/pull/25/files + await _araise_on_error(event_source._response) async for event in event_source.aiter_sse(): if event.data == "[DONE]": return @@ -144,10 +171,10 @@ async def _completion_with_retry(**kwargs: Any) -> Any: event_source = aconnect_sse( llm.async_client, "POST", "/chat/completions", json=kwargs ) - return _aiter_sse(event_source) else: response = await llm.async_client.post(url="/chat/completions", json=kwargs) + await _araise_on_error(response) return response.json() return await _completion_with_retry(**kwargs) @@ -298,6 +325,9 @@ def iter_sse() -> Iterator[Dict]: with connect_sse( self.client, "POST", "/chat/completions", json=kwargs ) as event_source: + # TODO(Team): Remove after this is fixed in httpx dependency + # https://github.com/florimondmanca/httpx-sse/pull/25/files + _raise_on_error(event_source._response) for event in event_source.iter_sse(): if event.data == "[DONE]": return @@ -305,7 +335,9 @@ def iter_sse() -> Iterator[Dict]: return iter_sse() else: - return self.client.post(url="/chat/completions", json=kwargs).json() + response = self.client.post(url="/chat/completions", json=kwargs) + _raise_on_error(response) + return response.json() rtn = _completion_with_retry(**kwargs) return rtn