Skip to content

Commit

Permalink
Add support for calling multiple agents on one service (#75)
Browse files Browse the repository at this point in the history
* Support multiple agents (still only one in service)

* Get the routes working, update client + tests

* add TODO

* Clean up basic agent call, update README
  • Loading branch information
JoshuaC215 authored Oct 31, 2024
1 parent b5eec80 commit cb450b5
Show file tree
Hide file tree
Showing 18 changed files with 227 additions and 109 deletions.
18 changes: 11 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ docker compose watch
1. **Advanced Streaming**: A novel approach to support both token-based and message-based streaming.
1. **Content Moderation**: Implements LlamaGuard for content moderation (requires Groq API key).
1. **Streamlit Interface**: Provides a user-friendly chat interface for interacting with the agent.
1. **Multiple Agent Support**: Run multiple agents in the service and call by URL path
1. **Asynchronous Design**: Utilizes async/await for efficient handling of concurrent requests.
1. **Feedback Mechanism**: Includes a star-based feedback system integrated with LangSmith.
1. **Docker Support**: Includes Dockerfiles and a docker compose file for easy development and deployment.
Expand All @@ -60,10 +61,12 @@ docker compose watch

The repository is structured as follows:

- `src/agent/research_assistant.py`: Defines the LangGraph agent
- `src/agent/llama_guard.py`: Defines the LlamaGuard content moderation
- `src/schema/schema.py`: Defines the service schema
- `src/service/service.py`: FastAPI service to serve the agent
- `src/agents/research_assistant.py`: Defines the main LangGraph agent
- `src/agents/llama_guard.py`: Defines the LlamaGuard content moderation
- `src/agents/models.py`: Configures available models based on ENV
- `src/agents/agents.py`: Mapping of all agents provided by the service
- `src/schema/schema.py`: Defines the protocol schema
- `src/service/service.py`: FastAPI service to serve the agents
- `src/client/client.py`: Client to interact with the agent service
- `src/streamlit_app.py`: Streamlit app providing a chat interface

Expand Down Expand Up @@ -208,8 +211,9 @@ Currently the tests need to be run using the local development without Docker se

To customize the agent for your own use case:

1. Modify the `src/agent/research_assistant.py` file to change the agent's behavior and tools. Or, build a new agent from scratch.
2. Adjust the Streamlit interface in `src/streamlit_app.py` to match your agent's capabilities.
1. Add your new agent to the `src/agents` directory. You can copy `research_assistant.py` or `chatbot.py` and modify it to change the agent's behavior and tools.
1. Import and add your new agent to the `agents` dictionary in `src/agents/agents.py`. Your agent can be called by `/<your_agent_name>/invoke` or `/<your_agent_name>/stream`.
1. Adjust the Streamlit interface in `src/streamlit_app.py` to match your agent's capabilities.

## Building other apps on the AgentClient

Expand Down Expand Up @@ -239,7 +243,7 @@ Contributions are welcome! Please feel free to submit a Pull Request.
- [x] Get LlamaGuard working for content moderation (anyone know a reliable and fast hosted version?)
- [x] Add more sophisticated tools for the research assistant
- [x] Increase test coverage and add CI pipeline
- [ ] Add support for multiple agents running on the same service, including non-chat agent
- [x] Add support for multiple agents running on the same service, including non-chat agent
- [ ] Deployment instructions and configuration for cloud providers
- [ ] More ideas? File an issue or create a discussion!

Expand Down
4 changes: 2 additions & 2 deletions compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ services:
- .env
develop:
watch:
- path: src/agent/
- path: src/agents/
action: sync+restart
target: /app/agent/
target: /app/agents/
- path: src/schema/
action: sync+restart
target: /app/schema/
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile.service
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ COPY uv.lock .
RUN pip install --no-cache-dir uv
RUN uv sync --frozen --no-install-project --no-dev

COPY src/agent/ ./agent/
COPY src/agents/ ./agents/
COPY src/schema/ ./schema/
COPY src/service/ ./service/
COPY src/run_service.py .
Expand Down
2 changes: 1 addition & 1 deletion langgraph.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"python_version": "3.12",
"dependencies": ["."],
"graphs": {
"research_assistant": "./src/agent/research_assistant.py:research_assistant"
"research_assistant": "./src/agents/research_assistant.py:research_assistant"
},
"env": "./.env"
}
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,10 @@ target-version = "py310"
[tool.ruff.lint]
extend-select = ["I", "U"]

[tool.pytest.ini_options]
pythonpath = [
"src"
]

[tool.pytest_env]
OPENAI_API_KEY = "sk-fake-openai-key"
3 changes: 0 additions & 3 deletions src/agent/__init__.py

This file was deleted.

3 changes: 3 additions & 0 deletions src/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from agents.agents import DEFAULT_AGENT, agents

__all__ = ["agents", "DEFAULT_AGENT"]
12 changes: 12 additions & 0 deletions src/agents/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from langgraph.graph.state import CompiledStateGraph

from agents.chatbot import chatbot
from agents.research_assistant import research_assistant

DEFAULT_AGENT = "research-assistant"


agents: dict[str, CompiledStateGraph] = {
"chatbot": chatbot,
"research-assistant": research_assistant,
}
44 changes: 44 additions & 0 deletions src/agents/chatbot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, MessagesState, StateGraph

from agents.models import models


class AgentState(MessagesState, total=False):
"""`total=False` is PEP589 specs.
documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality
"""


def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]:
preprocessor = RunnableLambda(
lambda state: state["messages"],
name="StateModifier",
)
return preprocessor | model


async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
m = models[config["configurable"].get("model", "gpt-4o-mini")]
model_runnable = wrap_model(m)
response = await model_runnable.ainvoke(state, config)

# We return a list, because this will get added to the existing list
return {"messages": [response]}


# Define the graph
agent = StateGraph(AgentState)
agent.add_node("model", acall_model)
agent.set_entry_point("model")

# Always END after blocking unsafe content
agent.add_edge("model", END)

chatbot = agent.compile(
checkpointer=MemorySaver(),
)
File renamed without changes.
29 changes: 29 additions & 0 deletions src/agents/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os

from langchain_anthropic import ChatAnthropic
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI

# NOTE: models with streaming=True will send tokens as they are generated
# if the /stream endpoint is called with stream_tokens=True (the default)
models: dict[str, BaseChatModel] = {}
if os.getenv("OPENAI_API_KEY") is not None:
models["gpt-4o-mini"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.5, streaming=True)
if os.getenv("GROQ_API_KEY") is not None:
models["llama-3.1-70b"] = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.5)
if os.getenv("GOOGLE_API_KEY") is not None:
models["gemini-1.5-flash"] = ChatGoogleGenerativeAI(
model="gemini-1.5-flash", temperature=0.5, streaming=True
)
if os.getenv("ANTHROPIC_API_KEY") is not None:
models["claude-3-haiku"] = ChatAnthropic(
model="claude-3-haiku-20240307", temperature=0.5, streaming=True
)

if not models:
print("No LLM available. Please set API keys to enable at least one LLM.")
if os.getenv("MODE") == "dev":
print("FastAPI initialized failed. Please use Ctrl + C to exit uvicorn.")
exit(1)
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,18 @@
from datetime import datetime
from typing import Literal

from langchain_anthropic import ChatAnthropic
from langchain_community.tools import DuckDuckGoSearchResults, OpenWeatherMapQueryRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, MessagesState, StateGraph
from langgraph.managed import IsLastStep
from langgraph.prebuilt import ToolNode

from agent.llama_guard import LlamaGuard, LlamaGuardOutput, SafetyAssessment
from agent.tools import calculator
from agents.llama_guard import LlamaGuard, LlamaGuardOutput, SafetyAssessment
from agents.models import models
from agents.tools import calculator


class AgentState(MessagesState, total=False):
Expand All @@ -29,29 +26,6 @@ class AgentState(MessagesState, total=False):
is_last_step: IsLastStep


# NOTE: models with streaming=True will send tokens as they are generated
# if the /stream endpoint is called with stream_tokens=True (the default)
models: dict[str, BaseChatModel] = {}
if os.getenv("OPENAI_API_KEY") is not None:
models["gpt-4o-mini"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.5, streaming=True)
if os.getenv("GROQ_API_KEY") is not None:
models["llama-3.1-70b"] = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.5)
if os.getenv("GOOGLE_API_KEY") is not None:
models["gemini-1.5-flash"] = ChatGoogleGenerativeAI(
model="gemini-1.5-flash", temperature=0.5, streaming=True
)
if os.getenv("ANTHROPIC_API_KEY") is not None:
models["claude-3-haiku"] = ChatAnthropic(
model="claude-3-haiku-20240307", temperature=0.5, streaming=True
)

if not models:
print("No LLM available. Please set API keys to enable at least one LLM.")
if os.getenv("MODE") == "dev":
print("FastAPI initialized failed. Please use Ctrl + C to exit uvicorn.")
exit(1)


web_search = DuckDuckGoSearchResults(name="WebSearch")
tools = [web_search, calculator]

Expand Down Expand Up @@ -171,31 +145,3 @@ def pending_tool_calls(state: AgentState) -> Literal["tools", "done"]:
research_assistant = agent.compile(
checkpointer=MemorySaver(),
)


if __name__ == "__main__":
import asyncio
from uuid import uuid4

from dotenv import load_dotenv

load_dotenv()

async def main() -> None:
inputs = {"messages": [("user", "Find me a recipe for chocolate chip cookies")]}
result = await research_assistant.ainvoke(
inputs,
config=RunnableConfig(configurable={"thread_id": uuid4()}),
)
result["messages"][-1].pretty_print()

# Draw the agent graph as png
# requires:
# brew install graphviz
# export CFLAGS="-I $(brew --prefix graphviz)/include"
# export LDFLAGS="-L $(brew --prefix graphviz)/lib"
# pip install pygraphviz
#
# research_assistant.get_graph().draw_png("agent_diagram.png")

asyncio.run(main())
File renamed without changes.
16 changes: 11 additions & 5 deletions src/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,20 @@
class AgentClient:
"""Client for interacting with the agent service."""

def __init__(self, base_url: str = "http://localhost:80", timeout: float | None = None) -> None:
def __init__(
self,
base_url: str = "http://localhost:80",
agent: str = "research-assistant",
timeout: float | None = None,
) -> None:
"""
Initialize the client.
Args:
base_url (str): The base URL of the agent service.
"""
self.base_url = base_url
self.agent = agent
self.auth_secret = os.getenv("AUTH_SECRET")
self.timeout = timeout

Expand Down Expand Up @@ -50,7 +56,7 @@ async def ainvoke(
request.model = model
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/invoke",
f"{self.base_url}/{self.agent}/invoke",
json=request.model_dump(),
headers=self._headers,
timeout=self.timeout,
Expand Down Expand Up @@ -79,7 +85,7 @@ def invoke(
if model:
request.model = model
response = httpx.post(
f"{self.base_url}/invoke",
f"{self.base_url}/{self.agent}/invoke",
json=request.model_dump(),
headers=self._headers,
timeout=self.timeout,
Expand Down Expand Up @@ -143,7 +149,7 @@ def stream(
request.model = model
with httpx.stream(
"POST",
f"{self.base_url}/stream",
f"{self.base_url}/{self.agent}/stream",
json=request.model_dump(),
headers=self._headers,
timeout=self.timeout,
Expand Down Expand Up @@ -189,7 +195,7 @@ async def astream(
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
f"{self.base_url}/stream",
f"{self.base_url}/{self.agent}/stream",
json=request.model_dump(),
headers=self._headers,
timeout=self.timeout,
Expand Down
32 changes: 32 additions & 0 deletions src/run_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import asyncio
from uuid import uuid4

from dotenv import load_dotenv
from langchain_core.runnables import RunnableConfig

load_dotenv()

from agents import DEFAULT_AGENT, agents # noqa: E402

agent = agents[DEFAULT_AGENT]


async def main() -> None:
inputs = {"messages": [("user", "Find me a recipe for chocolate chip cookies")]}
result = await agent.ainvoke(
inputs,
config=RunnableConfig(configurable={"thread_id": uuid4()}),
)
result["messages"][-1].pretty_print()

# Draw the agent graph as png
# requires:
# brew install graphviz
# export CFLAGS="-I $(brew --prefix graphviz)/include"
# export LDFLAGS="-L $(brew --prefix graphviz)/lib"
# pip install pygraphviz
#
# agent.get_graph().draw_png("agent_diagram.png")


asyncio.run(main())
Loading

0 comments on commit cb450b5

Please sign in to comment.