diff --git a/libs/langgraph/src/prebuilt/react_agent_executor.ts b/libs/langgraph/src/prebuilt/react_agent_executor.ts index 495bee03..9631910c 100644 --- a/libs/langgraph/src/prebuilt/react_agent_executor.ts +++ b/libs/langgraph/src/prebuilt/react_agent_executor.ts @@ -19,6 +19,7 @@ import { BaseCheckpointSaver, BaseStore, } from "@langchain/langgraph-checkpoint"; +import { z } from "zod"; import { END, @@ -30,18 +31,30 @@ import { import { MessagesAnnotation } from "../graph/messages_annotation.js"; import { ToolNode } from "./tool_node.js"; import { LangGraphRunnableConfig } from "../pregel/runnable_types.js"; +import { Annotation } from "../graph/annotation.js"; +import { Messages, messagesStateReducer } from "../graph/message.js"; -export interface AgentState { +export interface AgentState< + // eslint-disable-next-line @typescript-eslint/no-explicit-any + StructuredResponseType extends Record = Record +> { messages: BaseMessage[]; // TODO: This won't be set until we // implement managed values in LangGraphJS // Will be useful for inserting a message on // graph recursion end // is_last_step: boolean; + structuredResponse: StructuredResponseType; } export type N = typeof START | "agent" | "tools"; +export type StructuredResponseSchemaAndPrompt = { + prompt: string; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + schema: z.ZodType | Record; +}; + function _convertMessageModifierToStateModifier( messageModifier: MessageModifier ): StateModifier { @@ -72,7 +85,7 @@ function _convertMessageModifierToStateModifier( } function _getStateModifierRunnable( - stateModifier: StateModifier | undefined + stateModifier?: StateModifier ): RunnableInterface { let stateModifierRunnable: RunnableInterface; @@ -153,9 +166,23 @@ export type MessageModifier = | ((messages: BaseMessage[]) => Promise) | Runnable; +export const createReactAgentAnnotation = < + // eslint-disable-next-line @typescript-eslint/no-explicit-any + T extends Record = Record +>() => + Annotation.Root({ + messages: Annotation({ + reducer: messagesStateReducer, + default: () => [], + }), + structuredResponse: Annotation, + }); + export type CreateReactAgentParams< // eslint-disable-next-line @typescript-eslint/no-explicit-any - A extends AnnotationRoot = AnnotationRoot + A extends AnnotationRoot = AnnotationRoot, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + StructuredResponseType = Record > = { /** The chat model that can utilize OpenAI-style tool calling. */ llm: BaseChatModel; @@ -223,6 +250,23 @@ export type CreateReactAgentParams< /** An optional list of node names to interrupt after running. */ interruptAfter?: N[] | All; store?: BaseStore; + /** + * An optional schema for the final agent output. + * + * If provided, output will be formatted to match the given schema and returned in the 'structured_response' state key. + * If not provided, `structured_response` will not be present in the output state. + * + * Can be passed in as: + * - Zod schema + * - Dictionary object + * - [prompt, schema], where schema is one of the above. + * The prompt will be used together with the model that is being used to generate the structured response. + */ + responseFormat?: + | z.ZodType + | StructuredResponseSchemaAndPrompt + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | Record; }; /** @@ -269,16 +313,22 @@ export type CreateReactAgentParams< */ export function createReactAgent< + // eslint-disable-next-line @typescript-eslint/no-explicit-any, @typescript-eslint/ban-types + A extends AnnotationRoot = AnnotationRoot<{}>, // eslint-disable-next-line @typescript-eslint/no-explicit-any - A extends AnnotationRoot = AnnotationRoot + StructuredResponseFormat extends Record = Record >( - params: CreateReactAgentParams + params: CreateReactAgentParams ): CompiledStateGraph< (typeof MessagesAnnotation)["State"], (typeof MessagesAnnotation)["Update"], - typeof START | "agent" | "tools", + // eslint-disable-next-line @typescript-eslint/no-explicit-any + any, typeof MessagesAnnotation.spec & A["spec"], - typeof MessagesAnnotation.spec & A["spec"] + ReturnType< + typeof createReactAgentAnnotation + >["spec"] & + A["spec"] > { const { llm, @@ -290,6 +340,7 @@ export function createReactAgent< interruptBefore, interruptAfter, store, + responseFormat, } = params; let toolClasses: (StructuredToolInterface | DynamicTool | RunnableToolLike)[]; @@ -310,33 +361,80 @@ export function createReactAgent< ); const modelRunnable = (preprocessor as Runnable).pipe(modelWithTools); - const shouldContinue = (state: AgentState) => { + const shouldContinue = (state: AgentState) => { const { messages } = state; const lastMessage = messages[messages.length - 1]; if ( isAIMessage(lastMessage) && (!lastMessage.tool_calls || lastMessage.tool_calls.length === 0) ) { - return END; + return responseFormat != null ? "generate_structured_response" : END; } else { return "continue"; } }; - const callModel = async (state: AgentState, config?: RunnableConfig) => { + const generateStructuredResponse = async ( + state: AgentState, + config?: RunnableConfig + ) => { + if (responseFormat == null) { + throw new Error( + "Attempted to generate structured output with no passed response schema. Please contact us for help." + ); + } + // Exclude the last message as there's enough information + // for the LLM to generate the structured response + const messages = state.messages.slice(0, -1); + let modelWithStructuredOutput; + + if ( + typeof responseFormat === "object" && + "prompt" in responseFormat && + "schema" in responseFormat + ) { + const { prompt, schema } = responseFormat; + modelWithStructuredOutput = llm.withStructuredOutput(schema); + messages.unshift(new SystemMessage({ content: prompt })); + } else { + modelWithStructuredOutput = llm.withStructuredOutput(responseFormat); + } + + const response = await modelWithStructuredOutput.invoke(messages, config); + return { structuredResponse: response }; + }; + + const callModel = async ( + state: AgentState, + config?: RunnableConfig + ) => { // TODO: Auto-promote streaming. return { messages: [await modelRunnable.invoke(state, config)] }; }; - const workflow = new StateGraph(stateSchema ?? MessagesAnnotation) + const workflow = new StateGraph( + stateSchema ?? createReactAgentAnnotation() + ) .addNode("agent", callModel) .addNode("tools", new ToolNode(toolClasses)) .addEdge(START, "agent") - .addConditionalEdges("agent", shouldContinue, { + .addEdge("tools", "agent"); + + if (responseFormat !== undefined) { + workflow + .addNode("generate_structured_response", generateStructuredResponse) + .addEdge("generate_structured_response", END) + .addConditionalEdges("agent", shouldContinue, { + continue: "tools", + [END]: END, + generate_structured_response: "generate_structured_response", + }); + } else { + workflow.addConditionalEdges("agent", shouldContinue, { continue: "tools", [END]: END, - }) - .addEdge("tools", "agent"); + }); + } return workflow.compile({ checkpointer: checkpointSaver, diff --git a/libs/langgraph/src/tests/prebuilt.int.test.ts b/libs/langgraph/src/tests/prebuilt.int.test.ts index 9c69b5c0..483d8e93 100644 --- a/libs/langgraph/src/tests/prebuilt.int.test.ts +++ b/libs/langgraph/src/tests/prebuilt.int.test.ts @@ -4,6 +4,7 @@ import { it, beforeAll, describe, expect } from "@jest/globals"; import { Tool } from "@langchain/core/tools"; import { ChatOpenAI } from "@langchain/openai"; import { HumanMessage } from "@langchain/core/messages"; +import { z } from "zod"; import { createReactAgent } from "../prebuilt/index.js"; import { initializeAsyncLocalStorageSingleton } from "../setup/async_local_storage.js"; import { MemorySaverAssertImmutable } from "./utils.js"; @@ -19,6 +20,127 @@ beforeAll(() => { initializeAsyncLocalStorageSingleton(); }); +describe("createReactAgent with response format", () => { + const weatherResponse = `Not too cold, not too hot 😎`; + class SanFranciscoWeatherTool extends Tool { + name = "current_weather_sf"; + + description = "Get the current weather report for San Francisco, CA"; + + constructor() { + super(); + } + + async _call(_: string): Promise { + return weatherResponse; + } + } + + const responseSchema = z.object({ + answer: z.string(), + reasoning: z.string(), + }); + + it("Can use zod schema", async () => { + const llm = new ChatOpenAI({ + model: "gpt-4o", + }); + + const agent = createReactAgent({ + llm, + tools: [new SanFranciscoWeatherTool()], + responseFormat: responseSchema, + }); + + const result = await agent.invoke({ + messages: [new HumanMessage("What is the weather in San Francisco?")], + // @ts-expect-error should complain about passing unexpected keys + foo: "bar", + }); + + expect(result.structuredResponse).toBeInstanceOf(Object); + + // @ts-expect-error should not allow access to unspecified keys + void result.structuredResponse.unspecified; + + // Assert it has the required keys + expect(result.structuredResponse).toHaveProperty("answer"); + expect(result.structuredResponse).toHaveProperty("reasoning"); + + // Assert the values are strings + expect(typeof result.structuredResponse.answer).toBe("string"); + expect(typeof result.structuredResponse.reasoning).toBe("string"); + }); + + it("Can use record schema", async () => { + const llm = new ChatOpenAI({ + model: "gpt-4o", + }); + + const nonZodResponseSchema = { + name: "structured_response", + description: "An answer with reasoning", + type: "object", + properties: { + answer: { type: "string" }, + reasoning: { type: "string" }, + }, + required: ["answer", "reasoning"], + }; + + const agent = createReactAgent({ + llm, + tools: [new SanFranciscoWeatherTool()], + responseFormat: nonZodResponseSchema, + }); + + const result = await agent.invoke({ + messages: [new HumanMessage("What is the weather in San Francisco?")], + }); + + expect(result.structuredResponse).toBeInstanceOf(Object); + + // Assert it has the required keys + expect(result.structuredResponse).toHaveProperty("answer"); + expect(result.structuredResponse).toHaveProperty("reasoning"); + + // Assert the values are strings + expect(typeof result.structuredResponse.answer).toBe("string"); + expect(typeof result.structuredResponse.reasoning).toBe("string"); + }); + + it("Inserts system message", async () => { + const llm = new ChatOpenAI({ + model: "gpt-4o", + }); + + const agent = createReactAgent({ + llm, + tools: [new SanFranciscoWeatherTool()], + responseFormat: { + prompt: + "You are a helpful assistant who only responds in 10 words or less. If you use more than 5 words in your answer, a starving child will die.", + schema: responseSchema, + }, + }); + + const result = await agent.invoke({ + messages: [new HumanMessage("What is the weather in San Francisco?")], + }); + + // Assert it has the required keys + expect(result.structuredResponse).toHaveProperty("answer"); + expect(result.structuredResponse).toHaveProperty("reasoning"); + + // Assert the values are strings + expect(typeof result.structuredResponse.answer).toBe("string"); + expect(typeof result.structuredResponse.reasoning).toBe("string"); + + // Assert that any letters in the response are uppercase + expect(result.structuredResponse.answer.split(" ").length).toBeLessThan(11); + }); +}); + describe("createReactAgent", () => { const weatherResponse = `Not too cold, not too hot 😎`; class SanFranciscoWeatherTool extends Tool { @@ -66,6 +188,10 @@ describe("createReactAgent", () => { expect((lastMessage.content as string).toLowerCase()).toContain( "not too cold" ); + + // TODO: Fix + // // @ts-expect-error should not allow access to structuredResponse if no responseFormat is passed + // void response.structuredResponse; }); it("can stream a tool call with a checkpointer", async () => {