Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(langgraph): add structured response format to prebuilt react agent #788

Merged
merged 6 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 69 additions & 5 deletions libs/langgraph/src/prebuilt/react_agent_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
BaseCheckpointSaver,
BaseStore,
} from "@langchain/langgraph-checkpoint";
import { z } from "zod";

import {
END,
Expand All @@ -30,14 +31,20 @@ import {
import { MessagesAnnotation } from "../graph/messages_annotation.js";
import { ToolNode } from "./tool_node.js";
import { LangGraphRunnableConfig } from "../pregel/runnable_types.js";
import { Annotation } from "../graph/annotation.js";
import { Messages, messagesStateReducer } from "../graph/message.js";

type StructuredResponse =
| z.ZodType<Record<string, unknown>>
jacoblee93 marked this conversation as resolved.
Show resolved Hide resolved
| Record<string, unknown>;
export interface AgentState {
messages: BaseMessage[];
// TODO: This won't be set until we
// implement managed values in LangGraphJS
// Will be useful for inserting a message on
// graph recursion end
// is_last_step: boolean;
structuredResponse: StructuredResponse;
}

export type N = typeof START | "agent" | "tools";
Expand Down Expand Up @@ -153,6 +160,14 @@ export type MessageModifier =
| ((messages: BaseMessage[]) => Promise<BaseMessage[]>)
| Runnable;

const ReactAgentAnnotation = Annotation.Root({
messages: Annotation<BaseMessage[], Messages>({
reducer: messagesStateReducer,
default: () => [],
}),
structuredResponse: Annotation<StructuredResponse>,
});

export type CreateReactAgentParams<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
A extends AnnotationRoot<any> = AnnotationRoot<any>
Expand Down Expand Up @@ -223,6 +238,19 @@ export type CreateReactAgentParams<
/** An optional list of node names to interrupt after running. */
interruptAfter?: N[] | All;
store?: BaseStore;
/**
* An optional schema for the final agent output.
*
* If provided, output will be formatted to match the given schema and returned in the 'structured_response' state key.
* If not provided, `structured_response` will not be present in the output state.
*
* Can be passed in as:
* - Zod schema
* - Dictionary object
* - [prompt, schema], where schema is one of the above.
* The prompt will be used together with the model that is being used to generate the structured response.
*/
responseFormat?: StructuredResponse | [string, StructuredResponse];
jacoblee93 marked this conversation as resolved.
Show resolved Hide resolved
};

/**
Expand Down Expand Up @@ -290,6 +318,7 @@ export function createReactAgent<
interruptBefore,
interruptAfter,
store,
responseFormat,
} = params;

let toolClasses: (StructuredToolInterface | DynamicTool | RunnableToolLike)[];
Expand Down Expand Up @@ -317,26 +346,61 @@ export function createReactAgent<
isAIMessage(lastMessage) &&
(!lastMessage.tool_calls || lastMessage.tool_calls.length === 0)
) {
return END;
return responseFormat ? "generate_structured_response" : END;
} else {
return "continue";
}
};

const generateStructuredResponse = async (
state: AgentState,
config?: RunnableConfig
) => {
// Exclude the last message as there's enough information
// for the LLM to generate the structured response
const messages = state.messages.slice(0, -1);
let structuredResponseSchema = responseFormat;

if (Array.isArray(responseFormat)) {
const [systemPrompt, schema] = responseFormat;
structuredResponseSchema = schema;
messages.unshift(new SystemMessage({ content: systemPrompt }));
}

const modelWithStructuredOutput = llm.withStructuredOutput(
structuredResponseSchema as StructuredResponse
);

const response = await modelWithStructuredOutput.invoke(messages, config);
return { structuredResponse: response };
};

const callModel = async (state: AgentState, config?: RunnableConfig) => {
// TODO: Auto-promote streaming.
return { messages: [await modelRunnable.invoke(state, config)] };
};

const workflow = new StateGraph(stateSchema ?? MessagesAnnotation)
const workflow = new StateGraph(stateSchema ?? ReactAgentAnnotation)
.addNode("agent", callModel)
.addNode("tools", new ToolNode(toolClasses))
.addEdge(START, "agent")
.addConditionalEdges("agent", shouldContinue, {
.addEdge("tools", "agent");

if (responseFormat) {
workflow
.addNode("generate_structured_response", generateStructuredResponse)
.addEdge("generate_structured_response", END)
.addConditionalEdges("agent", shouldContinue, {
continue: "tools",
[END]: END,
generate_structured_response: "generate_structured_response",
});
} else {
workflow.addConditionalEdges("agent", shouldContinue, {
continue: "tools",
[END]: END,
})
.addEdge("tools", "agent");
});
}

return workflow.compile({
checkpointer: checkpointSaver,
Expand Down
123 changes: 123 additions & 0 deletions libs/langgraph/src/tests/prebuilt.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { it, beforeAll, describe, expect } from "@jest/globals";
import { Tool } from "@langchain/core/tools";
import { ChatOpenAI } from "@langchain/openai";
import { HumanMessage } from "@langchain/core/messages";
import { z } from "zod";
import { createReactAgent } from "../prebuilt/index.js";
import { initializeAsyncLocalStorageSingleton } from "../setup/async_local_storage.js";
import { MemorySaverAssertImmutable } from "./utils.js";
Expand All @@ -19,6 +20,128 @@ beforeAll(() => {
initializeAsyncLocalStorageSingleton();
});

describe("createReactAgent with response format", () => {
const weatherResponse = `Not too cold, not too hot 😎`;
class SanFranciscoWeatherTool extends Tool {
name = "current_weather_sf";

description = "Get the current weather report for San Francisco, CA";

constructor() {
super();
}

async _call(_: string): Promise<string> {
return weatherResponse;
}
}

it("Can use zod schema", async () => {
const llm = new ChatOpenAI({
model: "gpt-4o",
});

const agent = createReactAgent({
llm,
tools: [new SanFranciscoWeatherTool()],
responseFormat: z.object({
answer: z.string(),
reasoning: z.string(),
}),
});

const result = await agent.invoke({
messages: [new HumanMessage("What is the weather in San Francisco?")],
});

expect(result.structuredResponse).toBeInstanceOf(Object);

// Assert it has the required keys
expect(result.structuredResponse).toHaveProperty("answer");
expect(result.structuredResponse).toHaveProperty("reasoning");

// Assert the values are strings
expect(typeof result.structuredResponse.answer).toBe("string");
expect(typeof result.structuredResponse.reasoning).toBe("string");
});

it("Can use record schema", async () => {
const llm = new ChatOpenAI({
model: "gpt-4o",
});

const agent = createReactAgent({
llm,
tools: [new SanFranciscoWeatherTool()],
responseFormat: {
name: "structured_response",
description: "An answer with reasoning",
type: "object",
properties: {
answer: { type: "string" },
reasoning: { type: "string" },
},
required: ["answer", "reasoning"],
},
});

const result = await agent.invoke({
messages: [new HumanMessage("What is the weather in San Francisco?")],
});

expect(result.structuredResponse).toBeInstanceOf(Object);

// Assert it has the required keys
expect(result.structuredResponse).toHaveProperty("answer");
expect(result.structuredResponse).toHaveProperty("reasoning");

// Assert the values are strings
expect(typeof result.structuredResponse.answer).toBe("string");
expect(typeof result.structuredResponse.reasoning).toBe("string");
});

it("Inserts system message", async () => {
const llm = new ChatOpenAI({
model: "gpt-4o",
});

const agent = createReactAgent({
llm,
tools: [new SanFranciscoWeatherTool()],
responseFormat: [
"You are a helpful assistant who only responds in 10 words or less. If you use more than 5 words in your answer, a starving child will die.",
{
name: "structured_response",
description: "An answer with reasoning",
type: "object",
properties: {
answer: { type: "string" },
reasoning: { type: "string" },
},
required: ["answer", "reasoning"],
},
],
});

const result = await agent.invoke({
messages: [new HumanMessage("What is the weather in San Francisco?")],
});

expect(result.structuredResponse).toBeInstanceOf(Object);

// Assert it has the required keys
expect(result.structuredResponse).toHaveProperty("answer");
expect(result.structuredResponse).toHaveProperty("reasoning");

// Assert the values are strings
expect(typeof result.structuredResponse.answer).toBe("string");
expect(typeof result.structuredResponse.reasoning).toBe("string");

// Assert that any letters in the response are uppercase
expect(result.structuredResponse.answer.split(" ").length).toBeLessThan(11);
});
});

describe("createReactAgent", () => {
const weatherResponse = `Not too cold, not too hot 😎`;
class SanFranciscoWeatherTool extends Tool {
Expand Down
Loading