Skip to content

Commit

Permalink
Add a BG Task Agent to show CustomData usage (#81)
Browse files Browse the repository at this point in the history
* Implement bg-task-agent

* fix bg tasks display logic (#78)

* fix bg tasks display logic

* Bug fixes and cleanup

* support parallel task runs

* Clean up logic and docs for the TaskData handling

---------

Co-authored-by: ANASOFT\keppert <[email protected]>
Co-authored-by: Joshua Carroll <[email protected]>

---------

Co-authored-by: peterkeppert <[email protected]>
Co-authored-by: ANASOFT\keppert <[email protected]>
  • Loading branch information
3 people authored Nov 4, 2024
1 parent acebd43 commit 8e5d219
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/agents/agents.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from langgraph.graph.state import CompiledStateGraph

from agents.bg_task_agent.bg_task_agent import bg_task_agent
from agents.chatbot import chatbot
from agents.research_assistant import research_assistant

Expand All @@ -9,4 +10,5 @@
agents: dict[str, CompiledStateGraph] = {
"chatbot": chatbot,
"research-assistant": research_assistant,
"bg-task-agent": bg_task_agent,
}
64 changes: 64 additions & 0 deletions src/agents/bg_task_agent/bg_task_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import asyncio

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, MessagesState, StateGraph

from agents.bg_task_agent.task import Task
from agents.models import models


class AgentState(MessagesState, total=False):
"""`total=False` is PEP589 specs.
documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality
"""


def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]:
preprocessor = RunnableLambda(
lambda state: state["messages"],
name="StateModifier",
)
return preprocessor | model


async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
m = models[config["configurable"].get("model", "gpt-4o-mini")]
model_runnable = wrap_model(m)
response = await model_runnable.ainvoke(state, config)

# We return a list, because this will get added to the existing list
return {"messages": [response]}


async def bg_task(state: AgentState, config: RunnableConfig) -> AgentState:
task1 = Task("Simple task 1...")
task2 = Task("Simple task 2...")

await task1.start(config=config)
await asyncio.sleep(2)
await task2.start(config=config)
await asyncio.sleep(2)
await task1.write_data(config=config, data={"status": "Still running..."})
await asyncio.sleep(2)
await task2.finish(result="error", config=config, data={"output": 42})
await asyncio.sleep(2)
await task1.finish(result="success", config=config, data={"output": 42})
return {"messages": []}


# Define the graph
agent = StateGraph(AgentState)
agent.add_node("model", acall_model)
agent.add_node("bg_task", bg_task)
agent.set_entry_point("bg_task")

agent.add_edge("bg_task", "model")
agent.add_edge("model", END)

bg_task_agent = agent.compile(
checkpointer=MemorySaver(),
)
47 changes: 47 additions & 0 deletions src/agents/bg_task_agent/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Literal
from uuid import uuid4

from langchain_core.messages import BaseMessage
from langchain_core.runnables import RunnableConfig

from agents.utils import CustomData
from schema.task_data import TaskData


class Task:
def __init__(self, task_name: str) -> None:
self.name = task_name
self.id = str(uuid4())
self.state: Literal["new", "running", "complete"] = "new"
self.result: Literal["success", "error"] | None = None

async def _generate_and_dispatch_message(self, config: RunnableConfig, data: dict):
task_data = TaskData(name=self.name, run_id=self.id, state=self.state, data=data)
if self.result:
task_data.result = self.result
task_custom_data = CustomData(
type=self.name,
data=task_data.model_dump(),
)
await task_custom_data.adispatch()
return task_custom_data.to_langchain()

async def start(self, config: RunnableConfig, data: dict = {}) -> BaseMessage:
self.state = "new"
task_message = await self._generate_and_dispatch_message(config, data)
return task_message

async def write_data(self, config: RunnableConfig, data: dict) -> BaseMessage:
if self.state == "complete":
raise ValueError("Only incomplete tasks can output data.")
self.state = "running"
task_message = await self._generate_and_dispatch_message(config, data)
return task_message

async def finish(
self, result: Literal["success", "error"], config: RunnableConfig, data: dict = {}
) -> BaseMessage:
self.state = "complete"
self.result = result
task_message = await self._generate_and_dispatch_message(config, data)
return task_message
34 changes: 34 additions & 0 deletions src/schema/task_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Any, Literal

from pydantic import BaseModel, Field


class TaskData(BaseModel):
name: str | None = Field(
description="Name of the task.", default=None, examples=["Check input safety"]
)
run_id: str = Field(
description="ID of the task run to pair state updates to.",
default="",
examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
)
state: Literal["new", "running", "complete"] | None = Field(
description="Current state of given task instance.",
default=None,
examples=["running"],
)
result: Literal["success", "error"] | None = Field(
description="Result of given task instance.",
default=None,
examples=["running"],
)
data: dict[str, Any] = Field(
description="Additional data generated by the task.",
default={},
)

def completed(self) -> bool:
return self.state == "complete"

def completed_with_error(self) -> bool:
return self.state == "complete" and self.result == "error"
57 changes: 57 additions & 0 deletions src/streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from client import AgentClient
from schema import ChatHistory, ChatMessage
from schema.task_data import TaskData

# A Streamlit app for interacting with the langgraph agent via a simple chat interface.
# The app has three main functions which are all run async:
Expand Down Expand Up @@ -82,6 +83,7 @@ async def main() -> None:
options=[
"research-assistant",
"chatbot",
"bg-task-agent",
],
)
use_streaming = st.toggle("Stream results", value=True)
Expand Down Expand Up @@ -266,6 +268,61 @@ async def draw_messages(
status.write(tool_result.content)
status.update(state="complete")

case "custom":
# This is an implementation of the TaskData example for CustomData.
# An agent can write a CustomData object to the message stream, and
# it's passed to the client for rendering. To see this in practice,
# run the app with the `bg-task-agent` agent.

# This is provided as an example, you may want to write your own
# CustomData types and handlers. This section will be skipped for
# any other agents that don't send CustomData.
task_data = TaskData.model_validate(msg.custom_data)

# If we're rendering new messages, store the message in session state
if is_new:
st.session_state.messages.append(msg)

# If the last message type was not Task, create a new chat message
# and container for task messages
if last_message_type != "task":
last_message_type = "task"
st.session_state.last_message = st.chat_message(
name="task", avatar=":material/manufacturing:"
)
with st.session_state.last_message:
status = st.status("")
current_task_data: dict[str, TaskData] = {}

status_str = f"Task **{task_data.name}** "
match task_data.state:
case "new":
status_str += "has :blue[started]. Input:"
case "running":
status_str += "wrote:"
case "complete":
if task_data.result == "success":
status_str += ":green[completed successfully]. Output:"
else:
status_str += ":red[ended with error]. Output:"
status.write(status_str)
status.write(task_data.data)
status.write("---")
if task_data.run_id not in current_task_data:
# Status label always shows the last newly started task
status.update(label=f"""Task: {task_data.name}""")
current_task_data[task_data.run_id] = task_data
# Status is "running" until all tasks have completed
if not any(entry.completed() for entry in current_task_data.values()):
state = "running"
# Status is "error" if any task has errored
elif any(entry.completed_with_error() for entry in current_task_data.values()):
state = "error"
# Status is "complete" if all tasks have completed successfully
else:
state = "complete"
status.update(state=state)

# In case of an unexpected message type, log an error and stop
case _:
st.error(f"Unexpected ChatMessage type: {msg.type}")
Expand Down

0 comments on commit 8e5d219

Please sign in to comment.