-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Unify AI Support #1446
base: main
Are you sure you want to change the base?
Adding Unify AI Support #1446
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
llm: | ||
api_type: "unify" | ||
model: "llama-3-8b-chat@together-ai" # or Get a list of models here: https://docs.unify.ai/python/utils#list-models | ||
base_url: "https://api.unify.ai/v0" | ||
api_key: "Enter your Unify API key here" # or Get your API key from https://console.unify.ai | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,6 +50,7 @@ | |
LLMType.MISTRAL, | ||
LLMType.YI, | ||
LLMType.OPENROUTER, | ||
LLMType.UNIFY, | ||
] | ||
) | ||
class OpenAILLM(BaseLLM): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
from typing import Optional, Dict, List, Union | ||
from openai.types import Completion, CompletionUsage | ||
from openai.types.chat import ChatCompletion | ||
|
||
from metagpt.configs.llm_config import LLMConfig, LLMType | ||
from metagpt.const import USE_CONFIG_TIMEOUT | ||
from metagpt.logs import log_llm_stream, logger | ||
from metagpt.provider.base_llm import BaseLLM | ||
from metagpt.provider.llm_provider_registry import register_provider | ||
from metagpt.utils.cost_manager import CostManager | ||
from metagpt.utils.token_counter import count_message_tokens, OPENAI_TOKEN_COSTS | ||
from unify.clients import Unify, AsyncUnify | ||
|
||
@register_provider([LLMType.UNIFY]) | ||
class UnifyLLM(BaseLLM): | ||
def __init__(self, config: LLMConfig): | ||
self.config = config | ||
self._init_client() | ||
self.cost_manager = CostManager(token_costs=OPENAI_TOKEN_COSTS) # Using OpenAI costs as Unify is compatible | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what about non-openai models? |
||
|
||
def _init_client(self): | ||
self.model = self.config.model | ||
self.client = Unify( | ||
api_key=self.config.api_key, | ||
endpoint=f"{self.config.model}@{self.config.provider}", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no provider field in LLMConfig There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggest to only add async client |
||
) | ||
self.async_client = AsyncUnify( | ||
api_key=self.config.api_key, | ||
endpoint=f"{self.config.model}@{self.config.provider}", | ||
) | ||
|
||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: | ||
return { | ||
"messages": messages, | ||
"max_tokens": self.config.max_token, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
"temperature": self.config.temperature, | ||
"stream": stream, | ||
} | ||
|
||
def get_choice_text(self, resp: Union[ChatCompletion, str]) -> str: | ||
if isinstance(resp, str): | ||
return resp | ||
return resp.choices[0].message.content if resp.choices else "" | ||
|
||
def _update_costs(self, usage: dict): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to add due to implemented under BaseLLM |
||
prompt_tokens = usage.get("prompt_tokens", 0) | ||
completion_tokens = usage.get("completion_tokens", 0) | ||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) | ||
|
||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion: | ||
try: | ||
response = await self.async_client.generate( | ||
messages=messages, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. _const_kwargs not used ? |
||
max_tokens=self.config.max_token, | ||
temperature=self.config.temperature, | ||
stream=False, | ||
) | ||
# Construct a ChatCompletion object to match OpenAI's format | ||
chat_completion = ChatCompletion( | ||
id="unify_chat_completion", | ||
object="chat.completion", | ||
created=0, # Unify doesn't provide this, so we use 0 | ||
model=self.model, | ||
choices=[{ | ||
"index": 0, | ||
"message": { | ||
"role": "assistant", | ||
"content": response, | ||
}, | ||
"finish_reason": "stop", | ||
}], | ||
usage=CompletionUsage( | ||
prompt_tokens=count_message_tokens(messages, self.model), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. currently main branch has no |
||
completion_tokens=count_message_tokens([{"role": "assistant", "content": response}], self.model), | ||
total_tokens=0, # Will be calculated below | ||
), | ||
) | ||
chat_completion.usage.total_tokens = chat_completion.usage.prompt_tokens + chat_completion.usage.completion_tokens | ||
self._update_costs(chat_completion.usage.model_dump()) | ||
return chat_completion | ||
except Exception as e: | ||
logger.error(f"Error in Unify chat completion: {str(e)}") | ||
raise | ||
|
||
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: | ||
try: | ||
stream = self.client.generate( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should use async_client |
||
messages=messages, | ||
max_tokens=self.config.max_token, | ||
temperature=self.config.temperature, | ||
stream=True, | ||
) | ||
collected_content = [] | ||
for chunk in stream: | ||
log_llm_stream(chunk) | ||
collected_content.append(chunk) | ||
|
||
full_content = "".join(collected_content) | ||
usage = { | ||
"prompt_tokens": count_message_tokens(messages, self.model), | ||
"completion_tokens": count_message_tokens([{"role": "assistant", "content": full_content}], self.model), | ||
} | ||
self._update_costs(usage) | ||
return full_content | ||
except Exception as e: | ||
logger.error(f"Error in Unify chat completion stream: {str(e)}") | ||
raise | ||
|
||
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion: | ||
return await self._achat_completion(messages, timeout=timeout) | ||
|
||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=USE_CONFIG_TIMEOUT) -> str: | ||
if stream: | ||
return await self._achat_completion_stream(messages, timeout=timeout) | ||
response = await self._achat_completion(messages, timeout=timeout) | ||
return self.get_choice_text(response) | ||
|
||
def get_model_name(self): | ||
return self.model | ||
|
||
def get_usage(self) -> Optional[Dict[str, int]]: | ||
return self.cost_manager.get_latest_usage() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keep some value with
YOUR_API_KEY