Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openai[patch]: support streaming with json_schema response format #29044

Merged
merged 5 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 85 additions & 32 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,14 @@ def _convert_delta_to_message_chunk(
def _convert_chunk_to_generation_chunk(
chunk: dict, default_chunk_class: Type, base_generation_info: Optional[Dict]
) -> Optional[ChatGenerationChunk]:
if chunk.get("type") == "content.delta": # from beta.chat.completions.stream
return None
token_usage = chunk.get("usage")
choices = chunk.get("choices", [])
choices = (
chunk.get("choices", [])
# from beta.chat.completions.stream
or chunk.get("chunk", {}).get("choices", [])
)

usage_metadata: Optional[UsageMetadata] = (
_create_usage_metadata(token_usage) if token_usage else None
Expand Down Expand Up @@ -660,13 +666,24 @@ def _stream(
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}

if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
if "response_format" in payload:
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
)
payload.pop("stream")
response_stream = self.root_client.beta.chat.completions.stream(**payload)
context_manager = response_stream
else:
response = self.client.create(**payload)
with response:
if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
context_manager = response
with context_manager as response:
is_first_chunk = True
for chunk in response:
if not isinstance(chunk, dict):
Expand All @@ -686,6 +703,16 @@ def _stream(
)
is_first_chunk = False
yield generation_chunk
if hasattr(response, "get_final_completion") and "response_format" in payload:
final_completion = response.get_final_completion()
generation_chunk = self._get_generation_chunk_from_completion(
final_completion
)
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
yield generation_chunk

def _generate(
self,
Expand Down Expand Up @@ -794,13 +821,29 @@ async def _astream(
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}
if self.include_response_headers:
raw_response = await self.async_client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}

if "response_format" in payload:
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
)
payload.pop("stream")
response_stream = self.root_async_client.beta.chat.completions.stream(
**payload
)
context_manager = response_stream
else:
response = await self.async_client.create(**payload)
async with response:
if self.include_response_headers:
raw_response = await self.async_client.with_raw_response.create(
**payload
)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = await self.async_client.create(**payload)
context_manager = response
async with context_manager as response:
is_first_chunk = True
async for chunk in response:
if not isinstance(chunk, dict):
Expand All @@ -820,6 +863,16 @@ async def _astream(
)
is_first_chunk = False
yield generation_chunk
if hasattr(response, "get_final_completion") and "response_format" in payload:
final_completion = await response.get_final_completion()
generation_chunk = self._get_generation_chunk_from_completion(
final_completion
)
if run_manager:
await run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
yield generation_chunk

async def _agenerate(
self,
Expand Down Expand Up @@ -1010,25 +1063,6 @@ def get_num_tokens_from_messages(
num_tokens += 3
return num_tokens

def _should_stream(
self,
*,
async_api: bool,
run_manager: Optional[
Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun]
] = None,
response_format: Optional[Union[dict, type]] = None,
**kwargs: Any,
) -> bool:
if isinstance(response_format, type) and is_basemodel_subclass(response_format):
# TODO: Add support for streaming with Pydantic response_format.
warnings.warn("Streaming with Pydantic response_format not yet supported.")
return False

return super()._should_stream(
async_api=async_api, run_manager=run_manager, **kwargs
)

@deprecated(
since="0.2.1",
alternative="langchain_openai.chat_models.base.ChatOpenAI.bind_tools",
Expand Down Expand Up @@ -1531,6 +1565,25 @@ def _filter_disabled_params(self, **kwargs: Any) -> Dict[str, Any]:
filtered[k] = v
return filtered

def _get_generation_chunk_from_completion(
self, completion: openai.BaseModel
) -> ChatGenerationChunk:
"""Get chunk from completion (e.g., from final completion of a stream)."""
chat_result = self._create_chat_result(completion)
chat_message = chat_result.generations[0].message
if isinstance(chat_message, AIMessage):
usage_metadata = chat_message.usage_metadata
else:
usage_metadata = None
message = AIMessageChunk(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we have a message to message chunk util

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I poked around and found:

  • message_chunk_to_message (I want to go the other way)
  • A private function _msg_to_chunk -- this function actually loses the parsed representation because it casts to/from dict. It's used in merge_message_runs which isn't used in monorepo, I see it's used in langchain-aws.

Let me know if you think it's important to add a util or I've missed it. Note I actually overwrite content to be "" so it's not a perfect copy.

content="",
additional_kwargs=chat_message.additional_kwargs,
usage_metadata=usage_metadata,
)
return ChatGenerationChunk(
message=message, generation_info=chat_result.llm_output
)


class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
"""OpenAI chat model integration.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
HumanMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
from pydantic import BaseModel

from langchain_openai import AzureChatOpenAI
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
Expand Down Expand Up @@ -262,3 +263,37 @@ async def test_json_mode_async(llm: AzureChatOpenAI) -> None:
assert isinstance(full, AIMessageChunk)
assert isinstance(full.content, str)
assert json.loads(full.content) == {"a": 1}


class Foo(BaseModel):
response: str


def test_stream_response_format(llm: AzureChatOpenAI) -> None:
full: Optional[BaseMessageChunk] = None
chunks = []
for chunk in llm.stream("how are ya", response_format=Foo):
chunks.append(chunk)
full = chunk if full is None else full + chunk
assert len(chunks) > 1
assert isinstance(full, AIMessageChunk)
parsed = full.additional_kwargs["parsed"]
assert isinstance(parsed, Foo)
assert isinstance(full.content, str)
parsed_content = json.loads(full.content)
assert parsed.response == parsed_content["response"]


async def test_astream_response_format(llm: AzureChatOpenAI) -> None:
full: Optional[BaseMessageChunk] = None
chunks = []
async for chunk in llm.astream("how are ya", response_format=Foo):
chunks.append(chunk)
full = chunk if full is None else full + chunk
assert len(chunks) > 1
assert isinstance(full, AIMessageChunk)
parsed = full.additional_kwargs["parsed"]
assert isinstance(parsed, Foo)
assert isinstance(full.content, str)
parsed_content = json.loads(full.content)
assert parsed.response == parsed_content["response"]
Original file line number Diff line number Diff line change
Expand Up @@ -1092,14 +1092,37 @@ class Foo(BaseModel):


def test_stream_response_format() -> None:
list(ChatOpenAI(model="gpt-4o-mini").stream("how are ya", response_format=Foo))
full: Optional[BaseMessageChunk] = None
chunks = []
for chunk in ChatOpenAI(model="gpt-4o-mini").stream(
"how are ya", response_format=Foo
):
chunks.append(chunk)
full = chunk if full is None else full + chunk
assert len(chunks) > 1
assert isinstance(full, AIMessageChunk)
parsed = full.additional_kwargs["parsed"]
assert isinstance(parsed, Foo)
assert isinstance(full.content, str)
parsed_content = json.loads(full.content)
assert parsed.response == parsed_content["response"]


async def test_astream_response_format() -> None:
async for _ in ChatOpenAI(model="gpt-4o-mini").astream(
full: Optional[BaseMessageChunk] = None
chunks = []
async for chunk in ChatOpenAI(model="gpt-4o-mini").astream(
"how are ya", response_format=Foo
):
pass
chunks.append(chunk)
full = chunk if full is None else full + chunk
assert len(chunks) > 1
assert isinstance(full, AIMessageChunk)
parsed = full.additional_kwargs["parsed"]
assert isinstance(parsed, Foo)
assert isinstance(full.content, str)
parsed_content = json.loads(full.content)
assert parsed.response == parsed_content["response"]


@pytest.mark.parametrize("use_max_completion_tokens", [True, False])
Expand Down
Loading