Skip to content

Commit

Permalink
openai: disable streaming for o1 by default (#29147)
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis authored Jan 11, 2025
1 parent 62074ba commit bbc3e3b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
9 changes: 9 additions & 0 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,15 @@ def validate_temperature(cls, values: Dict[str, Any]) -> Any:
values["temperature"] = 1
return values

@model_validator(mode="before")
@classmethod
def validate_disable_streaming(cls, values: Dict[str, Any]) -> Any:
"""Disable streaming if n > 1."""
model = values.get("model_name") or values.get("model") or ""
if model == "o1" and values.get("disable_streaming") is None:
values["disable_streaming"] = True
return values

@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1192,3 +1192,19 @@ def test_o1(use_max_completion_tokens: bool) -> None:
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content.upper() == response.content


@pytest.mark.scheduled
def test_o1_doesnt_stream() -> None:
"""
When this starts failing, remove the `disable_streaming` validator in
`BaseChatOpenAI`
"""
with pytest.raises(openai.BadRequestError):
list(ChatOpenAI(model="o1", disable_streaming=False).stream("how are you"))


@pytest.mark.scheduled
def test_o1_stream_default_works() -> None:
result = list(ChatOpenAI(model="o1").stream("say 'hi'"))
assert len(result) > 0

0 comments on commit bbc3e3b

Please sign in to comment.