From acebd43d77d149396cbb7a2bd2e0eb03ebdd31fc Mon Sep 17 00:00:00 2001 From: Joshua Carroll Date: Thu, 31 Oct 2024 08:36:18 -0700 Subject: [PATCH] Implement CustomData for agents; select agent in app (#77) --- src/agents/utils.py | 30 ++++++++++++++++++++++++++++++ src/schema/schema.py | 8 -------- src/service/service.py | 30 ++++++++++++++++++------------ src/streamlit_app.py | 7 +++++++ 4 files changed, 55 insertions(+), 20 deletions(-) create mode 100644 src/agents/utils.py diff --git a/src/agents/utils.py b/src/agents/utils.py new file mode 100644 index 0000000..2d0bd30 --- /dev/null +++ b/src/agents/utils.py @@ -0,0 +1,30 @@ +from typing import Any + +from langchain_core.callbacks import adispatch_custom_event +from langchain_core.messages import ChatMessage +from langchain_core.runnables import RunnableConfig +from langchain_core.runnables.config import merge_configs +from pydantic import BaseModel, Field + + +class CustomData(BaseModel): + "Custom data being sent by an agent" + + type: str = Field( + description="The type of custom data, used in dispatch events", + default="custom_data", + ) + data: dict[str, Any] = Field(description="The custom data") + + def to_langchain(self) -> ChatMessage: + return ChatMessage(content=[self.data], role="custom") + + async def adispatch(self, config: RunnableConfig | None = None) -> None: + dispatch_config = RunnableConfig( + tags=["custom_data_dispatch"], + ) + await adispatch_custom_event( + name=self.type, + data=self.to_langchain(), + config=merge_configs(config, dispatch_config), + ) diff --git a/src/schema/schema.py b/src/schema/schema.py index b088a90..8c27f3a 100644 --- a/src/schema/schema.py +++ b/src/schema/schema.py @@ -78,14 +78,6 @@ class ChatMessage(BaseModel): default={}, ) - @classmethod - def from_custom_data(cls, data: dict[str, Any]) -> "ChatMessage": - return cls( - type="custom", - content="", - custom_data=data, - ) - def pretty_repr(self) -> str: """Get a pretty representation of the message.""" base_title = self.type.title() + " Message" diff --git a/src/service/service.py b/src/service/service.py index 06cfea1..29fe1b9 100644 --- a/src/service/service.py +++ b/src/service/service.py @@ -128,6 +128,7 @@ async def message_generator( if not event: continue + new_messages = [] # Yield messages written to the graph state after node execution finishes. if ( event["event"] == "on_chain_end" @@ -137,18 +138,23 @@ async def message_generator( and "messages" in event["data"]["output"] ): new_messages = event["data"]["output"]["messages"] - for message in new_messages: - try: - chat_message = langchain_to_chat_message(message) - chat_message.run_id = str(run_id) - except Exception as e: - logger.error(f"Error parsing message: {e}") - yield f"data: {json.dumps({'type': 'error', 'content': 'Unexpected error'})}\n\n" - continue - # LangGraph re-sends the input message, which feels weird, so drop it - if chat_message.type == "human" and chat_message.content == user_input.message: - continue - yield f"data: {json.dumps({'type': 'message', 'content': chat_message.model_dump()})}\n\n" + + # Also yield intermediate messages from agents.utils.CustomData.adispatch(). + if event["event"] == "on_custom_event" and "custom_data_dispatch" in event.get("tags", []): + new_messages = [event["data"]] + + for message in new_messages: + try: + chat_message = langchain_to_chat_message(message) + chat_message.run_id = str(run_id) + except Exception as e: + logger.error(f"Error parsing message: {e}") + yield f"data: {json.dumps({'type': 'error', 'content': 'Unexpected error'})}\n\n" + continue + # LangGraph re-sends the input message, which feels weird, so drop it + if chat_message.type == "human" and chat_message.content == user_input.message: + continue + yield f"data: {json.dumps({'type': 'message', 'content': chat_message.model_dump()})}\n\n" # Yield tokens streamed from LLMs. if ( diff --git a/src/streamlit_app.py b/src/streamlit_app.py index 1ab1d05..2532875 100644 --- a/src/streamlit_app.py +++ b/src/streamlit_app.py @@ -77,6 +77,13 @@ async def main() -> None: with st.popover(":material/settings: Settings", use_container_width=True): m = st.radio("LLM to use", options=models.keys()) model = models[m] + agent_client.agent = st.selectbox( + "Agent to use", + options=[ + "research-assistant", + "chatbot", + ], + ) use_streaming = st.toggle("Stream results", value=True) @st.dialog("Architecture")