Skip to content

Commit

Permalink
Refactor TaskData draw (#93)
Browse files Browse the repository at this point in the history
* First attempt to extract TaskData logic

* Simplify / cleanup

* Clean up
  • Loading branch information
JoshuaC215 authored Nov 17, 2024
1 parent 66c1575 commit 5f27dc5
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 44 deletions.
39 changes: 39 additions & 0 deletions src/schema/task_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,42 @@ def completed(self) -> bool:

def completed_with_error(self) -> bool:
return self.state == "complete" and self.result == "error"


class TaskDataStatus:
def __init__(self) -> None:
import streamlit as st

self.status = st.status("")
self.current_task_data: dict[str, TaskData] = {}

def add_and_draw_task_data(self, task_data: TaskData) -> None:
status = self.status
status_str = f"Task **{task_data.name}** "
match task_data.state:
case "new":
status_str += "has :blue[started]. Input:"
case "running":
status_str += "wrote:"
case "complete":
if task_data.result == "success":
status_str += ":green[completed successfully]. Output:"
else:
status_str += ":red[ended with error]. Output:"
status.write(status_str)
status.write(task_data.data)
status.write("---")
if task_data.run_id not in self.current_task_data:
# Status label always shows the last newly started task
status.update(label=f"""Task: {task_data.name}""")
self.current_task_data[task_data.run_id] = task_data
# Status is "running" until all tasks have completed
if not any(entry.completed() for entry in self.current_task_data.values()):
state = "running"
# Status is "error" if any task has errored
elif any(entry.completed_with_error() for entry in self.current_task_data.values()):
state = "error"
# Status is "complete" if all tasks have completed successfully
else:
state = "complete"
status.update(state=state)
59 changes: 15 additions & 44 deletions src/streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from collections.abc import AsyncGenerator

import streamlit as st
from pydantic import ValidationError
from streamlit.runtime.scriptrunner import get_script_run_ctx

from client import AgentClient
from schema import ChatHistory, ChatMessage
from schema.task_data import TaskData
from schema.task_data import TaskData, TaskDataStatus

# A Streamlit app for interacting with the langgraph agent via a simple chat interface.
# The app has three main functions which are all run async:
Expand Down Expand Up @@ -270,59 +271,29 @@ async def draw_messages(
status.update(state="complete")

case "custom":
# This is an implementation of the TaskData example for CustomData.
# An agent can write a CustomData object to the message stream, and
# it's passed to the client for rendering. To see this in practice,
# run the app with the `bg-task-agent` agent.
# CustomData example used by the bg-task-agent
# See:
# - src/agents/utils.py CustomData
# - src/agents/bg_task_agent/task.py
try:
task_data: TaskData = TaskData.model_validate(msg.custom_data)
except ValidationError:
st.error("Unexpected CustomData message received from agent")
st.write(msg.custom_data)
st.stop()

# This is provided as an example, you may want to write your own
# CustomData types and handlers. This section will be skipped for
# any other agents that don't send CustomData.
task_data = TaskData.model_validate(msg.custom_data)

# If we're rendering new messages, store the message in session state
if is_new:
st.session_state.messages.append(msg)

# If the last message type was not Task, create a new chat message
# and container for task messages
if last_message_type != "task":
last_message_type = "task"
st.session_state.last_message = st.chat_message(
name="task", avatar=":material/manufacturing:"
)
with st.session_state.last_message:
status = st.status("")
current_task_data: dict[str, TaskData] = {}

status_str = f"Task **{task_data.name}** "
match task_data.state:
case "new":
status_str += "has :blue[started]. Input:"
case "running":
status_str += "wrote:"
case "complete":
if task_data.result == "success":
status_str += ":green[completed successfully]. Output:"
else:
status_str += ":red[ended with error]. Output:"
status.write(status_str)
status.write(task_data.data)
status.write("---")
if task_data.run_id not in current_task_data:
# Status label always shows the last newly started task
status.update(label=f"""Task: {task_data.name}""")
current_task_data[task_data.run_id] = task_data
# Status is "running" until all tasks have completed
if not any(entry.completed() for entry in current_task_data.values()):
state = "running"
# Status is "error" if any task has errored
elif any(entry.completed_with_error() for entry in current_task_data.values()):
state = "error"
# Status is "complete" if all tasks have completed successfully
else:
state = "complete"
status.update(state=state)
status = TaskDataStatus()

status.add_and_draw_task_data(task_data)

# In case of an unexpected message type, log an error and stop
case _:
Expand Down

0 comments on commit 5f27dc5

Please sign in to comment.