From 6fd9deb99ca25141129971fa9cd4654414575d6e Mon Sep 17 00:00:00 2001 From: gbaian10 <34255899+gbaian10@users.noreply.github.com> Date: Sat, 12 Oct 2024 05:50:37 +0800 Subject: [PATCH] refactor: Correct type hint (#56) --- src/agent/research_assistant.py | 2 ++ src/service/service.py | 6 ++---- src/service/test_service.py | 6 ++++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/agent/research_assistant.py b/src/agent/research_assistant.py index 2bc72d4..2254c4a 100644 --- a/src/agent/research_assistant.py +++ b/src/agent/research_assistant.py @@ -153,6 +153,8 @@ def check_safety(state: AgentState) -> Literal["unsafe", "safe"]: # After "model", if there are tool calls, run "tools". Otherwise END. def pending_tool_calls(state: AgentState) -> Literal["tools", "done"]: last_message = state["messages"][-1] + if not isinstance(last_message, AIMessage): + raise TypeError(f"Expected AIMessage, got {type(last_message)}") if last_message.tool_calls: return "tools" return "done" diff --git a/src/service/service.py b/src/service/service.py index 03754c9..66af669 100644 --- a/src/service/service.py +++ b/src/service/service.py @@ -64,9 +64,7 @@ def _parse_input(user_input: UserInput) -> tuple[dict[str, Any], str]: return kwargs, run_id -def _remove_tool_calls( - content: str | list[str | dict], -) -> str | list[str | dict]: +def _remove_tool_calls(content: str | list[str | dict]) -> str | list[str | dict]: """Remove tool calls from content.""" if isinstance(content, str): return content @@ -149,7 +147,7 @@ async def message_generator(user_input: StreamInput) -> AsyncGenerator[str, None yield "data: [DONE]\n\n" -def _sse_response_example() -> dict[str, Any]: +def _sse_response_example() -> dict[int, Any]: return { status.HTTP_200_OK: { "description": "Server Sent Event Response", diff --git a/src/service/test_service.py b/src/service/test_service.py index a45ca9f..41a4d43 100644 --- a/src/service/test_service.py +++ b/src/service/test_service.py @@ -1,7 +1,9 @@ from unittest.mock import AsyncMock, patch +import langsmith from fastapi.testclient import TestClient from langchain_core.messages import AIMessage +from langgraph.graph.state import CompiledStateGraph from schema import ChatMessage from service import app @@ -10,7 +12,7 @@ @patch("service.service.research_assistant") -def test_invoke(mock_agent: TestClient) -> None: +def test_invoke(mock_agent: CompiledStateGraph) -> None: QUESTION = "What is the weather in Tokyo?" ANSWER = "The weather in Tokyo is 70 degrees." agent_response = {"messages": [AIMessage(content=ANSWER)]} @@ -30,7 +32,7 @@ def test_invoke(mock_agent: TestClient) -> None: @patch("service.service.LangsmithClient") -def test_feedback(mock_client: TestClient) -> None: +def test_feedback(mock_client: langsmith.Client) -> None: ls_instance = mock_client.return_value ls_instance.create_feedback.return_value = None body = {