Skip to content

Commit

Permalink
refactor: Correct type hint (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
gbaian10 authored Oct 11, 2024
1 parent 879dfcb commit 6fd9deb
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
2 changes: 2 additions & 0 deletions src/agent/research_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def check_safety(state: AgentState) -> Literal["unsafe", "safe"]:
# After "model", if there are tool calls, run "tools". Otherwise END.
def pending_tool_calls(state: AgentState) -> Literal["tools", "done"]:
last_message = state["messages"][-1]
if not isinstance(last_message, AIMessage):
raise TypeError(f"Expected AIMessage, got {type(last_message)}")
if last_message.tool_calls:
return "tools"
return "done"
Expand Down
6 changes: 2 additions & 4 deletions src/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def _parse_input(user_input: UserInput) -> tuple[dict[str, Any], str]:
return kwargs, run_id


def _remove_tool_calls(
content: str | list[str | dict],
) -> str | list[str | dict]:
def _remove_tool_calls(content: str | list[str | dict]) -> str | list[str | dict]:
"""Remove tool calls from content."""
if isinstance(content, str):
return content
Expand Down Expand Up @@ -149,7 +147,7 @@ async def message_generator(user_input: StreamInput) -> AsyncGenerator[str, None
yield "data: [DONE]\n\n"


def _sse_response_example() -> dict[str, Any]:
def _sse_response_example() -> dict[int, Any]:
return {
status.HTTP_200_OK: {
"description": "Server Sent Event Response",
Expand Down
6 changes: 4 additions & 2 deletions src/service/test_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from unittest.mock import AsyncMock, patch

import langsmith
from fastapi.testclient import TestClient
from langchain_core.messages import AIMessage
from langgraph.graph.state import CompiledStateGraph

from schema import ChatMessage
from service import app
Expand All @@ -10,7 +12,7 @@


@patch("service.service.research_assistant")
def test_invoke(mock_agent: TestClient) -> None:
def test_invoke(mock_agent: CompiledStateGraph) -> None:
QUESTION = "What is the weather in Tokyo?"
ANSWER = "The weather in Tokyo is 70 degrees."
agent_response = {"messages": [AIMessage(content=ANSWER)]}
Expand All @@ -30,7 +32,7 @@ def test_invoke(mock_agent: TestClient) -> None:


@patch("service.service.LangsmithClient")
def test_feedback(mock_client: TestClient) -> None:
def test_feedback(mock_client: langsmith.Client) -> None:
ls_instance = mock_client.return_value
ls_instance.create_feedback.return_value = None
body = {
Expand Down

0 comments on commit 6fd9deb

Please sign in to comment.