Skip to content

Commit

Permalink
Allow checking tool_calls on any BaseMessage without type casts
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 committed Jan 7, 2025
1 parent 2b43011 commit f9bf4e3
Show file tree
Hide file tree
Showing 9 changed files with 240 additions and 96 deletions.
7 changes: 2 additions & 5 deletions langchain-core/src/messages/ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@ import {
type MessageType,
BaseMessageFields,
_mergeLists,
} from "./base.js";
import {
InvalidToolCall,
ToolCall,
ToolCallChunk,
defaultToolCallParser,
} from "./tool.js";
} from "./base.js";
import { InvalidToolCall, defaultToolCallParser } from "./tool.js";

export type AIMessageFields = BaseMessageFields & {
tool_calls?: ToolCall[];
Expand Down
93 changes: 93 additions & 0 deletions langchain-core/src/messages/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,23 @@ function stringifyWithDepthLimit(obj: any, depthLimit: number): string {
return JSON.stringify(helper(obj, 0), null, 2);
}

/**
* A call to a tool.
* @property {string} name - The name of the tool to be called
* @property {Record<string, any>} args - The arguments to the tool call
* @property {string} [id] - If provided, an identifier associated with the tool call
*/
export type ToolCall = {
name: string;

// eslint-disable-next-line @typescript-eslint/no-explicit-any
args: Record<string, any>;

id?: string;

type?: "tool_call";
};

/**
* Base class for all types of messages in a conversation. It includes
* properties like `content`, `name`, and `additional_kwargs`. It also
Expand Down Expand Up @@ -219,6 +236,8 @@ export abstract class BaseMessage
*/
id?: string;

tool_calls?: never[] | ToolCall[];

/**
* @deprecated Use .getType() instead or import the proper typeguard.
* For example:
Expand Down Expand Up @@ -457,6 +476,76 @@ export function _mergeObj<T = any>(
}
}

/**
* A chunk of a tool call (e.g., as part of a stream).
* When merging ToolCallChunks (e.g., via AIMessageChunk.__add__),
* all string attributes are concatenated. Chunks are only merged if their
* values of `index` are equal and not None.
*
* @example
* ```ts
* const leftChunks = [
* {
* name: "foo",
* args: '{"a":',
* index: 0
* }
* ];
*
* const leftAIMessageChunk = new AIMessageChunk({
* content: "",
* tool_call_chunks: leftChunks
* });
*
* const rightChunks = [
* {
* name: undefined,
* args: '1}',
* index: 0
* }
* ];
*
* const rightAIMessageChunk = new AIMessageChunk({
* content: "",
* tool_call_chunks: rightChunks
* });
*
* const result = leftAIMessageChunk.concat(rightAIMessageChunk);
* // result.tool_call_chunks is equal to:
* // [
* // {
* // name: "foo",
* // args: '{"a":1}'
* // index: 0
* // }
* // ]
* ```
*
* @property {string} [name] - If provided, a substring of the name of the tool to be called
* @property {string} [args] - If provided, a JSON substring of the arguments to the tool call
* @property {string} [id] - If provided, a substring of an identifier for the tool call
* @property {number} [index] - If provided, the index of the tool call in a sequence
*/
export type ToolCallChunk = {
name?: string;

args?: string;

id?: string;

index?: number;

type?: "tool_call_chunk";
};

export type InvalidToolCall = {
name?: string;
args?: string;
id?: string;
error?: string;
type?: "invalid_tool_call";
};

/**
* Represents a chunk of a message, which can be concatenated with other
* message chunks. It includes a method `_merge_kwargs_dict()` for merging
Expand All @@ -465,6 +554,10 @@ export function _mergeObj<T = any>(
* of `BaseMessageChunk` instances.
*/
export abstract class BaseMessageChunk extends BaseMessage {
tool_call_chunks?: never[] | ToolCallChunk[];

invalid_tool_calls?: never[] | InvalidToolCall[];

abstract concat(chunk: BaseMessageChunk): BaseMessageChunk;
}

Expand Down
8 changes: 8 additions & 0 deletions langchain-core/src/messages/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ export class ChatMessage
extends BaseMessage
implements ChatMessageFieldsWithRole
{
declare tool_calls?: never[];

static lc_name() {
return "ChatMessage";
}
Expand Down Expand Up @@ -62,6 +64,12 @@ export class ChatMessage
* other chat message chunks.
*/
export class ChatMessageChunk extends BaseMessageChunk {
declare tool_calls?: never[];

declare tool_call_chunks?: never[];

declare invalid_tool_calls?: never[];

static lc_name() {
return "ChatMessageChunk";
}
Expand Down
8 changes: 8 additions & 0 deletions langchain-core/src/messages/function.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ export interface FunctionMessageFieldsWithName extends BaseMessageFields {
* Represents a function message in a conversation.
*/
export class FunctionMessage extends BaseMessage {
declare tool_calls?: never[];

static lc_name() {
return "FunctionMessage";
}
Expand Down Expand Up @@ -49,6 +51,12 @@ export class FunctionMessage extends BaseMessage {
* with other function message chunks.
*/
export class FunctionMessageChunk extends BaseMessageChunk {
declare tool_calls?: never[];

declare tool_call_chunks?: never[];

declare invalid_tool_calls?: never[];

static lc_name() {
return "FunctionMessageChunk";
}
Expand Down
8 changes: 8 additions & 0 deletions langchain-core/src/messages/human.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import {
* Represents a human message in a conversation.
*/
export class HumanMessage extends BaseMessage {
declare tool_calls?: never[];

static lc_name() {
return "HumanMessage";
}
Expand All @@ -24,6 +26,12 @@ export class HumanMessage extends BaseMessage {
* other human message chunks.
*/
export class HumanMessageChunk extends BaseMessageChunk {
declare tool_calls?: never[];

declare tool_call_chunks?: never[];

declare invalid_tool_calls?: never[];

static lc_name() {
return "HumanMessageChunk";
}
Expand Down
2 changes: 2 additions & 0 deletions langchain-core/src/messages/modifier.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ export interface RemoveMessageFields
* Message responsible for deleting other messages.
*/
export class RemoveMessage extends BaseMessage {
declare tool_calls?: never[];

/**
* The ID of the message to remove.
*/
Expand Down
8 changes: 8 additions & 0 deletions langchain-core/src/messages/system.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import {
* Represents a system message in a conversation.
*/
export class SystemMessage extends BaseMessage {
declare tool_calls?: never[];

static lc_name() {
return "SystemMessage";
}
Expand All @@ -24,6 +26,12 @@ export class SystemMessage extends BaseMessage {
* other system message chunks.
*/
export class SystemMessageChunk extends BaseMessageChunk {
declare tool_calls?: never[];

declare tool_call_chunks?: never[];

declare invalid_tool_calls?: never[];

static lc_name() {
return "SystemMessageChunk";
}
Expand Down
104 changes: 99 additions & 5 deletions langchain-core/src/messages/tests/message_utils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,22 @@ import {
mergeMessageRuns,
trimMessages,
} from "../transformers.js";
import { AIMessage } from "../ai.js";
import { ChatMessage } from "../chat.js";
import { HumanMessage } from "../human.js";
import { SystemMessage } from "../system.js";
import { BaseMessage } from "../base.js";
import { AIMessage, isAIMessage, isAIMessageChunk } from "../ai.js";
import { ChatMessage, isChatMessage, isChatMessageChunk } from "../chat.js";
import { HumanMessage, isHumanMessage, isHumanMessageChunk } from "../human.js";
import {
isSystemMessage,
isSystemMessageChunk,
SystemMessage,
SystemMessageChunk,
} from "../system.js";
import { BaseMessage, BaseMessageChunk } from "../base.js";
import {
getBufferString,
mapChatMessagesToStoredMessages,
mapStoredMessagesToChatMessages,
} from "../utils.js";
import { isToolMessage, isToolMessageChunk } from "../tool.js";

describe("filterMessage", () => {
const getMessages = () => [
Expand Down Expand Up @@ -520,3 +526,91 @@ describe("chat message conversions", () => {
expect(convertedBackMessages).toEqual(originalMessages);
});
});

it("Should narrow tool call typing when accessing a base message array", async () => {
const messages: BaseMessage[] = [new SystemMessage("test")];
if (messages[0].tool_calls?.[0] !== undefined) {
// Allow checking existence on BaseMessage with no errors
void messages[0].tool_calls[0].args;
}

const msg = messages[0];
if (isAIMessage(msg)) {
// Should allow access from AI messages
void msg.tool_calls?.[0].args;
}
if (isHumanMessage(msg)) {
// @ts-expect-error Typing should not allow access from human messages
void msg.tool_calls?.[0].args;
}
if (isSystemMessage(msg)) {
// @ts-expect-error Typing should not allow access from system messages
void msg.tool_calls?.[0].args;
}
if (isToolMessage(msg)) {
// @ts-expect-error Typing should not allow access from tool messages
void msg.tool_calls?.[0].args;
}
if (isChatMessage(msg)) {
// @ts-expect-error Typing should not allow access from chat messages
void msg.tool_calls?.[0].args;
}

const messageChunks: BaseMessageChunk[] = [new SystemMessageChunk("test")];
if (messageChunks[0].tool_calls?.[0] !== undefined) {
// Allow checking existence on BaseMessage with no errors
void messageChunks[0].tool_calls[0].args;
}

if (messageChunks[0].tool_call_chunks?.[0] !== undefined) {
// Allow checking existence on BaseMessage with no errors
void messageChunks[0].tool_call_chunks[0].args;
}

if (messageChunks[0].invalid_tool_calls?.[0] !== undefined) {
// Allow checking existence on BaseMessage with no errors
void messageChunks[0].invalid_tool_calls[0].args;
}

const msgChunk = messageChunks[0];
if (isAIMessageChunk(msgChunk)) {
// Typing should allow access from AI message chunks
void msgChunk.tool_calls?.[0].args;
// Typing should allow access from AI message chunks
void msgChunk.tool_call_chunks?.[0].args;
// Typing should allow access from AI message chunks
void msgChunk.invalid_tool_calls?.[0].args;
}
if (isHumanMessageChunk(msgChunk)) {
// @ts-expect-error Typing should not allow access from human message chunks
void msgChunk.tool_calls?.[0].args;
// @ts-expect-error Typing should not allow access from human message chunks
void msgChunk.tool_call_chunks?.[0].args;
// @ts-expect-error Typing should not allow access from human message chunks
void msgChunk.invalid_tool_calls?.[0].args;
}
if (isSystemMessageChunk(msgChunk)) {
// @ts-expect-error Typing should not allow access from system message chunks
void msgChunk.tool_calls?.[0].args;
// @ts-expect-error Typing should not allow access from system message chunks
void msgChunk.tool_call_chunks?.[0].args;
// @ts-expect-error Typing should not allow access from system message chunks
void msgChunk.invalid_tool_calls?.[0].args;
}
if (isToolMessageChunk(msgChunk)) {
// @ts-expect-error Typing should not allow access from tool message chunks
void msgChunk.tool_calls?.[0].args;
// @ts-expect-error Typing should not allow access from tool message chunks
void msgChunk.tool_call_chunks?.[0].args;
// @ts-expect-error Typing should not allow access from tool message chunks
void msgChunk.invalid_tool_calls?.[0].args;
}
if (isChatMessageChunk(msgChunk)) {
// @ts-expect-error Typing should not allow access from chat message chunks
void msgChunk.tool_calls?.[0].args;
// @ts-expect-error Typing should not allow access from chat message chunks
void msgChunk.tool_call_chunks?.[0].args;
// @ts-expect-error Typing should not allow access from chat message chunks
void msgChunk.invalid_tool_calls?.[0].args;
}
});
Loading

0 comments on commit f9bf4e3

Please sign in to comment.