Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Human-in-the-loop #117

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions src/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from agents.bg_task_agent.bg_task_agent import bg_task_agent
from agents.chatbot import chatbot
from agents.research_assistant import research_assistant
from agents.interrupted_assistant import interrupted_assistant
from schema import AgentInfo

DEFAULT_AGENT = "research-assistant"
Expand All @@ -22,6 +23,7 @@ class Agent:
description="A research assistant with web search and calculator.", graph=research_assistant
),
"bg-task-agent": Agent(description="A background task agent.", graph=bg_task_agent),
"interrupted_assistant": Agent(description="Exampple with human-in-the-loop", graph=interrupted_assistant),
}


Expand Down
95 changes: 95 additions & 0 deletions src/agents/interrupted_assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from datetime import datetime
from typing import Literal

from langchain_community.tools import DuckDuckGoSearchResults
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, MessagesState, StateGraph
from langgraph.managed import RemainingSteps
from langgraph.prebuilt import ToolNode

from agents.tools import calculator
from core import get_model, settings


class AgentState(MessagesState, total=False):
"""`total=False` is PEP589 specs.

documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality
"""
remaining_steps: RemainingSteps


web_search = DuckDuckGoSearchResults(name="WebSearch")
tools = [web_search, calculator]

current_date = datetime.now().strftime("%B %d, %Y")
instructions = f"""
You are a helpful research assistant with the ability to search the web and use other tools.
Today's date is {current_date}.

NOTE: THE USER CAN'T SEE THE TOOL RESPONSE.

A few things to remember:
- Please include markdown-formatted links to any citations used in your response. Only include one
or two citations per response unless more are needed. ONLY USE LINKS RETURNED BY THE TOOLS.
- Use calculator tool with numexpr to answer math questions. The user does not understand numexpr,
so for the final response, use human readable format - e.g. "300 * 200", not "(300 \\times 200)".
- If API call is denied by user don't answer the question, just inform that you are canceling the
operation as they ask to.
"""

def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]:
model = model.bind_tools(tools)
preprocessor = RunnableLambda(
lambda state: [SystemMessage(content=instructions)] + state["messages"],
name="StateModifier",
)
return preprocessor | model


async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL))
model_runnable = wrap_model(m)
response = await model_runnable.ainvoke(state, config)

if state["remaining_steps"] < 2 and response.tool_calls:
return {
"messages": [
AIMessage(
id=response.id,
content="Sorry, need more steps to process this request.",
)
]
}
# We return a list, because this will get added to the existing list
return {"messages": [response]}


# Define the graph
agent = StateGraph(AgentState)
agent.add_node("model", acall_model)
agent.add_node("tools", ToolNode(tools))
agent.set_entry_point("model")

# Always run "model" after "tools"
agent.add_edge("tools", "model")

# After "model", if there are tool calls, run "tools". Otherwise END.
def pending_tool_calls(state: AgentState) -> Literal["tools", "done"]:
last_message = state["messages"][-1]
if not isinstance(last_message, AIMessage):
raise TypeError(f"Expected AIMessage, got {type(last_message)}")
if last_message.tool_calls:
return "tools"
return "done"


agent.add_conditional_edges("model", pending_tool_calls, {"tools": "tools", "done": END})

interrupted_assistant = agent.compile(
checkpointer=MemorySaver(),
interrupt_before=["tools"]
)
25 changes: 22 additions & 3 deletions src/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ChatHistory,
ChatHistoryInput,
ChatMessage,
InterruptMessage,
Feedback,
ServiceMetadata,
StreamInput,
Expand Down Expand Up @@ -148,7 +149,7 @@ def invoke(

return ChatMessage.model_validate(response.json())

def _parse_stream_line(self, line: str) -> ChatMessage | str | None:
def _parse_stream_line(self, line: str) -> ChatMessage | InterruptMessage | str | None:
line = line.strip()
if line.startswith("data: "):
data = line[6:]
Expand All @@ -165,6 +166,11 @@ def _parse_stream_line(self, line: str) -> ChatMessage | str | None:
return ChatMessage.model_validate(parsed["content"])
except Exception as e:
raise Exception(f"Server returned invalid message: {e}")
case "interrupt":
try:
return InterruptMessage.model_validate(parsed["content"])
except Exception as e:
raise Exception(f"Server returned invalid interrupt message: {e}")
case "token":
# Yield the str token directly
return parsed["content"]
Expand Down Expand Up @@ -226,8 +232,11 @@ async def astream(
message: str,
model: str | None = None,
thread_id: str | None = None,
run_id: str | None = None,
tool_call_id: str | None = None,
interruption: bool = False,
stream_tokens: bool = True,
) -> AsyncGenerator[ChatMessage | str, None]:
) -> AsyncGenerator[ChatMessage | InterruptMessage | str, None]:
"""
Stream the agent's response asynchronously.

Expand All @@ -241,17 +250,27 @@ async def astream(
thread_id (str, optional): Thread ID for continuing a conversation
stream_tokens (bool, optional): Stream tokens as they are generated
Default: True
run_id (str, optional): Needed after interruption
tool_call_id (str, optional): Needed after interruption
interruption (bool, optional): Needed after interruption

Returns:
AsyncGenerator[ChatMessage | str, None]: The response from the agent
AsyncGenerator[ChatMessage | InterruptMessage | str, None]: The response from the agent
"""
if not self.agent:
raise AgentClientError("No agent selected. Use update_agent() to select an agent.")
request = StreamInput(message=message, stream_tokens=stream_tokens)
if interruption:
request.type = "interrupt"
if thread_id:
request.thread_id = thread_id
if model:
request.model = model
if run_id:
request.run_id = run_id
if tool_call_id:
request.tool_call_id = tool_call_id

async with httpx.AsyncClient() as client:
try:
async with client.stream(
Expand Down
2 changes: 2 additions & 0 deletions src/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
ChatHistory,
ChatHistoryInput,
ChatMessage,
InterruptMessage,
Feedback,
FeedbackResponse,
ServiceMetadata,
Expand All @@ -16,6 +17,7 @@
"AllModelEnum",
"UserInput",
"ChatMessage",
"InterruptMessage",
"ServiceMetadata",
"StreamInput",
"Feedback",
Expand Down
59 changes: 59 additions & 0 deletions src/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ class ServiceMetadata(BaseModel):
class UserInput(BaseModel):
"""Basic user input for the agent."""

type: Literal["user", "interrupt"] = Field(
description="Source of the message.",
default="user",
examples=["user", "interrupt"],
)

message: str = Field(
description="User input to the agent.",
examples=["What is the weather in Tokyo?"],
Expand All @@ -55,6 +61,16 @@ class UserInput(BaseModel):
default=None,
examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
)
tool_call_id: str | None = Field(
description="Tool call that this message is responding to after an interruption.",
default=None,
examples=["call_Jja7J89XsjrOLA5r!MEOW!SL"],
)
run_id: str | None = Field(
description="Run ID of the message to continue after interruption.",
default=None,
examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
)
Comment on lines +64 to +73
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the client need to store and send back the tool_call_id(s) and run_id ? Couldn't it just be tracked in the server (or maybe the checkpointer already does it for us?) for the relevant thread_id instead of returning it, to keep the interface simpler?



class StreamInput(UserInput):
Expand Down Expand Up @@ -125,6 +141,49 @@ def pretty_repr(self) -> str:
def pretty_print(self) -> None:
print(self.pretty_repr()) # noqa: T201

class InterruptMessage(BaseModel):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this actually need to be a whole new class? Why not just make a new interrupt type in ChatMessage? It already has a lot of the relevant fields. And could use custom_data attribute for any non-general stuff.

msg_name: str = Field(
description="Name of the message. If created from call_tool it will be equal to the template name.",
examples=["book_table", "ask_permission", "default"],
)

args: dict[str, Any] = Field(
description="Data/information to be added to message",
example=[{"num_personas": 4, "reservation_datetime": "2024-10-24 19:00"}],
default={},
)
confirmation_msg: str | None = Field(
description="Message to display.",
default=None, ## TODO: antes de hacer el diplay si esto es None tendre que invocar select_confirmation con default
example=["Needs confirmation.", "{num_personas} people at {reservation_datetime}. It's that correct?"],
)
user_input: str | None = Field(
description="Data added by user after interruption. None means that user said ok to confirmation.",
default="", #TODO: no uso none por defecto, porque un none significa okey continua la ejecicion
)
tool_id: str | None = Field(
description="ID of the interrupted tool.",
default=None,
examples=["call_Jja7J89XsjrOLA5r!MEOW!SL"]
)
run_id: str | None = Field(
description="Run ID of the message.",
default=None,
examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
)
original: dict[str, Any] = Field(
description="Original LangChain message triggering the interruption in serialized form.",
default={},
)
# model: str | None = Field(
# description="Model to run.",
# default=None,
# )
# thread_id: str | None = Field(
# description="Thread id.",
# default=None,
# )


class Feedback(BaseModel):
"""Feedback for a run, to record to LangSmith."""
Expand Down
51 changes: 44 additions & 7 deletions src/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from langchain_core._api import LangChainBetaWarning
from langchain_core.messages import AnyMessage, HumanMessage
from langchain_core.messages import AnyMessage, HumanMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from langgraph.graph.state import CompiledStateGraph
Expand All @@ -22,6 +22,7 @@
ChatHistory,
ChatHistoryInput,
ChatMessage,
InterruptMessage,
Feedback,
FeedbackResponse,
ServiceMetadata,
Expand All @@ -32,6 +33,7 @@
convert_message_content_to_string,
langchain_to_chat_message,
remove_tool_calls,
interrupt_from_call_tool
)

warnings.filterwarnings("ignore", category=LangChainBetaWarning)
Expand Down Expand Up @@ -81,14 +83,33 @@ async def info() -> ServiceMetadata:


def _parse_input(user_input: UserInput) -> tuple[dict[str, Any], UUID]:
run_id = uuid4()
run_id = user_input.run_id or uuid4()
thread_id = user_input.thread_id or str(uuid4())
if user_input.type == "user":
inputU = {"messages": [HumanMessage(content=user_input.message)]}
else:
if user_input.message == "aproved_tool":
inputU = None # This means to continue graph execution after interruption
elif user_input.message == "canceled_tool":
inputU = {"messages": [ToolMessage(
tool_call_id=user_input.tool_call_id,
content=f"API call denied by user. Ask them about de reason or what else can you help them with.",
)]}
else:
inputU = {"messages": [
ToolMessage(
tool_call_id=user_input.tool_call_id,
content=f"API call denied by user.",
),
HumanMessage(content=user_input.message)
]}

kwargs = {
"input": {"messages": [HumanMessage(content=user_input.message)]},
"config": RunnableConfig(
configurable={"thread_id": thread_id, "model": user_input.model}, run_id=run_id
),
}
"input": inputU,
"config": RunnableConfig(
configurable={"thread_id": thread_id, "model": user_input.model}, run_id=run_id
),
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll want to set up the pre-commit hooks and run the linter and tests locally. Otherwise the CI will keep failing and won't be able to merge this. See the contributing instructions in the README.

return kwargs, run_id


Expand Down Expand Up @@ -171,6 +192,22 @@ async def message_generator(
# So we only print non-empty content.
yield f"data: {json.dumps({'type': 'token', 'content': convert_message_content_to_string(content)})}\n\n"
continue

# Handle interruptions: https://langchain-ai.github.io/langgraph/how-tos/human_in_the_loop/breakpoints/
# To detect breakpoints which enable pausing graph execution at specific steps.
# If detected, interruption message is send to ask for human approval.
# Checkpoints ensure the graph can resume from the same state after human input.
# To specify breakpoints use interrupt_before in the agent.
snapshot = await agent.aget_state(kwargs["config"])
if snapshot.next:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment explaining what this check does and why?

try:
ai_message = langchain_to_chat_message(snapshot.values["messages"][-1])
ichat_message = interrupt_from_call_tool(call_tool=ai_message.tool_calls[0])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An AIMessage can include multiple tool calls. This would drop the additional ones. That seems bad.

ichat_message.run_id = str(run_id)
except Exception as e:
yield f"data: {json.dumps({'type': 'error', 'content': f'Error parsing interrupt message: {e}'})}\n\n"

yield f"data: {json.dumps({'type': 'interrupt', 'content': ichat_message.model_dump()})}\n\n"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only handles the /stream endpoint. What happens if I call /invoke on an interrupting agent?


yield "data: [DONE]\n\n"

Expand Down
17 changes: 16 additions & 1 deletion src/service/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
ChatMessage as LangchainChatMessage,
)

from schema import ChatMessage
from schema import ChatMessage, InterruptMessage


def convert_message_content_to_string(content: str | list[str | dict]) -> str:
Expand Down Expand Up @@ -74,3 +74,18 @@ def remove_tool_calls(content: str | list[str | dict]) -> str | list[str | dict]
for content_item in content
if isinstance(content_item, str) or content_item["type"] != "tool_use"
]


def interrupt_from_call_tool(call_tool: dict) -> InterruptMessage:
# Crear instancia de InterruptMessage a partir del diccionario call_tool
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind putting code comments in english to be consistent with the rest of the repo?

instance = InterruptMessage(
msg_name=call_tool["name"],
args=call_tool["args"],
tool_id=call_tool["id"]
)

instance.original = call_tool

instance.confirmation_msg = "Do you aprove the above actions?"

return instance
Loading