Skip to content

Commit

Permalink
[VertexAI] Take vertexai init location as default location for BaseVe…
Browse files Browse the repository at this point in the history
…rtexAI (#683)
  • Loading branch information
jzaldi authored Jan 13, 2025
1 parent 49d3844 commit fda0eaf
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
4 changes: 3 additions & 1 deletion libs/vertexai/langchain_google_vertexai/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ class _VertexAIBase(BaseModel):
def validate_params_base(cls, values: dict) -> Any:
if "model" in values and "model_name" not in values:
values["model_name"] = values.pop("model")
if values.get("api_transport") is None:
if "api_transport" not in values:
values["api_transport"] = initializer.global_config._api_transport
if "location" not in values:
values["location"] = initializer.global_config.location
if values.get("api_endpoint"):
api_endpoint = values["api_endpoint"]
else:
Expand Down
23 changes: 23 additions & 0 deletions libs/vertexai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import pytest
import requests
import vertexai # type: ignore[import-untyped, unused-ignore]
from google.cloud import storage
from google.cloud.aiplatform_v1beta1.types import Blob, Content, Part
from google.oauth2 import service_account
Expand Down Expand Up @@ -772,6 +773,8 @@ def test_prediction_client_transport():
assert model.prediction_client.transport.kind == "rest"
assert model.async_prediction_client.transport.kind == "rest"

vertexai.init(api_transport="grpc") # Reset global config to "grpc"


@pytest.mark.extended
def test_structured_output_schema_json():
Expand Down Expand Up @@ -1201,3 +1204,23 @@ def test_logprobs() -> None:
llm3 = ChatVertexAI(model="gemini-1.5-flash", logprobs=False)
msg3 = llm3.invoke("howdy")
assert msg3.response_metadata.get("logprobs_result") is None


def test_location_init() -> None:
# If I don't initialize vertexai before, defaults to us-central-1
llm = ChatVertexAI(model="gemini-1.5-flash", logprobs=2)
assert llm.location == "us-central1"

# If I init vertexai with other region the model is in that particular region
vertexai.init(location="europe-west1")
llm = ChatVertexAI(model="gemini-1.5-flash", logprobs=2)
assert llm.location == "europe-west1"

# If I specify the location, it follows that location
llm = ChatVertexAI(model="gemini-1.5-flash", logprobs=2, location="europe-west2")
assert llm.location == "europe-west2"

# It reverts to the default
vertexai.init(location="us-central1")
llm = ChatVertexAI(model="gemini-1.5-flash", logprobs=2)
assert llm.location == "us-central1"
3 changes: 3 additions & 0 deletions libs/vertexai/tests/integration_tests/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def chat_model_params(self) -> dict:
"model_name": "gemini-2.0-flash-exp",
"rate_limiter": rate_limiter,
"temperature": 0,
"api_transport": None,
}

@property
Expand Down Expand Up @@ -52,6 +53,7 @@ def chat_model_params(self) -> dict:
"model_name": "gemini-1.0-pro-001",
"rate_limiter": rate_limiter,
"temperature": 0,
"api_transport": None,
}

@pytest.mark.xfail(reason="Gemini 1.0 doesn't support tool_choice='any'")
Expand All @@ -72,6 +74,7 @@ def chat_model_params(self) -> dict:
"model_name": "gemini-1.5-pro-001",
"rate_limiter": rate_limiter,
"temperature": 0,
"api_transport": None,
}

@property
Expand Down

0 comments on commit fda0eaf

Please sign in to comment.