Skip to content

Commit

Permalink
feat: support config temperature in bot
Browse files Browse the repository at this point in the history
  • Loading branch information
RaoHai committed Nov 21, 2024
1 parent a8ad640 commit 12cadb2
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 84 deletions.
7 changes: 6 additions & 1 deletion server/agent/bot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ class Bot:
def __init__(self, bot: BotModel, llm_token: LLMTokenLike):
self._bot = bot
self._llm_token = llm_token
self._llm = LLM(llm_token=llm_token)
self._llm = LLM(
llm_token=llm_token,
temperature=bot.temperature,
n=bot.n,
top_p=bot.top_p,
)

@property
def id(self):
Expand Down
80 changes: 51 additions & 29 deletions server/agent/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,55 +3,77 @@
from typing import Dict, Optional, Type
from agent.llm.base import BaseLLMClient

class LLMTokenLike():

class LLMTokenLike:
token: str
llm: str

llm_client_registry: Dict[str, Type['BaseLLMClient']] = {}

llm_client_registry: Dict[str, Type["BaseLLMClient"]] = {}


def register_llm_client(name: str):
"""Decorator to register a new LLM client class."""

def decorator(cls):
if name in llm_client_registry:
raise ValueError(f"Client '{name}' is already registered.")
llm_client_registry[name] = cls
return cls

return decorator


def get_registered_llm_client():
return llm_client_registry
return llm_client_registry

def import_clients(directory: str = 'clients'):

def import_clients(directory: str = "clients"):
"""Dynamically import all Python modules in the given directory."""
# 获取当前文件的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
clients_dir = os.path.join(current_dir, directory)

for filename in os.listdir(clients_dir):
if filename.endswith('.py') and not filename.startswith('__'):
if filename.endswith(".py") and not filename.startswith("__"):
module_name = f"agent.llm.{directory}.{filename[:-3]}" # 去掉 .py 后缀
importlib.import_module(module_name)

class LLM():
llm_token: LLMTokenLike
client: Optional[BaseLLMClient]

def __init__(self, llm_token: LLMTokenLike):
self._llm_token = llm_token
self._client = self.get_llm_client(llm_token.llm, api_key=llm_token.token)

def get_llm_client(
self,
llm: str = 'openai',
api_key: Optional[str | None] = None,
temperature: Optional[int] = 0.2,
max_tokens: Optional[int] = 1500,
streaming: Optional[bool] = False
) -> BaseLLMClient:

"""Get an LLM client based on the specified name."""
if llm in llm_client_registry:
client_class = llm_client_registry[llm]
return client_class(temperature=temperature, api_key=api_key, streaming=streaming, max_tokens=max_tokens)

return None

class LLM:
llm_token: LLMTokenLike
client: Optional[BaseLLMClient]

def __init__(
self,
llm_token: LLMTokenLike,
temperature: Optional[float] = 0.2,
n: Optional[int] = 1,
top_p: Optional[float] = None
):
self._llm_token = llm_token
self._client = self.get_llm_client(llm_token.llm, api_key=llm_token.token, temperature=temperature, n=n, top_p=top_p)

def get_llm_client(
self,
llm: str = "openai",
api_key: Optional[str | None] = None,
temperature: Optional[float] = 0.2,
n: Optional[int] = 1,
top_p: Optional[float] = None,
max_tokens: Optional[int] = 1500,
streaming: Optional[bool] = False,
) -> BaseLLMClient:
"""Get an LLM client based on the specified name."""
if llm in llm_client_registry:
client_class = llm_client_registry[llm]
return client_class(
temperature=temperature,
n=n,
top_p=top_p,
api_key=api_key,
streaming=streaming,
max_tokens=max_tokens,
)

return None
45 changes: 23 additions & 22 deletions server/agent/llm/base.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@

from abc import abstractmethod
from typing import Any, Dict, List, Optional
from langchain_core.language_models import BaseChatModel

from petercat_utils.data_class import MessageContent

class BaseLLMClient():
def __init__(self,
temperature: Optional[int] = 0.2,
max_tokens: Optional[int] = 1500,
streaming: Optional[bool] = False,
api_key: Optional[str] = ''
):
pass

@abstractmethod
def get_client() -> BaseChatModel:
pass

@abstractmethod
def get_tools(self, tool: List[Any]) -> list[Dict[str, Any]]:
pass

@abstractmethod
def parse_content(self, content: List[MessageContent]) -> List[MessageContent]:
pass


class BaseLLMClient:
def __init__(
self,
temperature: Optional[float] = 0.2,
n: Optional[int] = 1,
top_p: Optional[float] = None,
max_tokens: Optional[int] = 1500,
streaming: Optional[bool] = False,
api_key: Optional[str] = "",
):
pass

@abstractmethod
def get_client() -> BaseChatModel:
pass

@abstractmethod
def get_tools(self, tool: List[Any]) -> list[Dict[str, Any]]:
pass

@abstractmethod
def parse_content(self, content: List[MessageContent]) -> List[MessageContent]:
pass
71 changes: 40 additions & 31 deletions server/agent/llm/clients/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,47 @@

GEMINI_API_KEY = get_env_variable("GEMINI_API_KEY")


def parse_gemini_input(message: MessageContent):
match message.type:
case "image_url":
return ImageRawURLContentBlock(image_url=message.image_url.url, type="image_url")
case _:
return message
match message.type:
case "image_url":
return ImageRawURLContentBlock(
image_url=message.image_url.url, type="image_url"
)
case _:
return message


@register_llm_client("gemini")
class GeminiClient(BaseLLMClient):
_client: ChatOpenAI

def __init__(self,
temperature: Optional[int] = 0.2,
max_tokens: Optional[int] = 1500,
streaming: Optional[bool] = False,
api_key: Optional[str] = GEMINI_API_KEY,
):
self._client = ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
temperature=temperature,
streaming=streaming,
max_tokens=max_tokens,
google_api_key=api_key,
)

def get_client(self):
return self._client

def get_tools(self, tools: List[Any]):
return [convert_to_genai_function_declarations(tool) for tool in tools]

def parse_content(self, content: List[MessageContent]):
result = [parse_gemini_input(message=message) for message in content]
print(f"parse_content, content={content}, result={result}")
return result
_client: ChatOpenAI

def __init__(
self,
temperature: Optional[float] = 0.2,
n: Optional[int] = 1,
top_p: Optional[float] = None,
max_tokens: Optional[int] = 1500,
streaming: Optional[bool] = False,
api_key: Optional[str] = GEMINI_API_KEY,
):
self._client = ChatGoogleGenerativeAI(
model="gemini-1.5-flash",
temperature=temperature,
top_p=top_p,
n=n,
streaming=streaming,
max_tokens=max_tokens,
google_api_key=api_key,
)

def get_client(self):
return self._client

def get_tools(self, tools: List[Any]):
return [convert_to_genai_function_declarations(tool) for tool in tools]

def parse_content(self, content: List[MessageContent]):
result = [parse_gemini_input(message=message) for message in content]
print(f"parse_content, content={content}, result={result}")
return result
6 changes: 5 additions & 1 deletion server/agent/llm/clients/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@ class OpenAIClient(BaseLLMClient):

def __init__(
self,
temperature: Optional[int] = 0.2,
temperature: Optional[float] = 0.2,
n: Optional[int] = 1,
top_p: Optional[float] = None,
max_tokens: Optional[int] = 1500,
streaming: Optional[bool] = False,
api_key: Optional[str] = OPEN_API_KEY,
):
self._client = ChatOpenAI(
model_name="gpt-4o",
temperature=temperature,
n=n,
top_p=top_p,
streaming=streaming,
max_tokens=max_tokens,
openai_api_key=api_key,
Expand Down
3 changes: 3 additions & 0 deletions server/core/models/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ class BotModel(BaseModel):
token_id: Optional[str] = ""
created_at: datetime = datetime.now()
domain_whitelist: Optional[list[str]] = []
temperature: Optional[float] = 0.2
n: Optional[int] = 1
top_p: Optional[float]


class RepoBindBotConfigVO(BaseModel):
Expand Down

0 comments on commit 12cadb2

Please sign in to comment.