From 914edec01d1998d2dd58f1a9ad7ea0e85bf4ed55 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Sat, 4 Jan 2025 16:59:36 -0500 Subject: [PATCH 1/5] update --- .../langchain_openai/chat_models/base.py | 74 +++++++++++-------- 1 file changed, 43 insertions(+), 31 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 142e7eca1a84b..556da1af86e95 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -318,8 +318,10 @@ def _convert_delta_to_message_chunk( def _convert_chunk_to_generation_chunk( chunk: dict, default_chunk_class: Type, base_generation_info: Optional[Dict] ) -> Optional[ChatGenerationChunk]: + if chunk.get("type") == "content.delta": + return None token_usage = chunk.get("usage") - choices = chunk.get("choices", []) + choices = chunk.get("choices", []) or chunk.get("snapshot", {}).get("choices", []) usage_metadata: Optional[UsageMetadata] = ( _create_usage_metadata(token_usage) if token_usage else None @@ -332,12 +334,20 @@ def _convert_chunk_to_generation_chunk( return generation_chunk choice = choices[0] - if choice["delta"] is None: - return None + if chunk.get("type") == "chunk": + refusal = choice.get("message", {}).get("refusal") + content = choice.get("message", {}).get("content") + message_chunk = AIMessageChunk( + content, + additional_kwargs={"refusal": refusal}, + ) + else: + if choice["delta"] is None: + return None - message_chunk = _convert_delta_to_message_chunk( - choice["delta"], default_chunk_class - ) + message_chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) generation_info = {**base_generation_info} if base_generation_info else {} if finish_reason := choice.get("finish_reason"): @@ -660,13 +670,24 @@ def _stream( default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk base_generation_info = {} - if self.include_response_headers: - raw_response = self.client.with_raw_response.create(**payload) - response = raw_response.parse() - base_generation_info = {"headers": dict(raw_response.headers)} + if "response_format" in payload: + if self.include_response_headers: + warnings.warn( + "Cannot currently include response headers when response_format is " + "specified." + ) + payload.pop("stream") + response_stream = self.root_client.beta.chat.completions.stream(**payload) + context_manager = response_stream else: - response = self.client.create(**payload) - with response: + if self.include_response_headers: + raw_response = self.client.with_raw_response.create(**payload) + response = raw_response.parse() + base_generation_info = {"headers": dict(raw_response.headers)} + else: + response = self.client.create(**payload) + context_manager = response + with context_manager as response: is_first_chunk = True for chunk in response: if not isinstance(chunk, dict): @@ -686,6 +707,16 @@ def _stream( ) is_first_chunk = False yield generation_chunk + if hasattr(response, "get_final_completion"): + final_completion = response.get_final_completion() + if isinstance(final_completion, openai.BaseModel): + message = AIMessageChunk( + "", + additional_kwargs={ + "parsed": final_completion.choices[0].message.parsed + }, + ) + yield ChatGenerationChunk(message=message) def _generate( self, @@ -1010,25 +1041,6 @@ def get_num_tokens_from_messages( num_tokens += 3 return num_tokens - def _should_stream( - self, - *, - async_api: bool, - run_manager: Optional[ - Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun] - ] = None, - response_format: Optional[Union[dict, type]] = None, - **kwargs: Any, - ) -> bool: - if isinstance(response_format, type) and is_basemodel_subclass(response_format): - # TODO: Add support for streaming with Pydantic response_format. - warnings.warn("Streaming with Pydantic response_format not yet supported.") - return False - - return super()._should_stream( - async_api=async_api, run_manager=run_manager, **kwargs - ) - @deprecated( since="0.2.1", alternative="langchain_openai.chat_models.base.ChatOpenAI.bind_tools", From 09f35c820482de448d57d6a55eb0ca63dc3de798 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Sat, 4 Jan 2025 19:23:02 -0500 Subject: [PATCH 2/5] update --- .../langchain_openai/chat_models/base.py | 43 ++++++++++--------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 556da1af86e95..68c0a1a3c92c6 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -318,10 +318,14 @@ def _convert_delta_to_message_chunk( def _convert_chunk_to_generation_chunk( chunk: dict, default_chunk_class: Type, base_generation_info: Optional[Dict] ) -> Optional[ChatGenerationChunk]: - if chunk.get("type") == "content.delta": + if chunk.get("type") == "content.delta": # from beta.chat.completions.stream return None token_usage = chunk.get("usage") - choices = chunk.get("choices", []) or chunk.get("snapshot", {}).get("choices", []) + choices = ( + chunk.get("choices", []) + # from beta.chat.completions.stream + or chunk.get("chunk", {}).get("choices", []) + ) usage_metadata: Optional[UsageMetadata] = ( _create_usage_metadata(token_usage) if token_usage else None @@ -334,20 +338,12 @@ def _convert_chunk_to_generation_chunk( return generation_chunk choice = choices[0] - if chunk.get("type") == "chunk": - refusal = choice.get("message", {}).get("refusal") - content = choice.get("message", {}).get("content") - message_chunk = AIMessageChunk( - content, - additional_kwargs={"refusal": refusal}, - ) - else: - if choice["delta"] is None: - return None + if choice["delta"] is None: + return None - message_chunk = _convert_delta_to_message_chunk( - choice["delta"], default_chunk_class - ) + message_chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) generation_info = {**base_generation_info} if base_generation_info else {} if finish_reason := choice.get("finish_reason"): @@ -710,13 +706,20 @@ def _stream( if hasattr(response, "get_final_completion"): final_completion = response.get_final_completion() if isinstance(final_completion, openai.BaseModel): + chat_result = self._create_chat_result(final_completion) + chat_message = chat_result.generations[0].message + if isinstance(chat_message, AIMessage): + usage_metadata = chat_message.usage_metadata + else: + usage_metadata = None message = AIMessageChunk( - "", - additional_kwargs={ - "parsed": final_completion.choices[0].message.parsed - }, + content="", + additional_kwargs=chat_message.additional_kwargs, + usage_metadata=usage_metadata, + ) + yield ChatGenerationChunk( + message=message, generation_info=chat_result.llm_output ) - yield ChatGenerationChunk(message=message) def _generate( self, From 16859172f1435d169380d3f487f9d07766b1e1bd Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Sat, 4 Jan 2025 21:52:08 -0500 Subject: [PATCH 3/5] implement async --- .../langchain_openai/chat_models/base.py | 78 ++++++++++++++----- 1 file changed, 59 insertions(+), 19 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 68c0a1a3c92c6..e434974dd7582 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -706,20 +706,14 @@ def _stream( if hasattr(response, "get_final_completion"): final_completion = response.get_final_completion() if isinstance(final_completion, openai.BaseModel): - chat_result = self._create_chat_result(final_completion) - chat_message = chat_result.generations[0].message - if isinstance(chat_message, AIMessage): - usage_metadata = chat_message.usage_metadata - else: - usage_metadata = None - message = AIMessageChunk( - content="", - additional_kwargs=chat_message.additional_kwargs, - usage_metadata=usage_metadata, - ) - yield ChatGenerationChunk( - message=message, generation_info=chat_result.llm_output + generation_chunk = self._get_generation_chunk_from_completion( + final_completion ) + if run_manager: + run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk + ) + yield generation_chunk def _generate( self, @@ -828,13 +822,29 @@ async def _astream( payload = self._get_request_payload(messages, stop=stop, **kwargs) default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk base_generation_info = {} - if self.include_response_headers: - raw_response = await self.async_client.with_raw_response.create(**payload) - response = raw_response.parse() - base_generation_info = {"headers": dict(raw_response.headers)} + + if "response_format" in payload: + if self.include_response_headers: + warnings.warn( + "Cannot currently include response headers when response_format is " + "specified." + ) + payload.pop("stream") + response_stream = self.root_async_client.beta.chat.completions.stream( + **payload + ) + context_manager = response_stream else: - response = await self.async_client.create(**payload) - async with response: + if self.include_response_headers: + raw_response = await self.async_client.with_raw_response.create( + **payload + ) + response = raw_response.parse() + base_generation_info = {"headers": dict(raw_response.headers)} + else: + response = await self.async_client.create(**payload) + context_manager = response + async with context_manager as response: is_first_chunk = True async for chunk in response: if not isinstance(chunk, dict): @@ -854,6 +864,17 @@ async def _astream( ) is_first_chunk = False yield generation_chunk + if hasattr(response, "get_final_completion"): + final_completion = await response.get_final_completion() + if isinstance(final_completion, openai.BaseModel): + generation_chunk = self._get_generation_chunk_from_completion( + final_completion + ) + if run_manager: + await run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk + ) + yield generation_chunk async def _agenerate( self, @@ -1546,6 +1567,25 @@ def _filter_disabled_params(self, **kwargs: Any) -> Dict[str, Any]: filtered[k] = v return filtered + def _get_generation_chunk_from_completion( + self, completion: openai.BaseModel + ) -> ChatGenerationChunk: + """Get chunk from completion (e.g., from final completion of a stream).""" + chat_result = self._create_chat_result(completion) + chat_message = chat_result.generations[0].message + if isinstance(chat_message, AIMessage): + usage_metadata = chat_message.usage_metadata + else: + usage_metadata = None + message = AIMessageChunk( + content="", + additional_kwargs=chat_message.additional_kwargs, + usage_metadata=usage_metadata, + ) + return ChatGenerationChunk( + message=message, generation_info=chat_result.llm_output + ) + class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] """OpenAI chat model integration. From cc12670ec57edb677264ff3c051939ee01c9e3b2 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Mon, 6 Jan 2025 11:43:43 -0500 Subject: [PATCH 4/5] update tests --- .../langchain_openai/chat_models/base.py | 34 +++++++++---------- .../chat_models/test_base.py | 29 ++++++++++++++-- 2 files changed, 42 insertions(+), 21 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index e434974dd7582..f4e26253484e5 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -703,17 +703,16 @@ def _stream( ) is_first_chunk = False yield generation_chunk - if hasattr(response, "get_final_completion"): + if hasattr(response, "get_final_completion") and "response_format" in payload: final_completion = response.get_final_completion() - if isinstance(final_completion, openai.BaseModel): - generation_chunk = self._get_generation_chunk_from_completion( - final_completion + generation_chunk = self._get_generation_chunk_from_completion( + final_completion + ) + if run_manager: + run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk ) - if run_manager: - run_manager.on_llm_new_token( - generation_chunk.text, chunk=generation_chunk - ) - yield generation_chunk + yield generation_chunk def _generate( self, @@ -864,17 +863,16 @@ async def _astream( ) is_first_chunk = False yield generation_chunk - if hasattr(response, "get_final_completion"): + if hasattr(response, "get_final_completion") and "response_format" in payload: final_completion = await response.get_final_completion() - if isinstance(final_completion, openai.BaseModel): - generation_chunk = self._get_generation_chunk_from_completion( - final_completion + generation_chunk = self._get_generation_chunk_from_completion( + final_completion + ) + if run_manager: + await run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk ) - if run_manager: - await run_manager.on_llm_new_token( - generation_chunk.text, chunk=generation_chunk - ) - yield generation_chunk + yield generation_chunk async def _agenerate( self, diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 93c08ce214178..506799aef4b59 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -1092,14 +1092,37 @@ class Foo(BaseModel): def test_stream_response_format() -> None: - list(ChatOpenAI(model="gpt-4o-mini").stream("how are ya", response_format=Foo)) + full: Optional[BaseMessageChunk] = None + chunks = [] + for chunk in ChatOpenAI(model="gpt-4o-mini").stream( + "how are ya", response_format=Foo + ): + chunks.append(chunk) + full = chunk if full is None else full + chunk + assert len(chunks) > 1 + assert isinstance(full, AIMessageChunk) + parsed = full.additional_kwargs["parsed"] + assert isinstance(parsed, Foo) + assert isinstance(full.content, str) + parsed_content = json.loads(full.content) + assert parsed.response == parsed_content["response"] async def test_astream_response_format() -> None: - async for _ in ChatOpenAI(model="gpt-4o-mini").astream( + full: Optional[BaseMessageChunk] = None + chunks = [] + async for chunk in ChatOpenAI(model="gpt-4o-mini").astream( "how are ya", response_format=Foo ): - pass + chunks.append(chunk) + full = chunk if full is None else full + chunk + assert len(chunks) > 1 + assert isinstance(full, AIMessageChunk) + parsed = full.additional_kwargs["parsed"] + assert isinstance(parsed, Foo) + assert isinstance(full.content, str) + parsed_content = json.loads(full.content) + assert parsed.response == parsed_content["response"] @pytest.mark.parametrize("use_max_completion_tokens", [True, False]) From 3121fd7dc0106f6c3896769395b3e6a74c7e6979 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Mon, 6 Jan 2025 11:51:37 -0500 Subject: [PATCH 5/5] test azure --- .../chat_models/test_azure.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py b/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py index 4ed531ad119ff..e9228df0730c5 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py @@ -13,6 +13,7 @@ HumanMessage, ) from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult +from pydantic import BaseModel from langchain_openai import AzureChatOpenAI from tests.unit_tests.fake.callbacks import FakeCallbackHandler @@ -262,3 +263,37 @@ async def test_json_mode_async(llm: AzureChatOpenAI) -> None: assert isinstance(full, AIMessageChunk) assert isinstance(full.content, str) assert json.loads(full.content) == {"a": 1} + + +class Foo(BaseModel): + response: str + + +def test_stream_response_format(llm: AzureChatOpenAI) -> None: + full: Optional[BaseMessageChunk] = None + chunks = [] + for chunk in llm.stream("how are ya", response_format=Foo): + chunks.append(chunk) + full = chunk if full is None else full + chunk + assert len(chunks) > 1 + assert isinstance(full, AIMessageChunk) + parsed = full.additional_kwargs["parsed"] + assert isinstance(parsed, Foo) + assert isinstance(full.content, str) + parsed_content = json.loads(full.content) + assert parsed.response == parsed_content["response"] + + +async def test_astream_response_format(llm: AzureChatOpenAI) -> None: + full: Optional[BaseMessageChunk] = None + chunks = [] + async for chunk in llm.astream("how are ya", response_format=Foo): + chunks.append(chunk) + full = chunk if full is None else full + chunk + assert len(chunks) > 1 + assert isinstance(full, AIMessageChunk) + parsed = full.additional_kwargs["parsed"] + assert isinstance(parsed, Foo) + assert isinstance(full.content, str) + parsed_content = json.loads(full.content) + assert parsed.response == parsed_content["response"]