diff --git a/libs/vertexai/langchain_google_vertexai/functions_utils.py b/libs/vertexai/langchain_google_vertexai/functions_utils.py index 1524d671..2bf56608 100644 --- a/libs/vertexai/langchain_google_vertexai/functions_utils.py +++ b/libs/vertexai/langchain_google_vertexai/functions_utils.py @@ -1,5 +1,5 @@ import json -from typing import Dict, List, Type, Union +from typing import Any, Dict, List, Type, Union from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import BaseOutputParser @@ -25,17 +25,7 @@ def _format_pydantic_to_vertex_function( return { "name": schema["title"], "description": schema.get("description", ""), - "parameters": { - "properties": { - k: { - "type": v["type"], - "description": v.get("description"), - } - for k, v in schema["properties"].items() - }, - "required": schema["required"], - "type": schema["type"], - }, + "parameters": _get_parameters_from_schema(schema=schema), } @@ -48,17 +38,7 @@ def _format_tool_to_vertex_function(tool: BaseTool) -> FunctionDescription: return { "name": tool.name or schema["title"], "description": tool.description or schema["description"], - "parameters": { - "properties": { - k: { - "type": v["type"], - "description": v.get("description"), - } - for k, v in schema["properties"].items() - }, - "required": schema["required"], - "type": schema["type"], - }, + "parameters": _get_parameters_from_schema(schema=schema), } else: return { @@ -89,6 +69,37 @@ def _format_tools_to_vertex_tool( return [VertexTool(function_declarations=function_declarations)] +def _get_parameters_from_schema(schema: Dict[str, Any]) -> Dict[str, Any]: + """Given a schema, format the parameters key to match VertexAI + expected input. + + Args: + schema: Dictionary that must have the following keys. + + Returns: + Dictionary with the formatted parameters. + """ + + parameters = {} + + parameters["type"] = schema["type"] + + if "required" in schema: + parameters["required"] = schema["required"] + + schema_properties: Dict[str, Any] = schema.get("properties", {}) + + parameters["properties"] = { + parameter_name: { + "type": parameter_dict["type"], + "description": parameter_dict.get("description"), + } + for parameter_name, parameter_dict in schema_properties.items() + } + + return parameters + + class PydanticFunctionsOutputParser(BaseOutputParser): """Parse an output as a pydantic object. diff --git a/libs/vertexai/tests/unit_tests/test_function_utils.py b/libs/vertexai/tests/unit_tests/test_function_utils.py new file mode 100644 index 00000000..c1248f2d --- /dev/null +++ b/libs/vertexai/tests/unit_tests/test_function_utils.py @@ -0,0 +1,45 @@ +from langchain_core.tools import tool + +from langchain_google_vertexai.functions_utils import _format_tool_to_vertex_function + + +def test_format_tool_to_vertex_function(): + @tool + def get_datetime() -> str: + """Gets the current datetime""" + import datetime + + return datetime.datetime.now().strftime("%Y-%m-%d") + + schema = _format_tool_to_vertex_function(get_datetime) # type: ignore + + assert schema["name"] == "get_datetime" + assert schema["description"] == "get_datetime() -> str - Gets the current datetime" + assert "parameters" in schema + assert "required" not in schema["parameters"] + + @tool + def sum_two_numbers(a: float, b: float) -> str: + """Sum two numbers 'a' and 'b'. + + Returns: + a + b in string format + """ + return str(a + b) + + schema = _format_tool_to_vertex_function(sum_two_numbers) # type: ignore + + assert schema["name"] == "sum_two_numbers" + assert "parameters" in schema + assert len(schema["parameters"]["required"]) == 2 + + @tool + def do_something_optional(a: float, b: float = 0) -> str: + """Some description""" + return str(a + b) + + schema = _format_tool_to_vertex_function(do_something_optional) # type: ignore + + assert schema["name"] == "do_something_optional" + assert "parameters" in schema + assert len(schema["parameters"]["required"]) == 1