Skip to content

Commit

Permalink
implement async
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme committed Jan 5, 2025
1 parent 09f35c8 commit 1685917
Showing 1 changed file with 59 additions and 19 deletions.
78 changes: 59 additions & 19 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,20 +706,14 @@ def _stream(
if hasattr(response, "get_final_completion"):
final_completion = response.get_final_completion()
if isinstance(final_completion, openai.BaseModel):
chat_result = self._create_chat_result(final_completion)
chat_message = chat_result.generations[0].message
if isinstance(chat_message, AIMessage):
usage_metadata = chat_message.usage_metadata
else:
usage_metadata = None
message = AIMessageChunk(
content="",
additional_kwargs=chat_message.additional_kwargs,
usage_metadata=usage_metadata,
)
yield ChatGenerationChunk(
message=message, generation_info=chat_result.llm_output
generation_chunk = self._get_generation_chunk_from_completion(
final_completion
)
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
yield generation_chunk

def _generate(
self,
Expand Down Expand Up @@ -828,13 +822,29 @@ async def _astream(
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}
if self.include_response_headers:
raw_response = await self.async_client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}

if "response_format" in payload:
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
)
payload.pop("stream")
response_stream = self.root_async_client.beta.chat.completions.stream(
**payload
)
context_manager = response_stream
else:
response = await self.async_client.create(**payload)
async with response:
if self.include_response_headers:
raw_response = await self.async_client.with_raw_response.create(
**payload
)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = await self.async_client.create(**payload)
context_manager = response
async with context_manager as response:
is_first_chunk = True
async for chunk in response:
if not isinstance(chunk, dict):
Expand All @@ -854,6 +864,17 @@ async def _astream(
)
is_first_chunk = False
yield generation_chunk
if hasattr(response, "get_final_completion"):
final_completion = await response.get_final_completion()
if isinstance(final_completion, openai.BaseModel):
generation_chunk = self._get_generation_chunk_from_completion(
final_completion
)
if run_manager:
await run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
yield generation_chunk

async def _agenerate(
self,
Expand Down Expand Up @@ -1546,6 +1567,25 @@ def _filter_disabled_params(self, **kwargs: Any) -> Dict[str, Any]:
filtered[k] = v
return filtered

def _get_generation_chunk_from_completion(
self, completion: openai.BaseModel
) -> ChatGenerationChunk:
"""Get chunk from completion (e.g., from final completion of a stream)."""
chat_result = self._create_chat_result(completion)
chat_message = chat_result.generations[0].message
if isinstance(chat_message, AIMessage):
usage_metadata = chat_message.usage_metadata
else:
usage_metadata = None
message = AIMessageChunk(
content="",
additional_kwargs=chat_message.additional_kwargs,
usage_metadata=usage_metadata,
)
return ChatGenerationChunk(
message=message, generation_info=chat_result.llm_output
)


class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
"""OpenAI chat model integration.
Expand Down

0 comments on commit 1685917

Please sign in to comment.