diff --git a/langchain-core/src/messages/ai.ts b/langchain-core/src/messages/ai.ts index b1b4c5261378..001413dad45b 100644 --- a/langchain-core/src/messages/ai.ts +++ b/langchain-core/src/messages/ai.ts @@ -7,13 +7,10 @@ import { type MessageType, BaseMessageFields, _mergeLists, -} from "./base.js"; -import { - InvalidToolCall, ToolCall, ToolCallChunk, - defaultToolCallParser, -} from "./tool.js"; +} from "./base.js"; +import { InvalidToolCall, defaultToolCallParser } from "./tool.js"; export type AIMessageFields = BaseMessageFields & { tool_calls?: ToolCall[]; diff --git a/langchain-core/src/messages/base.ts b/langchain-core/src/messages/base.ts index 521f5b77935b..c5a60e81e56b 100644 --- a/langchain-core/src/messages/base.ts +++ b/langchain-core/src/messages/base.ts @@ -172,6 +172,23 @@ function stringifyWithDepthLimit(obj: any, depthLimit: number): string { return JSON.stringify(helper(obj, 0), null, 2); } +/** + * A call to a tool. + * @property {string} name - The name of the tool to be called + * @property {Record} args - The arguments to the tool call + * @property {string} [id] - If provided, an identifier associated with the tool call + */ +export type ToolCall = { + name: string; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + args: Record; + + id?: string; + + type?: "tool_call"; +}; + /** * Base class for all types of messages in a conversation. It includes * properties like `content`, `name`, and `additional_kwargs`. It also @@ -219,6 +236,8 @@ export abstract class BaseMessage */ id?: string; + tool_calls?: never[] | ToolCall[]; + /** * @deprecated Use .getType() instead or import the proper typeguard. * For example: @@ -457,6 +476,76 @@ export function _mergeObj( } } +/** + * A chunk of a tool call (e.g., as part of a stream). + * When merging ToolCallChunks (e.g., via AIMessageChunk.__add__), + * all string attributes are concatenated. Chunks are only merged if their + * values of `index` are equal and not None. + * + * @example + * ```ts + * const leftChunks = [ + * { + * name: "foo", + * args: '{"a":', + * index: 0 + * } + * ]; + * + * const leftAIMessageChunk = new AIMessageChunk({ + * content: "", + * tool_call_chunks: leftChunks + * }); + * + * const rightChunks = [ + * { + * name: undefined, + * args: '1}', + * index: 0 + * } + * ]; + * + * const rightAIMessageChunk = new AIMessageChunk({ + * content: "", + * tool_call_chunks: rightChunks + * }); + * + * const result = leftAIMessageChunk.concat(rightAIMessageChunk); + * // result.tool_call_chunks is equal to: + * // [ + * // { + * // name: "foo", + * // args: '{"a":1}' + * // index: 0 + * // } + * // ] + * ``` + * + * @property {string} [name] - If provided, a substring of the name of the tool to be called + * @property {string} [args] - If provided, a JSON substring of the arguments to the tool call + * @property {string} [id] - If provided, a substring of an identifier for the tool call + * @property {number} [index] - If provided, the index of the tool call in a sequence + */ +export type ToolCallChunk = { + name?: string; + + args?: string; + + id?: string; + + index?: number; + + type?: "tool_call_chunk"; +}; + +export type InvalidToolCall = { + name?: string; + args?: string; + id?: string; + error?: string; + type?: "invalid_tool_call"; +}; + /** * Represents a chunk of a message, which can be concatenated with other * message chunks. It includes a method `_merge_kwargs_dict()` for merging @@ -465,6 +554,10 @@ export function _mergeObj( * of `BaseMessageChunk` instances. */ export abstract class BaseMessageChunk extends BaseMessage { + tool_call_chunks?: never[] | ToolCallChunk[]; + + invalid_tool_calls?: never[] | InvalidToolCall[]; + abstract concat(chunk: BaseMessageChunk): BaseMessageChunk; } diff --git a/langchain-core/src/messages/chat.ts b/langchain-core/src/messages/chat.ts index 376c05cceb84..a04081e75b13 100644 --- a/langchain-core/src/messages/chat.ts +++ b/langchain-core/src/messages/chat.ts @@ -18,6 +18,8 @@ export class ChatMessage extends BaseMessage implements ChatMessageFieldsWithRole { + declare tool_calls?: never[]; + static lc_name() { return "ChatMessage"; } @@ -62,6 +64,12 @@ export class ChatMessage * other chat message chunks. */ export class ChatMessageChunk extends BaseMessageChunk { + declare tool_calls?: never[]; + + declare tool_call_chunks?: never[]; + + declare invalid_tool_calls?: never[]; + static lc_name() { return "ChatMessageChunk"; } diff --git a/langchain-core/src/messages/function.ts b/langchain-core/src/messages/function.ts index 7e3455b8ca51..7cb4c832f496 100644 --- a/langchain-core/src/messages/function.ts +++ b/langchain-core/src/messages/function.ts @@ -15,6 +15,8 @@ export interface FunctionMessageFieldsWithName extends BaseMessageFields { * Represents a function message in a conversation. */ export class FunctionMessage extends BaseMessage { + declare tool_calls?: never[]; + static lc_name() { return "FunctionMessage"; } @@ -49,6 +51,12 @@ export class FunctionMessage extends BaseMessage { * with other function message chunks. */ export class FunctionMessageChunk extends BaseMessageChunk { + declare tool_calls?: never[]; + + declare tool_call_chunks?: never[]; + + declare invalid_tool_calls?: never[]; + static lc_name() { return "FunctionMessageChunk"; } diff --git a/langchain-core/src/messages/human.ts b/langchain-core/src/messages/human.ts index 30b5354ee509..ba34c5b8f12a 100644 --- a/langchain-core/src/messages/human.ts +++ b/langchain-core/src/messages/human.ts @@ -10,6 +10,8 @@ import { * Represents a human message in a conversation. */ export class HumanMessage extends BaseMessage { + declare tool_calls?: never[]; + static lc_name() { return "HumanMessage"; } @@ -24,6 +26,12 @@ export class HumanMessage extends BaseMessage { * other human message chunks. */ export class HumanMessageChunk extends BaseMessageChunk { + declare tool_calls?: never[]; + + declare tool_call_chunks?: never[]; + + declare invalid_tool_calls?: never[]; + static lc_name() { return "HumanMessageChunk"; } diff --git a/langchain-core/src/messages/modifier.ts b/langchain-core/src/messages/modifier.ts index aab91151a8fb..9457adffbf63 100644 --- a/langchain-core/src/messages/modifier.ts +++ b/langchain-core/src/messages/modifier.ts @@ -12,6 +12,8 @@ export interface RemoveMessageFields * Message responsible for deleting other messages. */ export class RemoveMessage extends BaseMessage { + declare tool_calls?: never[]; + /** * The ID of the message to remove. */ diff --git a/langchain-core/src/messages/system.ts b/langchain-core/src/messages/system.ts index ae91a240e83f..c258dff10a19 100644 --- a/langchain-core/src/messages/system.ts +++ b/langchain-core/src/messages/system.ts @@ -10,6 +10,8 @@ import { * Represents a system message in a conversation. */ export class SystemMessage extends BaseMessage { + declare tool_calls?: never[]; + static lc_name() { return "SystemMessage"; } @@ -24,6 +26,12 @@ export class SystemMessage extends BaseMessage { * other system message chunks. */ export class SystemMessageChunk extends BaseMessageChunk { + declare tool_calls?: never[]; + + declare tool_call_chunks?: never[]; + + declare invalid_tool_calls?: never[]; + static lc_name() { return "SystemMessageChunk"; } diff --git a/langchain-core/src/messages/tests/message_utils.test.ts b/langchain-core/src/messages/tests/message_utils.test.ts index 222d9b42372a..9f137ba9a8b6 100644 --- a/langchain-core/src/messages/tests/message_utils.test.ts +++ b/langchain-core/src/messages/tests/message_utils.test.ts @@ -4,16 +4,22 @@ import { mergeMessageRuns, trimMessages, } from "../transformers.js"; -import { AIMessage } from "../ai.js"; -import { ChatMessage } from "../chat.js"; -import { HumanMessage } from "../human.js"; -import { SystemMessage } from "../system.js"; -import { BaseMessage } from "../base.js"; +import { AIMessage, isAIMessage, isAIMessageChunk } from "../ai.js"; +import { ChatMessage, isChatMessage, isChatMessageChunk } from "../chat.js"; +import { HumanMessage, isHumanMessage, isHumanMessageChunk } from "../human.js"; +import { + isSystemMessage, + isSystemMessageChunk, + SystemMessage, + SystemMessageChunk, +} from "../system.js"; +import { BaseMessage, BaseMessageChunk } from "../base.js"; import { getBufferString, mapChatMessagesToStoredMessages, mapStoredMessagesToChatMessages, } from "../utils.js"; +import { isToolMessage, isToolMessageChunk } from "../tool.js"; describe("filterMessage", () => { const getMessages = () => [ @@ -520,3 +526,91 @@ describe("chat message conversions", () => { expect(convertedBackMessages).toEqual(originalMessages); }); }); + +it("Should narrow tool call typing when accessing a base message array", async () => { + const messages: BaseMessage[] = [new SystemMessage("test")]; + if (messages[0].tool_calls?.[0] !== undefined) { + // Allow checking existence on BaseMessage with no errors + void messages[0].tool_calls[0].args; + } + + const msg = messages[0]; + if (isAIMessage(msg)) { + // Should allow access from AI messages + void msg.tool_calls?.[0].args; + } + if (isHumanMessage(msg)) { + // @ts-expect-error Typing should not allow access from human messages + void msg.tool_calls?.[0].args; + } + if (isSystemMessage(msg)) { + // @ts-expect-error Typing should not allow access from system messages + void msg.tool_calls?.[0].args; + } + if (isToolMessage(msg)) { + // @ts-expect-error Typing should not allow access from tool messages + void msg.tool_calls?.[0].args; + } + if (isChatMessage(msg)) { + // @ts-expect-error Typing should not allow access from chat messages + void msg.tool_calls?.[0].args; + } + + const messageChunks: BaseMessageChunk[] = [new SystemMessageChunk("test")]; + if (messageChunks[0].tool_calls?.[0] !== undefined) { + // Allow checking existence on BaseMessage with no errors + void messageChunks[0].tool_calls[0].args; + } + + if (messageChunks[0].tool_call_chunks?.[0] !== undefined) { + // Allow checking existence on BaseMessage with no errors + void messageChunks[0].tool_call_chunks[0].args; + } + + if (messageChunks[0].invalid_tool_calls?.[0] !== undefined) { + // Allow checking existence on BaseMessage with no errors + void messageChunks[0].invalid_tool_calls[0].args; + } + + const msgChunk = messageChunks[0]; + if (isAIMessageChunk(msgChunk)) { + // Typing should allow access from AI message chunks + void msgChunk.tool_calls?.[0].args; + // Typing should allow access from AI message chunks + void msgChunk.tool_call_chunks?.[0].args; + // Typing should allow access from AI message chunks + void msgChunk.invalid_tool_calls?.[0].args; + } + if (isHumanMessageChunk(msgChunk)) { + // @ts-expect-error Typing should not allow access from human message chunks + void msgChunk.tool_calls?.[0].args; + // @ts-expect-error Typing should not allow access from human message chunks + void msgChunk.tool_call_chunks?.[0].args; + // @ts-expect-error Typing should not allow access from human message chunks + void msgChunk.invalid_tool_calls?.[0].args; + } + if (isSystemMessageChunk(msgChunk)) { + // @ts-expect-error Typing should not allow access from system message chunks + void msgChunk.tool_calls?.[0].args; + // @ts-expect-error Typing should not allow access from system message chunks + void msgChunk.tool_call_chunks?.[0].args; + // @ts-expect-error Typing should not allow access from system message chunks + void msgChunk.invalid_tool_calls?.[0].args; + } + if (isToolMessageChunk(msgChunk)) { + // @ts-expect-error Typing should not allow access from tool message chunks + void msgChunk.tool_calls?.[0].args; + // @ts-expect-error Typing should not allow access from tool message chunks + void msgChunk.tool_call_chunks?.[0].args; + // @ts-expect-error Typing should not allow access from tool message chunks + void msgChunk.invalid_tool_calls?.[0].args; + } + if (isChatMessageChunk(msgChunk)) { + // @ts-expect-error Typing should not allow access from chat message chunks + void msgChunk.tool_calls?.[0].args; + // @ts-expect-error Typing should not allow access from chat message chunks + void msgChunk.tool_call_chunks?.[0].args; + // @ts-expect-error Typing should not allow access from chat message chunks + void msgChunk.invalid_tool_calls?.[0].args; + } +}); diff --git a/langchain-core/src/messages/tool.ts b/langchain-core/src/messages/tool.ts index 1b3f555a7349..22d9d266aaae 100644 --- a/langchain-core/src/messages/tool.ts +++ b/langchain-core/src/messages/tool.ts @@ -7,6 +7,9 @@ import { type MessageType, _mergeObj, _mergeStatus, + ToolCall, + ToolCallChunk, + InvalidToolCall, } from "./base.js"; export interface ToolMessageFieldsWithToolCallId extends BaseMessageFields { @@ -51,6 +54,8 @@ export function isDirectToolOutput(x: unknown): x is DirectToolOutput { * Represents a tool message in a conversation. */ export class ToolMessage extends BaseMessage implements DirectToolOutput { + declare tool_calls?: never[]; + static lc_name() { return "ToolMessage"; } @@ -125,6 +130,12 @@ export class ToolMessage extends BaseMessage implements DirectToolOutput { * with other tool message chunks. */ export class ToolMessageChunk extends BaseMessageChunk { + declare tool_calls?: never[]; + + declare tool_call_chunks?: never[]; + + declare invalid_tool_calls?: never[]; + tool_call_id: string; /** @@ -185,92 +196,7 @@ export class ToolMessageChunk extends BaseMessageChunk { } } -/** - * A call to a tool. - * @property {string} name - The name of the tool to be called - * @property {Record} args - The arguments to the tool call - * @property {string} [id] - If provided, an identifier associated with the tool call - */ -export type ToolCall = { - name: string; - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - args: Record; - - id?: string; - - type?: "tool_call"; -}; - -/** - * A chunk of a tool call (e.g., as part of a stream). - * When merging ToolCallChunks (e.g., via AIMessageChunk.__add__), - * all string attributes are concatenated. Chunks are only merged if their - * values of `index` are equal and not None. - * - * @example - * ```ts - * const leftChunks = [ - * { - * name: "foo", - * args: '{"a":', - * index: 0 - * } - * ]; - * - * const leftAIMessageChunk = new AIMessageChunk({ - * content: "", - * tool_call_chunks: leftChunks - * }); - * - * const rightChunks = [ - * { - * name: undefined, - * args: '1}', - * index: 0 - * } - * ]; - * - * const rightAIMessageChunk = new AIMessageChunk({ - * content: "", - * tool_call_chunks: rightChunks - * }); - * - * const result = leftAIMessageChunk.concat(rightAIMessageChunk); - * // result.tool_call_chunks is equal to: - * // [ - * // { - * // name: "foo", - * // args: '{"a":1}' - * // index: 0 - * // } - * // ] - * ``` - * - * @property {string} [name] - If provided, a substring of the name of the tool to be called - * @property {string} [args] - If provided, a JSON substring of the arguments to the tool call - * @property {string} [id] - If provided, a substring of an identifier for the tool call - * @property {number} [index] - If provided, the index of the tool call in a sequence - */ -export type ToolCallChunk = { - name?: string; - - args?: string; - - id?: string; - - index?: number; - - type?: "tool_call_chunk"; -}; - -export type InvalidToolCall = { - name?: string; - args?: string; - id?: string; - error?: string; - type?: "invalid_tool_call"; -}; +export type { ToolCall, ToolCallChunk, InvalidToolCall }; export function defaultToolCallParser( // eslint-disable-next-line @typescript-eslint/no-explicit-any