Skip to content

Commit

Permalink
openai[patch]: remove optional defaults (#29097)
Browse files Browse the repository at this point in the history
Merging into v0.3 branch
  • Loading branch information
ccurme authored Jan 8, 2025
1 parent 2b09f79 commit 8c3cdc6
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 21 deletions.
10 changes: 6 additions & 4 deletions libs/partners/openai/langchain_openai/chat_models/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning
timeout: Union[float, Tuple[float, float], Any, None]
Timeout for requests.
max_retries: int
max_retries: Optional[int]
Max number of retries.
organization: Optional[str]
OpenAI organization ID. If not passed in will be read from env
Expand Down Expand Up @@ -586,9 +586,9 @@ def is_lc_serializable(cls) -> bool:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.n < 1:
if self.n is not None and self.n < 1:
raise ValueError("n must be at least 1.")
if self.n > 1 and self.streaming:
elif self.n is not None and self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")

if self.disabled_params is None:
Expand Down Expand Up @@ -641,10 +641,12 @@ def validate_environment(self) -> Self:
"organization": self.openai_organization,
"base_url": self.openai_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries

if not self.client:
sync_specific = {"http_client": self.http_client}
self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
Expand Down
20 changes: 11 additions & 9 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ class BaseChatOpenAI(BaseChatModel):
root_async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
"""Model name to use."""
temperature: float = 0.7
temperature: Optional[float] = None
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
Expand All @@ -430,7 +430,7 @@ class BaseChatOpenAI(BaseChatModel):
)
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
None."""
max_retries: int = 2
max_retries: Optional[int] = None
"""Maximum number of retries to make when generating."""
presence_penalty: Optional[float] = None
"""Penalizes repeated tokens."""
Expand All @@ -448,7 +448,7 @@ class BaseChatOpenAI(BaseChatModel):
"""Modify the likelihood of specified tokens appearing in the completion."""
streaming: bool = False
"""Whether to stream the results or not."""
n: int = 1
n: Optional[int] = None
"""Number of chat completions to generate for each prompt."""
top_p: Optional[float] = None
"""Total probability mass of tokens to consider at each step."""
Expand Down Expand Up @@ -532,9 +532,9 @@ def validate_temperature(cls, values: Dict[str, Any]) -> Any:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.n < 1:
if self.n is not None and self.n < 1:
raise ValueError("n must be at least 1.")
if self.n > 1 and self.streaming:
elif self.n is not None and self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")

# Check OPENAI_ORGANIZATION for backwards compatibility.
Expand All @@ -551,10 +551,12 @@ def validate_environment(self) -> Self:
"organization": self.openai_organization,
"base_url": self.openai_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries

if self.openai_proxy and (self.http_client or self.http_async_client):
openai_proxy = self.openai_proxy
http_client = self.http_client
Expand Down Expand Up @@ -609,14 +611,14 @@ def _default_params(self) -> Dict[str, Any]:
"stop": self.stop or None, # also exclude empty list for this
"max_tokens": self.max_tokens,
"extra_body": self.extra_body,
"n": self.n,
"temperature": self.temperature,
"reasoning_effort": self.reasoning_effort,
}

params = {
"model": self.model_name,
"stream": self.streaming,
"n": self.n,
"temperature": self.temperature,
**{k: v for k, v in exclude_if_none.items() if v is not None},
**self.model_kwargs,
}
Expand Down Expand Up @@ -1565,7 +1567,7 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
timeout: Union[float, Tuple[float, float], Any, None]
Timeout for requests.
max_retries: int
max_retries: Optional[int]
Max number of retries.
api_key: Optional[str]
OpenAI API key. If not passed in will be read from env var OPENAI_API_KEY.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
}),
'max_retries': 2,
'max_tokens': 100,
'n': 1,
'openai_api_key': dict({
'id': list([
'AZURE_OPENAI_API_KEY',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
'max_retries': 2,
'max_tokens': 100,
'model_name': 'gpt-3.5-turbo',
'n': 1,
'openai_api_key': dict({
'id': list([
'OPENAI_API_KEY',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -877,8 +877,6 @@ def test__get_request_payload() -> None:
],
"model": "gpt-4o-2024-08-06",
"stream": False,
"n": 1,
"temperature": 0.7,
}
payload = llm._get_request_payload(messages)
assert payload == expected
Expand Down
3 changes: 3 additions & 0 deletions libs/partners/xai/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ integration_test integration_tests: TEST_FILE=tests/integration_tests/
test tests:
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)

test_watch:
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)

integration_test integration_tests:
poetry run pytest $(TEST_FILE)

Expand Down
7 changes: 4 additions & 3 deletions libs/partners/xai/langchain_xai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,9 @@ def _get_ls_params(
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.n < 1:
if self.n is not None and self.n < 1:
raise ValueError("n must be at least 1.")
if self.n > 1 and self.streaming:
if self.n is not None and self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")

client_params: dict = {
Expand All @@ -331,10 +331,11 @@ def validate_environment(self) -> Self:
),
"base_url": self.xai_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries

if client_params["api_key"] is None:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
'max_retries': 2,
'max_tokens': 100,
'model_name': 'grok-beta',
'n': 1,
'request_timeout': 60.0,
'stop': list([
]),
Expand Down

0 comments on commit 8c3cdc6

Please sign in to comment.