-
Notifications
You must be signed in to change notification settings - Fork 222
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for calling multiple agents on one service (#75)
* Support multiple agents (still only one in service) * Get the routes working, update client + tests * add TODO * Clean up basic agent call, update README
- Loading branch information
1 parent
b5eec80
commit cb450b5
Showing
18 changed files
with
227 additions
and
109 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from agents.agents import DEFAULT_AGENT, agents | ||
|
||
__all__ = ["agents", "DEFAULT_AGENT"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from langgraph.graph.state import CompiledStateGraph | ||
|
||
from agents.chatbot import chatbot | ||
from agents.research_assistant import research_assistant | ||
|
||
DEFAULT_AGENT = "research-assistant" | ||
|
||
|
||
agents: dict[str, CompiledStateGraph] = { | ||
"chatbot": chatbot, | ||
"research-assistant": research_assistant, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from langchain_core.language_models.chat_models import BaseChatModel | ||
from langchain_core.messages import AIMessage | ||
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable | ||
from langgraph.checkpoint.memory import MemorySaver | ||
from langgraph.graph import END, MessagesState, StateGraph | ||
|
||
from agents.models import models | ||
|
||
|
||
class AgentState(MessagesState, total=False): | ||
"""`total=False` is PEP589 specs. | ||
documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality | ||
""" | ||
|
||
|
||
def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]: | ||
preprocessor = RunnableLambda( | ||
lambda state: state["messages"], | ||
name="StateModifier", | ||
) | ||
return preprocessor | model | ||
|
||
|
||
async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState: | ||
m = models[config["configurable"].get("model", "gpt-4o-mini")] | ||
model_runnable = wrap_model(m) | ||
response = await model_runnable.ainvoke(state, config) | ||
|
||
# We return a list, because this will get added to the existing list | ||
return {"messages": [response]} | ||
|
||
|
||
# Define the graph | ||
agent = StateGraph(AgentState) | ||
agent.add_node("model", acall_model) | ||
agent.set_entry_point("model") | ||
|
||
# Always END after blocking unsafe content | ||
agent.add_edge("model", END) | ||
|
||
chatbot = agent.compile( | ||
checkpointer=MemorySaver(), | ||
) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import os | ||
|
||
from langchain_anthropic import ChatAnthropic | ||
from langchain_core.language_models.chat_models import BaseChatModel | ||
from langchain_google_genai import ChatGoogleGenerativeAI | ||
from langchain_groq import ChatGroq | ||
from langchain_openai import ChatOpenAI | ||
|
||
# NOTE: models with streaming=True will send tokens as they are generated | ||
# if the /stream endpoint is called with stream_tokens=True (the default) | ||
models: dict[str, BaseChatModel] = {} | ||
if os.getenv("OPENAI_API_KEY") is not None: | ||
models["gpt-4o-mini"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.5, streaming=True) | ||
if os.getenv("GROQ_API_KEY") is not None: | ||
models["llama-3.1-70b"] = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.5) | ||
if os.getenv("GOOGLE_API_KEY") is not None: | ||
models["gemini-1.5-flash"] = ChatGoogleGenerativeAI( | ||
model="gemini-1.5-flash", temperature=0.5, streaming=True | ||
) | ||
if os.getenv("ANTHROPIC_API_KEY") is not None: | ||
models["claude-3-haiku"] = ChatAnthropic( | ||
model="claude-3-haiku-20240307", temperature=0.5, streaming=True | ||
) | ||
|
||
if not models: | ||
print("No LLM available. Please set API keys to enable at least one LLM.") | ||
if os.getenv("MODE") == "dev": | ||
print("FastAPI initialized failed. Please use Ctrl + C to exit uvicorn.") | ||
exit(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import asyncio | ||
from uuid import uuid4 | ||
|
||
from dotenv import load_dotenv | ||
from langchain_core.runnables import RunnableConfig | ||
|
||
load_dotenv() | ||
|
||
from agents import DEFAULT_AGENT, agents # noqa: E402 | ||
|
||
agent = agents[DEFAULT_AGENT] | ||
|
||
|
||
async def main() -> None: | ||
inputs = {"messages": [("user", "Find me a recipe for chocolate chip cookies")]} | ||
result = await agent.ainvoke( | ||
inputs, | ||
config=RunnableConfig(configurable={"thread_id": uuid4()}), | ||
) | ||
result["messages"][-1].pretty_print() | ||
|
||
# Draw the agent graph as png | ||
# requires: | ||
# brew install graphviz | ||
# export CFLAGS="-I $(brew --prefix graphviz)/include" | ||
# export LDFLAGS="-L $(brew --prefix graphviz)/lib" | ||
# pip install pygraphviz | ||
# | ||
# agent.get_graph().draw_png("agent_diagram.png") | ||
|
||
|
||
asyncio.run(main()) |
Oops, something went wrong.