Skip to content

Commit

Permalink
Implement CustomData for agents; select agent in app (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshuaC215 authored Oct 31, 2024
1 parent cb450b5 commit acebd43
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 20 deletions.
30 changes: 30 additions & 0 deletions src/agents/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Any

from langchain_core.callbacks import adispatch_custom_event
from langchain_core.messages import ChatMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.config import merge_configs
from pydantic import BaseModel, Field


class CustomData(BaseModel):
"Custom data being sent by an agent"

type: str = Field(
description="The type of custom data, used in dispatch events",
default="custom_data",
)
data: dict[str, Any] = Field(description="The custom data")

def to_langchain(self) -> ChatMessage:
return ChatMessage(content=[self.data], role="custom")

async def adispatch(self, config: RunnableConfig | None = None) -> None:
dispatch_config = RunnableConfig(
tags=["custom_data_dispatch"],
)
await adispatch_custom_event(
name=self.type,
data=self.to_langchain(),
config=merge_configs(config, dispatch_config),
)
8 changes: 0 additions & 8 deletions src/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,6 @@ class ChatMessage(BaseModel):
default={},
)

@classmethod
def from_custom_data(cls, data: dict[str, Any]) -> "ChatMessage":
return cls(
type="custom",
content="",
custom_data=data,
)

def pretty_repr(self) -> str:
"""Get a pretty representation of the message."""
base_title = self.type.title() + " Message"
Expand Down
30 changes: 18 additions & 12 deletions src/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ async def message_generator(
if not event:
continue

new_messages = []
# Yield messages written to the graph state after node execution finishes.
if (
event["event"] == "on_chain_end"
Expand All @@ -137,18 +138,23 @@ async def message_generator(
and "messages" in event["data"]["output"]
):
new_messages = event["data"]["output"]["messages"]
for message in new_messages:
try:
chat_message = langchain_to_chat_message(message)
chat_message.run_id = str(run_id)
except Exception as e:
logger.error(f"Error parsing message: {e}")
yield f"data: {json.dumps({'type': 'error', 'content': 'Unexpected error'})}\n\n"
continue
# LangGraph re-sends the input message, which feels weird, so drop it
if chat_message.type == "human" and chat_message.content == user_input.message:
continue
yield f"data: {json.dumps({'type': 'message', 'content': chat_message.model_dump()})}\n\n"

# Also yield intermediate messages from agents.utils.CustomData.adispatch().
if event["event"] == "on_custom_event" and "custom_data_dispatch" in event.get("tags", []):
new_messages = [event["data"]]

for message in new_messages:
try:
chat_message = langchain_to_chat_message(message)
chat_message.run_id = str(run_id)
except Exception as e:
logger.error(f"Error parsing message: {e}")
yield f"data: {json.dumps({'type': 'error', 'content': 'Unexpected error'})}\n\n"
continue
# LangGraph re-sends the input message, which feels weird, so drop it
if chat_message.type == "human" and chat_message.content == user_input.message:
continue
yield f"data: {json.dumps({'type': 'message', 'content': chat_message.model_dump()})}\n\n"

# Yield tokens streamed from LLMs.
if (
Expand Down
7 changes: 7 additions & 0 deletions src/streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ async def main() -> None:
with st.popover(":material/settings: Settings", use_container_width=True):
m = st.radio("LLM to use", options=models.keys())
model = models[m]
agent_client.agent = st.selectbox(
"Agent to use",
options=[
"research-assistant",
"chatbot",
],
)
use_streaming = st.toggle("Stream results", value=True)

@st.dialog("Architecture")
Expand Down

0 comments on commit acebd43

Please sign in to comment.