From 670368a2138d4405291f18415005386f318c60e8 Mon Sep 17 00:00:00 2001 From: Jorge Date: Mon, 19 Feb 2024 20:02:55 +0100 Subject: [PATCH 1/4] Fixed function parsing without arguments. Added unit testing --- .../functions_utils.py | 2 +- .../tests/unit_tests/test_function_utils.py | 46 +++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 libs/vertexai/tests/unit_tests/test_function_utils.py diff --git a/libs/vertexai/langchain_google_vertexai/functions_utils.py b/libs/vertexai/langchain_google_vertexai/functions_utils.py index 1ae43dff..b5f56504 100644 --- a/libs/vertexai/langchain_google_vertexai/functions_utils.py +++ b/libs/vertexai/langchain_google_vertexai/functions_utils.py @@ -54,7 +54,7 @@ def _format_tool_to_vertex_function(tool: BaseTool) -> FunctionDescription: } for k, v in schema["properties"].items() }, - "required": schema["required"], + "required": schema.get("required", []), "type": schema["type"], }, } diff --git a/libs/vertexai/tests/unit_tests/test_function_utils.py b/libs/vertexai/tests/unit_tests/test_function_utils.py new file mode 100644 index 00000000..1d004261 --- /dev/null +++ b/libs/vertexai/tests/unit_tests/test_function_utils.py @@ -0,0 +1,46 @@ +from langchain_core.tools import tool + +from langchain_google_vertexai.functions_utils import _format_tool_to_vertex_function + +def test_format_tool_to_vertex_function(): + + @tool + def get_datetime() -> str: + """ Gets the current datetime + """ + import datetime + return datetime.datetime.now().strftime("%Y-%m-%d") + + schema = _format_tool_to_vertex_function(get_datetime) + + assert schema["name"] == "get_datetime" + assert schema["description"] == "get_datetime() -> str - Gets the current datetime" + assert "parameters" in schema + assert len(schema["parameters"]["required"]) == 0 + + @tool + def sum_two_numbers(a: float, b: float) -> str: + """ Sum two numbers 'a' and 'b'. + + Returns: + a + b in string format + """ + return str(a + b) + + schema = _format_tool_to_vertex_function(sum_two_numbers) + + assert schema["name"] == "sum_two_numbers" + assert "parameters" in schema + assert len(schema["parameters"]["required"]) == 2 + + @tool + def do_something_optional(a: float, b: float = 0) -> str: + """ Some description + """ + return a + b + + schema = _format_tool_to_vertex_function(do_something_optional) + + assert schema["name"] == "do_something_optional" + assert "parameters" in schema + assert len(schema["parameters"]["required"]) == 1 From 553b366fe29525026286d6bd8af11ad3282dbd44 Mon Sep 17 00:00:00 2001 From: Jorge Date: Mon, 19 Feb 2024 20:06:26 +0100 Subject: [PATCH 2/4] Fromat --- libs/vertexai/tests/unit_tests/test_function_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/libs/vertexai/tests/unit_tests/test_function_utils.py b/libs/vertexai/tests/unit_tests/test_function_utils.py index 1d004261..70e323a8 100644 --- a/libs/vertexai/tests/unit_tests/test_function_utils.py +++ b/libs/vertexai/tests/unit_tests/test_function_utils.py @@ -2,6 +2,7 @@ from langchain_google_vertexai.functions_utils import _format_tool_to_vertex_function + def test_format_tool_to_vertex_function(): @tool From ffb545534c5c81d7640c8cae7005f78b26cb1544 Mon Sep 17 00:00:00 2001 From: Jorge Date: Tue, 20 Feb 2024 20:29:17 +0100 Subject: [PATCH 3/4] Format + wrap duplicate code in function --- .../functions_utils.py | 57 +++++++++++-------- .../tests/unit_tests/test_function_utils.py | 22 ++++--- 2 files changed, 44 insertions(+), 35 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/functions_utils.py b/libs/vertexai/langchain_google_vertexai/functions_utils.py index b5f56504..fa2940c5 100644 --- a/libs/vertexai/langchain_google_vertexai/functions_utils.py +++ b/libs/vertexai/langchain_google_vertexai/functions_utils.py @@ -1,5 +1,5 @@ import json -from typing import Dict, List, Type, Union +from typing import Any, Dict, List, Type, Union from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import BaseOutputParser @@ -23,17 +23,7 @@ def _format_pydantic_to_vertex_function( return { "name": schema["title"], "description": schema.get("description", ""), - "parameters": { - "properties": { - k: { - "type": v["type"], - "description": v.get("description"), - } - for k, v in schema["properties"].items() - }, - "required": schema["required"], - "type": schema["type"], - }, + "parameters": _get_parameters_from_schema(schema=schema), } @@ -46,17 +36,7 @@ def _format_tool_to_vertex_function(tool: BaseTool) -> FunctionDescription: return { "name": tool.name or schema["title"], "description": tool.description or schema["description"], - "parameters": { - "properties": { - k: { - "type": v["type"], - "description": v.get("description"), - } - for k, v in schema["properties"].items() - }, - "required": schema.get("required", []), - "type": schema["type"], - }, + "parameters": _get_parameters_from_schema(schema=schema), } else: return { @@ -87,6 +67,37 @@ def _format_tools_to_vertex_tool( return [VertexTool(function_declarations=function_declarations)] +def _get_parameters_from_schema(schema: Dict[str, Any]) -> Dict[str, Any]: + """Given a schema, format the parameters key to match VertexAI + expected input. + + Args: + schema: Dictionary that must have the following keys. + + Returns: + Dictionary with the formatted parameters. + """ + + parameters = {} + + parameters["type"] = schema["type"] + + if "required" in schema: + parameters["required"] = schema["required"] + + schema_properties: Dict[str, Any] = schema.get("properties", {}) + + parameters["properties"] = { + parameter_name: { + "type": parameter_dict["type"], + "description": parameter_dict.get("description"), + } + for parameter_name, parameter_dict in schema_properties.items() + } + + return parameters + + class PydanticFunctionsOutputParser(BaseOutputParser): """Parse an output as a pydantic object. diff --git a/libs/vertexai/tests/unit_tests/test_function_utils.py b/libs/vertexai/tests/unit_tests/test_function_utils.py index 70e323a8..46d3d367 100644 --- a/libs/vertexai/tests/unit_tests/test_function_utils.py +++ b/libs/vertexai/tests/unit_tests/test_function_utils.py @@ -4,30 +4,29 @@ def test_format_tool_to_vertex_function(): - @tool def get_datetime() -> str: - """ Gets the current datetime - """ + """Gets the current datetime""" import datetime + return datetime.datetime.now().strftime("%Y-%m-%d") - + schema = _format_tool_to_vertex_function(get_datetime) assert schema["name"] == "get_datetime" - assert schema["description"] == "get_datetime() -> str - Gets the current datetime" + assert schema["description"] == "get_datetime() -> str - Gets the current datetime" assert "parameters" in schema - assert len(schema["parameters"]["required"]) == 0 + assert "required" not in schema["parameters"] @tool def sum_two_numbers(a: float, b: float) -> str: - """ Sum two numbers 'a' and 'b'. + """Sum two numbers 'a' and 'b'. Returns: a + b in string format """ return str(a + b) - + schema = _format_tool_to_vertex_function(sum_two_numbers) assert schema["name"] == "sum_two_numbers" @@ -36,12 +35,11 @@ def sum_two_numbers(a: float, b: float) -> str: @tool def do_something_optional(a: float, b: float = 0) -> str: - """ Some description - """ + """Some description""" return a + b - + schema = _format_tool_to_vertex_function(do_something_optional) - + assert schema["name"] == "do_something_optional" assert "parameters" in schema assert len(schema["parameters"]["required"]) == 1 From 48ebb9820a6c191ed2878a6b06b36a3802fb9a35 Mon Sep 17 00:00:00 2001 From: Jorge Date: Tue, 20 Feb 2024 21:59:42 +0100 Subject: [PATCH 4/4] Fix mypy in unit_tests --- libs/vertexai/tests/unit_tests/test_function_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/libs/vertexai/tests/unit_tests/test_function_utils.py b/libs/vertexai/tests/unit_tests/test_function_utils.py index 46d3d367..c1248f2d 100644 --- a/libs/vertexai/tests/unit_tests/test_function_utils.py +++ b/libs/vertexai/tests/unit_tests/test_function_utils.py @@ -11,7 +11,7 @@ def get_datetime() -> str: return datetime.datetime.now().strftime("%Y-%m-%d") - schema = _format_tool_to_vertex_function(get_datetime) + schema = _format_tool_to_vertex_function(get_datetime) # type: ignore assert schema["name"] == "get_datetime" assert schema["description"] == "get_datetime() -> str - Gets the current datetime" @@ -27,7 +27,7 @@ def sum_two_numbers(a: float, b: float) -> str: """ return str(a + b) - schema = _format_tool_to_vertex_function(sum_two_numbers) + schema = _format_tool_to_vertex_function(sum_two_numbers) # type: ignore assert schema["name"] == "sum_two_numbers" assert "parameters" in schema @@ -36,9 +36,9 @@ def sum_two_numbers(a: float, b: float) -> str: @tool def do_something_optional(a: float, b: float = 0) -> str: """Some description""" - return a + b + return str(a + b) - schema = _format_tool_to_vertex_function(do_something_optional) + schema = _format_tool_to_vertex_function(do_something_optional) # type: ignore assert schema["name"] == "do_something_optional" assert "parameters" in schema