Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme committed Jan 6, 2025
1 parent 1685917 commit cc12670
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 21 deletions.
34 changes: 16 additions & 18 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,17 +703,16 @@ def _stream(
)
is_first_chunk = False
yield generation_chunk
if hasattr(response, "get_final_completion"):
if hasattr(response, "get_final_completion") and "response_format" in payload:
final_completion = response.get_final_completion()
if isinstance(final_completion, openai.BaseModel):
generation_chunk = self._get_generation_chunk_from_completion(
final_completion
generation_chunk = self._get_generation_chunk_from_completion(
final_completion
)
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
yield generation_chunk
yield generation_chunk

def _generate(
self,
Expand Down Expand Up @@ -864,17 +863,16 @@ async def _astream(
)
is_first_chunk = False
yield generation_chunk
if hasattr(response, "get_final_completion"):
if hasattr(response, "get_final_completion") and "response_format" in payload:
final_completion = await response.get_final_completion()
if isinstance(final_completion, openai.BaseModel):
generation_chunk = self._get_generation_chunk_from_completion(
final_completion
generation_chunk = self._get_generation_chunk_from_completion(
final_completion
)
if run_manager:
await run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
if run_manager:
await run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
yield generation_chunk
yield generation_chunk

async def _agenerate(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1092,14 +1092,37 @@ class Foo(BaseModel):


def test_stream_response_format() -> None:
list(ChatOpenAI(model="gpt-4o-mini").stream("how are ya", response_format=Foo))
full: Optional[BaseMessageChunk] = None
chunks = []
for chunk in ChatOpenAI(model="gpt-4o-mini").stream(
"how are ya", response_format=Foo
):
chunks.append(chunk)
full = chunk if full is None else full + chunk
assert len(chunks) > 1
assert isinstance(full, AIMessageChunk)
parsed = full.additional_kwargs["parsed"]
assert isinstance(parsed, Foo)
assert isinstance(full.content, str)
parsed_content = json.loads(full.content)
assert parsed.response == parsed_content["response"]


async def test_astream_response_format() -> None:
async for _ in ChatOpenAI(model="gpt-4o-mini").astream(
full: Optional[BaseMessageChunk] = None
chunks = []
async for chunk in ChatOpenAI(model="gpt-4o-mini").astream(
"how are ya", response_format=Foo
):
pass
chunks.append(chunk)
full = chunk if full is None else full + chunk
assert len(chunks) > 1
assert isinstance(full, AIMessageChunk)
parsed = full.additional_kwargs["parsed"]
assert isinstance(parsed, Foo)
assert isinstance(full.content, str)
parsed_content = json.loads(full.content)
assert parsed.response == parsed_content["response"]


@pytest.mark.parametrize("use_max_completion_tokens", [True, False])
Expand Down

0 comments on commit cc12670

Please sign in to comment.