diff --git a/src/schema/schema.py b/src/schema/schema.py index f6c7206..c9fb8c0 100644 --- a/src/schema/schema.py +++ b/src/schema/schema.py @@ -9,6 +9,9 @@ message_to_dict, messages_from_dict, ) +from langchain_core.messages import ( + ChatMessage as LangchainChatMessage, +) from pydantic import BaseModel, Field @@ -56,9 +59,9 @@ class StreamInput(UserInput): class ChatMessage(BaseModel): """Message in a chat.""" - type: Literal["human", "ai", "tool"] = Field( + type: Literal["human", "ai", "tool", "custom"] = Field( description="Role of the message.", - examples=["human", "ai", "tool"], + examples=["human", "ai", "tool", "custom"], ) content: str = Field( description="Content of the message.", @@ -82,6 +85,10 @@ class ChatMessage(BaseModel): description="Original LangChain message in serialized form.", default={}, ) + custom_data: dict[str, Any] = Field( + description="Custom message data.", + default={}, + ) @classmethod def from_langchain(cls, message: BaseMessage) -> "ChatMessage": @@ -112,9 +119,28 @@ def from_langchain(cls, message: BaseMessage) -> "ChatMessage": original=original, ) return tool_message + case LangchainChatMessage(): + if message.role == "custom": + custom_message = cls( + type="custom", + content="", + custom_data=message.content[0], + original=original, + ) + return custom_message + else: + raise ValueError(f"Unsupported chat message role: {message.role}") case _: raise ValueError(f"Unsupported message type: {message.__class__.__name__}") + @classmethod + def from_custom_data(cls, data: dict[str, Any]) -> "ChatMessage": + return cls( + type="custom", + content="", + custom_data=data, + ) + def to_langchain(self) -> BaseMessage: """Convert the ChatMessage to a LangChain message.""" if self.original: @@ -124,6 +150,11 @@ def to_langchain(self) -> BaseMessage: match self.type: case "human": return HumanMessage(content=self.content) + case "custom": + return LangchainChatMessage( + content=[self.custom_data], + role="custom", + ) case _: raise NotImplementedError(f"Unsupported message type: {self.type}")