-
Notifications
You must be signed in to change notification settings - Fork 222
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a BG Task Agent to show CustomData usage (#81)
* Implement bg-task-agent * fix bg tasks display logic (#78) * fix bg tasks display logic * Bug fixes and cleanup * support parallel task runs * Clean up logic and docs for the TaskData handling --------- Co-authored-by: ANASOFT\keppert <[email protected]> Co-authored-by: Joshua Carroll <[email protected]> --------- Co-authored-by: peterkeppert <[email protected]> Co-authored-by: ANASOFT\keppert <[email protected]>
- Loading branch information
1 parent
acebd43
commit 8e5d219
Showing
5 changed files
with
204 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import asyncio | ||
|
||
from langchain_core.language_models.chat_models import BaseChatModel | ||
from langchain_core.messages import AIMessage | ||
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable | ||
from langgraph.checkpoint.memory import MemorySaver | ||
from langgraph.graph import END, MessagesState, StateGraph | ||
|
||
from agents.bg_task_agent.task import Task | ||
from agents.models import models | ||
|
||
|
||
class AgentState(MessagesState, total=False): | ||
"""`total=False` is PEP589 specs. | ||
documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality | ||
""" | ||
|
||
|
||
def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]: | ||
preprocessor = RunnableLambda( | ||
lambda state: state["messages"], | ||
name="StateModifier", | ||
) | ||
return preprocessor | model | ||
|
||
|
||
async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState: | ||
m = models[config["configurable"].get("model", "gpt-4o-mini")] | ||
model_runnable = wrap_model(m) | ||
response = await model_runnable.ainvoke(state, config) | ||
|
||
# We return a list, because this will get added to the existing list | ||
return {"messages": [response]} | ||
|
||
|
||
async def bg_task(state: AgentState, config: RunnableConfig) -> AgentState: | ||
task1 = Task("Simple task 1...") | ||
task2 = Task("Simple task 2...") | ||
|
||
await task1.start(config=config) | ||
await asyncio.sleep(2) | ||
await task2.start(config=config) | ||
await asyncio.sleep(2) | ||
await task1.write_data(config=config, data={"status": "Still running..."}) | ||
await asyncio.sleep(2) | ||
await task2.finish(result="error", config=config, data={"output": 42}) | ||
await asyncio.sleep(2) | ||
await task1.finish(result="success", config=config, data={"output": 42}) | ||
return {"messages": []} | ||
|
||
|
||
# Define the graph | ||
agent = StateGraph(AgentState) | ||
agent.add_node("model", acall_model) | ||
agent.add_node("bg_task", bg_task) | ||
agent.set_entry_point("bg_task") | ||
|
||
agent.add_edge("bg_task", "model") | ||
agent.add_edge("model", END) | ||
|
||
bg_task_agent = agent.compile( | ||
checkpointer=MemorySaver(), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from typing import Literal | ||
from uuid import uuid4 | ||
|
||
from langchain_core.messages import BaseMessage | ||
from langchain_core.runnables import RunnableConfig | ||
|
||
from agents.utils import CustomData | ||
from schema.task_data import TaskData | ||
|
||
|
||
class Task: | ||
def __init__(self, task_name: str) -> None: | ||
self.name = task_name | ||
self.id = str(uuid4()) | ||
self.state: Literal["new", "running", "complete"] = "new" | ||
self.result: Literal["success", "error"] | None = None | ||
|
||
async def _generate_and_dispatch_message(self, config: RunnableConfig, data: dict): | ||
task_data = TaskData(name=self.name, run_id=self.id, state=self.state, data=data) | ||
if self.result: | ||
task_data.result = self.result | ||
task_custom_data = CustomData( | ||
type=self.name, | ||
data=task_data.model_dump(), | ||
) | ||
await task_custom_data.adispatch() | ||
return task_custom_data.to_langchain() | ||
|
||
async def start(self, config: RunnableConfig, data: dict = {}) -> BaseMessage: | ||
self.state = "new" | ||
task_message = await self._generate_and_dispatch_message(config, data) | ||
return task_message | ||
|
||
async def write_data(self, config: RunnableConfig, data: dict) -> BaseMessage: | ||
if self.state == "complete": | ||
raise ValueError("Only incomplete tasks can output data.") | ||
self.state = "running" | ||
task_message = await self._generate_and_dispatch_message(config, data) | ||
return task_message | ||
|
||
async def finish( | ||
self, result: Literal["success", "error"], config: RunnableConfig, data: dict = {} | ||
) -> BaseMessage: | ||
self.state = "complete" | ||
self.result = result | ||
task_message = await self._generate_and_dispatch_message(config, data) | ||
return task_message |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import Any, Literal | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
|
||
class TaskData(BaseModel): | ||
name: str | None = Field( | ||
description="Name of the task.", default=None, examples=["Check input safety"] | ||
) | ||
run_id: str = Field( | ||
description="ID of the task run to pair state updates to.", | ||
default="", | ||
examples=["847c6285-8fc9-4560-a83f-4e6285809254"], | ||
) | ||
state: Literal["new", "running", "complete"] | None = Field( | ||
description="Current state of given task instance.", | ||
default=None, | ||
examples=["running"], | ||
) | ||
result: Literal["success", "error"] | None = Field( | ||
description="Result of given task instance.", | ||
default=None, | ||
examples=["running"], | ||
) | ||
data: dict[str, Any] = Field( | ||
description="Additional data generated by the task.", | ||
default={}, | ||
) | ||
|
||
def completed(self) -> bool: | ||
return self.state == "complete" | ||
|
||
def completed_with_error(self) -> bool: | ||
return self.state == "complete" and self.result == "error" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters