Skip to content

Commit

Permalink
Add custom ChatMessage type (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshuaC215 authored Oct 28, 2024
1 parent 1f2312e commit 9266d9e
Showing 1 changed file with 33 additions and 2 deletions.
35 changes: 33 additions & 2 deletions src/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
message_to_dict,
messages_from_dict,
)
from langchain_core.messages import (
ChatMessage as LangchainChatMessage,
)
from pydantic import BaseModel, Field


Expand Down Expand Up @@ -56,9 +59,9 @@ class StreamInput(UserInput):
class ChatMessage(BaseModel):
"""Message in a chat."""

type: Literal["human", "ai", "tool"] = Field(
type: Literal["human", "ai", "tool", "custom"] = Field(
description="Role of the message.",
examples=["human", "ai", "tool"],
examples=["human", "ai", "tool", "custom"],
)
content: str = Field(
description="Content of the message.",
Expand All @@ -82,6 +85,10 @@ class ChatMessage(BaseModel):
description="Original LangChain message in serialized form.",
default={},
)
custom_data: dict[str, Any] = Field(
description="Custom message data.",
default={},
)

@classmethod
def from_langchain(cls, message: BaseMessage) -> "ChatMessage":
Expand Down Expand Up @@ -112,9 +119,28 @@ def from_langchain(cls, message: BaseMessage) -> "ChatMessage":
original=original,
)
return tool_message
case LangchainChatMessage():
if message.role == "custom":
custom_message = cls(
type="custom",
content="",
custom_data=message.content[0],
original=original,
)
return custom_message
else:
raise ValueError(f"Unsupported chat message role: {message.role}")
case _:
raise ValueError(f"Unsupported message type: {message.__class__.__name__}")

@classmethod
def from_custom_data(cls, data: dict[str, Any]) -> "ChatMessage":
return cls(
type="custom",
content="",
custom_data=data,
)

def to_langchain(self) -> BaseMessage:
"""Convert the ChatMessage to a LangChain message."""
if self.original:
Expand All @@ -124,6 +150,11 @@ def to_langchain(self) -> BaseMessage:
match self.type:
case "human":
return HumanMessage(content=self.content)
case "custom":
return LangchainChatMessage(
content=[self.custom_data],
role="custom",
)
case _:
raise NotImplementedError(f"Unsupported message type: {self.type}")

Expand Down

0 comments on commit 9266d9e

Please sign in to comment.