diff --git a/libs/vertexai/langchain_google_vertexai/_base.py b/libs/vertexai/langchain_google_vertexai/_base.py index 32d895a4..e7ffc646 100644 --- a/libs/vertexai/langchain_google_vertexai/_base.py +++ b/libs/vertexai/langchain_google_vertexai/_base.py @@ -105,8 +105,10 @@ class _VertexAIBase(BaseModel): def validate_params_base(cls, values: dict) -> Any: if "model" in values and "model_name" not in values: values["model_name"] = values.pop("model") - if values.get("api_transport") is None: + if "api_transport" not in values: values["api_transport"] = initializer.global_config._api_transport + if "location" not in values: + values["location"] = initializer.global_config.location if values.get("api_endpoint"): api_endpoint = values["api_endpoint"] else: diff --git a/libs/vertexai/tests/integration_tests/test_chat_models.py b/libs/vertexai/tests/integration_tests/test_chat_models.py index b943f6b1..a3eb26db 100644 --- a/libs/vertexai/tests/integration_tests/test_chat_models.py +++ b/libs/vertexai/tests/integration_tests/test_chat_models.py @@ -8,6 +8,7 @@ import pytest import requests +import vertexai # type: ignore[import-untyped, unused-ignore] from google.cloud import storage from google.cloud.aiplatform_v1beta1.types import Blob, Content, Part from google.oauth2 import service_account @@ -772,6 +773,8 @@ def test_prediction_client_transport(): assert model.prediction_client.transport.kind == "rest" assert model.async_prediction_client.transport.kind == "rest" + vertexai.init(api_transport="grpc") # Reset global config to "grpc" + @pytest.mark.extended def test_structured_output_schema_json(): @@ -1201,3 +1204,23 @@ def test_logprobs() -> None: llm3 = ChatVertexAI(model="gemini-1.5-flash", logprobs=False) msg3 = llm3.invoke("howdy") assert msg3.response_metadata.get("logprobs_result") is None + + +def test_location_init() -> None: + # If I don't initialize vertexai before, defaults to us-central-1 + llm = ChatVertexAI(model="gemini-1.5-flash", logprobs=2) + assert llm.location == "us-central1" + + # If I init vertexai with other region the model is in that particular region + vertexai.init(location="europe-west1") + llm = ChatVertexAI(model="gemini-1.5-flash", logprobs=2) + assert llm.location == "europe-west1" + + # If I specify the location, it follows that location + llm = ChatVertexAI(model="gemini-1.5-flash", logprobs=2, location="europe-west2") + assert llm.location == "europe-west2" + + # It reverts to the default + vertexai.init(location="us-central1") + llm = ChatVertexAI(model="gemini-1.5-flash", logprobs=2) + assert llm.location == "us-central1" diff --git a/libs/vertexai/tests/integration_tests/test_standard.py b/libs/vertexai/tests/integration_tests/test_standard.py index f811c57e..c0f52461 100644 --- a/libs/vertexai/tests/integration_tests/test_standard.py +++ b/libs/vertexai/tests/integration_tests/test_standard.py @@ -25,6 +25,7 @@ def chat_model_params(self) -> dict: "model_name": "gemini-2.0-flash-exp", "rate_limiter": rate_limiter, "temperature": 0, + "api_transport": None, } @property @@ -52,6 +53,7 @@ def chat_model_params(self) -> dict: "model_name": "gemini-1.0-pro-001", "rate_limiter": rate_limiter, "temperature": 0, + "api_transport": None, } @pytest.mark.xfail(reason="Gemini 1.0 doesn't support tool_choice='any'") @@ -72,6 +74,7 @@ def chat_model_params(self) -> dict: "model_name": "gemini-1.5-pro-001", "rate_limiter": rate_limiter, "temperature": 0, + "api_transport": None, } @property