-
Notifications
You must be signed in to change notification settings - Fork 222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Human-in-the-loop #117
base: main
Are you sure you want to change the base?
Human-in-the-loop #117
Changes from all commits
1501cf4
6503abb
b034f3a
85b25e9
b71eb5a
a4b6f85
aca0dad
c1be573
9b50fea
429da1b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from datetime import datetime | ||
from typing import Literal | ||
|
||
from langchain_community.tools import DuckDuckGoSearchResults | ||
from langchain_core.language_models.chat_models import BaseChatModel | ||
from langchain_core.messages import AIMessage, SystemMessage | ||
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable | ||
from langgraph.checkpoint.memory import MemorySaver | ||
from langgraph.graph import END, MessagesState, StateGraph | ||
from langgraph.managed import RemainingSteps | ||
from langgraph.prebuilt import ToolNode | ||
|
||
from agents.tools import calculator | ||
from core import get_model, settings | ||
|
||
|
||
class AgentState(MessagesState, total=False): | ||
"""`total=False` is PEP589 specs. | ||
|
||
documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality | ||
""" | ||
remaining_steps: RemainingSteps | ||
|
||
|
||
web_search = DuckDuckGoSearchResults(name="WebSearch") | ||
tools = [web_search, calculator] | ||
|
||
current_date = datetime.now().strftime("%B %d, %Y") | ||
instructions = f""" | ||
You are a helpful research assistant with the ability to search the web and use other tools. | ||
Today's date is {current_date}. | ||
|
||
NOTE: THE USER CAN'T SEE THE TOOL RESPONSE. | ||
|
||
A few things to remember: | ||
- Please include markdown-formatted links to any citations used in your response. Only include one | ||
or two citations per response unless more are needed. ONLY USE LINKS RETURNED BY THE TOOLS. | ||
- Use calculator tool with numexpr to answer math questions. The user does not understand numexpr, | ||
so for the final response, use human readable format - e.g. "300 * 200", not "(300 \\times 200)". | ||
- If API call is denied by user don't answer the question, just inform that you are canceling the | ||
operation as they ask to. | ||
""" | ||
|
||
def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]: | ||
model = model.bind_tools(tools) | ||
preprocessor = RunnableLambda( | ||
lambda state: [SystemMessage(content=instructions)] + state["messages"], | ||
name="StateModifier", | ||
) | ||
return preprocessor | model | ||
|
||
|
||
async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState: | ||
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL)) | ||
model_runnable = wrap_model(m) | ||
response = await model_runnable.ainvoke(state, config) | ||
|
||
if state["remaining_steps"] < 2 and response.tool_calls: | ||
return { | ||
"messages": [ | ||
AIMessage( | ||
id=response.id, | ||
content="Sorry, need more steps to process this request.", | ||
) | ||
] | ||
} | ||
# We return a list, because this will get added to the existing list | ||
return {"messages": [response]} | ||
|
||
|
||
# Define the graph | ||
agent = StateGraph(AgentState) | ||
agent.add_node("model", acall_model) | ||
agent.add_node("tools", ToolNode(tools)) | ||
agent.set_entry_point("model") | ||
|
||
# Always run "model" after "tools" | ||
agent.add_edge("tools", "model") | ||
|
||
# After "model", if there are tool calls, run "tools". Otherwise END. | ||
def pending_tool_calls(state: AgentState) -> Literal["tools", "done"]: | ||
last_message = state["messages"][-1] | ||
if not isinstance(last_message, AIMessage): | ||
raise TypeError(f"Expected AIMessage, got {type(last_message)}") | ||
if last_message.tool_calls: | ||
return "tools" | ||
return "done" | ||
|
||
|
||
agent.add_conditional_edges("model", pending_tool_calls, {"tools": "tools", "done": END}) | ||
|
||
interrupted_assistant = agent.compile( | ||
checkpointer=MemorySaver(), | ||
interrupt_before=["tools"] | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,6 +40,12 @@ class ServiceMetadata(BaseModel): | |
class UserInput(BaseModel): | ||
"""Basic user input for the agent.""" | ||
|
||
type: Literal["user", "interrupt"] = Field( | ||
description="Source of the message.", | ||
default="user", | ||
examples=["user", "interrupt"], | ||
) | ||
|
||
message: str = Field( | ||
description="User input to the agent.", | ||
examples=["What is the weather in Tokyo?"], | ||
|
@@ -55,6 +61,16 @@ class UserInput(BaseModel): | |
default=None, | ||
examples=["847c6285-8fc9-4560-a83f-4e6285809254"], | ||
) | ||
tool_call_id: str | None = Field( | ||
description="Tool call that this message is responding to after an interruption.", | ||
default=None, | ||
examples=["call_Jja7J89XsjrOLA5r!MEOW!SL"], | ||
) | ||
run_id: str | None = Field( | ||
description="Run ID of the message to continue after interruption.", | ||
default=None, | ||
examples=["847c6285-8fc9-4560-a83f-4e6285809254"], | ||
) | ||
|
||
|
||
class StreamInput(UserInput): | ||
|
@@ -125,6 +141,49 @@ def pretty_repr(self) -> str: | |
def pretty_print(self) -> None: | ||
print(self.pretty_repr()) # noqa: T201 | ||
|
||
class InterruptMessage(BaseModel): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this actually need to be a whole new class? Why not just make a new |
||
msg_name: str = Field( | ||
description="Name of the message. If created from call_tool it will be equal to the template name.", | ||
examples=["book_table", "ask_permission", "default"], | ||
) | ||
|
||
args: dict[str, Any] = Field( | ||
description="Data/information to be added to message", | ||
example=[{"num_personas": 4, "reservation_datetime": "2024-10-24 19:00"}], | ||
default={}, | ||
) | ||
confirmation_msg: str | None = Field( | ||
description="Message to display.", | ||
default=None, ## TODO: antes de hacer el diplay si esto es None tendre que invocar select_confirmation con default | ||
example=["Needs confirmation.", "{num_personas} people at {reservation_datetime}. It's that correct?"], | ||
) | ||
user_input: str | None = Field( | ||
description="Data added by user after interruption. None means that user said ok to confirmation.", | ||
default="", #TODO: no uso none por defecto, porque un none significa okey continua la ejecicion | ||
) | ||
tool_id: str | None = Field( | ||
description="ID of the interrupted tool.", | ||
default=None, | ||
examples=["call_Jja7J89XsjrOLA5r!MEOW!SL"] | ||
) | ||
run_id: str | None = Field( | ||
description="Run ID of the message.", | ||
default=None, | ||
examples=["847c6285-8fc9-4560-a83f-4e6285809254"], | ||
) | ||
original: dict[str, Any] = Field( | ||
description="Original LangChain message triggering the interruption in serialized form.", | ||
default={}, | ||
) | ||
# model: str | None = Field( | ||
# description="Model to run.", | ||
# default=None, | ||
# ) | ||
# thread_id: str | None = Field( | ||
# description="Thread id.", | ||
# default=None, | ||
# ) | ||
|
||
|
||
class Feedback(BaseModel): | ||
"""Feedback for a run, to record to LangSmith.""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
from fastapi.responses import StreamingResponse | ||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer | ||
from langchain_core._api import LangChainBetaWarning | ||
from langchain_core.messages import AnyMessage, HumanMessage | ||
from langchain_core.messages import AnyMessage, HumanMessage, ToolMessage | ||
from langchain_core.runnables import RunnableConfig | ||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver | ||
from langgraph.graph.state import CompiledStateGraph | ||
|
@@ -22,6 +22,7 @@ | |
ChatHistory, | ||
ChatHistoryInput, | ||
ChatMessage, | ||
InterruptMessage, | ||
Feedback, | ||
FeedbackResponse, | ||
ServiceMetadata, | ||
|
@@ -32,6 +33,7 @@ | |
convert_message_content_to_string, | ||
langchain_to_chat_message, | ||
remove_tool_calls, | ||
interrupt_from_call_tool | ||
) | ||
|
||
warnings.filterwarnings("ignore", category=LangChainBetaWarning) | ||
|
@@ -81,14 +83,33 @@ async def info() -> ServiceMetadata: | |
|
||
|
||
def _parse_input(user_input: UserInput) -> tuple[dict[str, Any], UUID]: | ||
run_id = uuid4() | ||
run_id = user_input.run_id or uuid4() | ||
thread_id = user_input.thread_id or str(uuid4()) | ||
if user_input.type == "user": | ||
inputU = {"messages": [HumanMessage(content=user_input.message)]} | ||
else: | ||
if user_input.message == "aproved_tool": | ||
inputU = None # This means to continue graph execution after interruption | ||
elif user_input.message == "canceled_tool": | ||
inputU = {"messages": [ToolMessage( | ||
tool_call_id=user_input.tool_call_id, | ||
content=f"API call denied by user. Ask them about de reason or what else can you help them with.", | ||
)]} | ||
else: | ||
inputU = {"messages": [ | ||
ToolMessage( | ||
tool_call_id=user_input.tool_call_id, | ||
content=f"API call denied by user.", | ||
), | ||
HumanMessage(content=user_input.message) | ||
]} | ||
|
||
kwargs = { | ||
"input": {"messages": [HumanMessage(content=user_input.message)]}, | ||
"config": RunnableConfig( | ||
configurable={"thread_id": thread_id, "model": user_input.model}, run_id=run_id | ||
), | ||
} | ||
"input": inputU, | ||
"config": RunnableConfig( | ||
configurable={"thread_id": thread_id, "model": user_input.model}, run_id=run_id | ||
), | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You'll want to set up the pre-commit hooks and run the linter and tests locally. Otherwise the CI will keep failing and won't be able to merge this. See the contributing instructions in the README. |
||
return kwargs, run_id | ||
|
||
|
||
|
@@ -171,6 +192,22 @@ async def message_generator( | |
# So we only print non-empty content. | ||
yield f"data: {json.dumps({'type': 'token', 'content': convert_message_content_to_string(content)})}\n\n" | ||
continue | ||
|
||
# Handle interruptions: https://langchain-ai.github.io/langgraph/how-tos/human_in_the_loop/breakpoints/ | ||
# To detect breakpoints which enable pausing graph execution at specific steps. | ||
# If detected, interruption message is send to ask for human approval. | ||
# Checkpoints ensure the graph can resume from the same state after human input. | ||
# To specify breakpoints use interrupt_before in the agent. | ||
snapshot = await agent.aget_state(kwargs["config"]) | ||
if snapshot.next: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a comment explaining what this check does and why? |
||
try: | ||
ai_message = langchain_to_chat_message(snapshot.values["messages"][-1]) | ||
ichat_message = interrupt_from_call_tool(call_tool=ai_message.tool_calls[0]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An AIMessage can include multiple tool calls. This would drop the additional ones. That seems bad. |
||
ichat_message.run_id = str(run_id) | ||
except Exception as e: | ||
yield f"data: {json.dumps({'type': 'error', 'content': f'Error parsing interrupt message: {e}'})}\n\n" | ||
|
||
yield f"data: {json.dumps({'type': 'interrupt', 'content': ichat_message.model_dump()})}\n\n" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This only handles the |
||
|
||
yield "data: [DONE]\n\n" | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
ChatMessage as LangchainChatMessage, | ||
) | ||
|
||
from schema import ChatMessage | ||
from schema import ChatMessage, InterruptMessage | ||
|
||
|
||
def convert_message_content_to_string(content: str | list[str | dict]) -> str: | ||
|
@@ -74,3 +74,18 @@ def remove_tool_calls(content: str | list[str | dict]) -> str | list[str | dict] | |
for content_item in content | ||
if isinstance(content_item, str) or content_item["type"] != "tool_use" | ||
] | ||
|
||
|
||
def interrupt_from_call_tool(call_tool: dict) -> InterruptMessage: | ||
# Crear instancia de InterruptMessage a partir del diccionario call_tool | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would you mind putting code comments in english to be consistent with the rest of the repo? |
||
instance = InterruptMessage( | ||
msg_name=call_tool["name"], | ||
args=call_tool["args"], | ||
tool_id=call_tool["id"] | ||
) | ||
|
||
instance.original = call_tool | ||
|
||
instance.confirmation_msg = "Do you aprove the above actions?" | ||
|
||
return instance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does the client need to store and send back the tool_call_id(s) and run_id ? Couldn't it just be tracked in the server (or maybe the checkpointer already does it for us?) for the relevant thread_id instead of returning it, to keep the interface simpler?