Skip to content

Commit

Permalink
allow model name in AOAI
Browse files Browse the repository at this point in the history
  • Loading branch information
darthtrevino committed Apr 2, 2024
1 parent f0df736 commit b48b885
Showing 1 changed file with 9 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ class Section(str, Enum):
LLM_KEY_REQUIRED = "API Key is required for Completion API. Please set either the OPENAI_API_KEY, GRAPHRAG_BASE_API_KEY or GRAPHRAG_LLM_API_KEY environment variable."
EMBEDDING_KEY_REQUIRED = "API Key is required for Embedding API. Please set either the OPENAI_API_KEY, GRAPHRAG_API_KEY or GRAPHRAG_EMBEDDING_API_KEY environment variable."
AZURE_LLM_DEPLOYMENT_NAME_REQUIRED = (
"GRAPHRAG_LLM_DEPLOYMENT_NAME is required for Azure OpenAI."
"GRAPHRAG_LLM_MODEL or GRAPHRAG_LLM_DEPLOYMENT_NAME is required for Azure OpenAI."
)
AZURE_LLM_API_BASE_REQUIRED = (
"GRAPHRAG_BASE_API_BASE or GRAPHRAG_LLM_API_BASE is required for Azure OpenAI."
)
AZURE_EMBEDDING_DEPLOYMENT_NAME_REQUIRED = (
"GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME is required for Azure OpenAI."
"GRAPHRAG_EMBEDDING_MODEL or GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME is required for Azure OpenAI."
)
AZURE_EMBEDDING_API_BASE_REQUIRED = "GRAPHRAG_BASE_API_BASE or GRAPHRAG_EMBEDDING_API_BASE is required for Azure OpenAI."

Expand Down Expand Up @@ -174,17 +174,19 @@ def section(key: Section):
llm_type = _str(Fragment.type)
llm_type = LLMType(llm_type)
deployment_name = str(Fragment.deployment_name)
model = _str(Fragment.model_supports_json)

is_azure = _is_azure(llm_type)
api_base = _str(Fragment.api_base, _api_base)
if is_azure and deployment_name is None:
if is_azure and deployment_name is None and model is None:
raise ValueError(AZURE_LLM_DEPLOYMENT_NAME_REQUIRED)
if is_azure and api_base is None:
raise ValueError(AZURE_LLM_API_BASE_REQUIRED)

llm_parameters = LLMParametersModel(
api_key=api_key,
type=llm_type,
model=_str(Fragment.model),
model=model,
max_tokens=_int(Fragment.max_tokens),
model_supports_json=_bool(Fragment.model_supports_json),
request_timeout=_float(Fragment.request_timeout),
Expand Down Expand Up @@ -217,12 +219,13 @@ def section(key: Section):
async_mode = _str(Fragment.async_mode)
async_mode_enum = AsyncType(async_mode) if async_mode else None
deployment_name = _str(Fragment.deployment_name)
model = _str(Fragment.model)
llm_type = _str(Fragment.type)
llm_type = LLMType(llm_type)
is_azure = _is_azure(llm_type)
api_base = _str(Fragment.api_base, _api_base)

if is_azure and deployment_name is None:
if is_azure and deployment_name is None and model is None:
raise ValueError(AZURE_EMBEDDING_DEPLOYMENT_NAME_REQUIRED)
if is_azure and api_base is None:
raise ValueError(AZURE_EMBEDDING_API_BASE_REQUIRED)
Expand All @@ -240,7 +243,7 @@ def section(key: Section):
llm=LLMParametersModel(
api_key=_str(Fragment.api_key, _api_key),
type=llm_type,
model=_str(Fragment.model),
model=model,
request_timeout=_float(Fragment.request_timeout),
api_base=api_base,
api_version=_str(Fragment.api_version, _api_version),
Expand Down

0 comments on commit b48b885

Please sign in to comment.