-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
openai[minor]: Update OpenAI with Azure Specific Code #5323
Changes from 4 commits
0fbf836
720bfa8
dfd8620
824894c
7644c6c
80c3a76
53857f7
23f9f70
59d3bf4
42d608a
9c40285
58da40e
31d06ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,12 @@ | ||
import { type ClientOptions } from "openai"; | ||
import { type ClientOptions, AzureOpenAI as AzureOpenAIClient } from "openai"; | ||
import { type BaseChatModelParams } from "@langchain/core/language_models/chat_models"; | ||
import { ChatOpenAI } from "../chat_models.js"; | ||
import { OpenAIEndpointConfig, getEndpoint } from "../utils/azure.js"; | ||
import { | ||
AzureOpenAIInput, | ||
LegacyOpenAIInput, | ||
OpenAIChatInput, | ||
OpenAICoreRequestOptions, | ||
} from "../types.js"; | ||
|
||
export class AzureChatOpenAI extends ChatOpenAI { | ||
|
@@ -31,15 +33,8 @@ export class AzureChatOpenAI extends ChatOpenAI { | |
configuration?: ClientOptions & LegacyOpenAIInput; | ||
} | ||
) { | ||
// assume the base URL does not contain "openai" nor "deployments" prefix | ||
let basePath = fields?.openAIBasePath ?? ""; | ||
if (!basePath.endsWith("/")) basePath += "/"; | ||
if (!basePath.endsWith("openai/deployments")) | ||
basePath += "openai/deployments"; | ||
|
||
const newFields = fields ? { ...fields } : fields; | ||
if (newFields) { | ||
newFields.azureOpenAIBasePath = basePath; | ||
newFields.azureOpenAIApiDeploymentName = newFields.deploymentName; | ||
newFields.azureOpenAIApiKey = newFields.openAIApiKey; | ||
newFields.azureOpenAIApiVersion = newFields.openAIApiVersion; | ||
|
@@ -48,6 +43,57 @@ export class AzureChatOpenAI extends ChatOpenAI { | |
super(newFields); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey there! I've reviewed the code changes and it looks like a new HTTP request setup using the AzureOpenAIClient has been added in the |
||
|
||
protected _getClientOptions(options: OpenAICoreRequestOptions | undefined) { | ||
if (!this.client) { | ||
const openAIEndpointConfig: OpenAIEndpointConfig = { | ||
azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName, | ||
azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName, | ||
azureOpenAIApiKey: this.azureOpenAIApiKey, | ||
azureOpenAIBasePath: this.azureOpenAIBasePath, | ||
baseURL: this.clientConfig.baseURL, | ||
}; | ||
|
||
const endpoint = getEndpoint(openAIEndpointConfig); | ||
|
||
const params = { | ||
...this.clientConfig, | ||
baseURL: endpoint, | ||
timeout: this.timeout, | ||
maxRetries: 0, | ||
}; | ||
|
||
if (!this.azureADTokenProvider) { | ||
params.apiKey = openAIEndpointConfig.azureOpenAIApiKey; | ||
} | ||
|
||
if (!params.baseURL) { | ||
delete params.baseURL; | ||
} | ||
|
||
this.client = new AzureOpenAIClient({ | ||
apiVersion: this.azureOpenAIApiVersion, | ||
azureADTokenProvider: this.azureADTokenProvider, | ||
deployment: this.azureOpenAIApiDeploymentName, | ||
...params, | ||
}); | ||
} | ||
const requestOptions = { | ||
...this.clientConfig, | ||
...options, | ||
} as OpenAICoreRequestOptions; | ||
if (this.azureOpenAIApiKey) { | ||
requestOptions.headers = { | ||
"api-key": this.azureOpenAIApiKey, | ||
...requestOptions.headers, | ||
}; | ||
requestOptions.query = { | ||
"api-version": this.azureOpenAIApiVersion, | ||
...requestOptions.query, | ||
}; | ||
} | ||
return requestOptions; | ||
} | ||
|
||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
toJSON(): any { | ||
const json = super.toJSON() as unknown; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import { | ||
type ClientOptions, | ||
AzureOpenAI as AzureOpenAIClient, | ||
OpenAI as OpenAIClient, | ||
} from "openai"; | ||
import { OpenAIEmbeddings, OpenAIEmbeddingsParams } from "../embeddings.js"; | ||
import { | ||
AzureOpenAIInput, | ||
OpenAICoreRequestOptions, | ||
LegacyOpenAIInput, | ||
} from "../types.js"; | ||
import { getEndpoint, OpenAIEndpointConfig } from "../utils/azure.js"; | ||
import { wrapOpenAIClientError } from "../utils/openai.js"; | ||
|
||
export class AzureOpenAIEmbeddings extends OpenAIEmbeddings { | ||
constructor( | ||
fields?: Partial<OpenAIEmbeddingsParams> & | ||
Partial<AzureOpenAIInput> & { | ||
verbose?: boolean; | ||
/** | ||
* The OpenAI API key to use. | ||
* Alias for `apiKey`. | ||
*/ | ||
openAIApiKey?: string; | ||
sarangan12 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/** The OpenAI API key to use. */ | ||
apiKey?: string; | ||
configuration?: ClientOptions; | ||
deploymentName?: string; | ||
openAIApiVersion?: string; | ||
}, | ||
configuration?: ClientOptions & LegacyOpenAIInput | ||
) { | ||
const newFields = fields ? { ...fields } : fields; | ||
if (newFields) { | ||
sarangan12 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
newFields.azureOpenAIApiDeploymentName = newFields.deploymentName; | ||
newFields.azureOpenAIApiKey = newFields.openAIApiKey; | ||
newFields.azureOpenAIApiVersion = newFields.openAIApiVersion; | ||
} | ||
|
||
super(newFields, configuration); | ||
} | ||
|
||
protected async embeddingWithRetry( | ||
request: OpenAIClient.EmbeddingCreateParams | ||
) { | ||
if (!this.client) { | ||
const openAIEndpointConfig: OpenAIEndpointConfig = { | ||
azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName, | ||
azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName, | ||
azureOpenAIApiKey: this.azureOpenAIApiKey, | ||
azureOpenAIBasePath: this.azureOpenAIBasePath, | ||
baseURL: this.clientConfig.baseURL, | ||
}; | ||
|
||
const endpoint = getEndpoint(openAIEndpointConfig); | ||
|
||
const params = { | ||
...this.clientConfig, | ||
baseURL: endpoint, | ||
timeout: this.timeout, | ||
maxRetries: 0, | ||
}; | ||
|
||
if (!this.azureADTokenProvider) { | ||
params.apiKey = openAIEndpointConfig.azureOpenAIApiKey; | ||
} | ||
|
||
if (!params.baseURL) { | ||
delete params.baseURL; | ||
} | ||
|
||
this.client = new AzureOpenAIClient({ | ||
apiVersion: this.azureOpenAIApiVersion, | ||
azureADTokenProvider: this.azureADTokenProvider, | ||
deployment: this.azureOpenAIApiDeploymentName, | ||
...params, | ||
}); | ||
} | ||
const requestOptions: OpenAICoreRequestOptions = {}; | ||
if (this.azureOpenAIApiKey) { | ||
requestOptions.headers = { | ||
"api-key": this.azureOpenAIApiKey, | ||
...requestOptions.headers, | ||
}; | ||
requestOptions.query = { | ||
"api-version": this.azureOpenAIApiVersion, | ||
...requestOptions.query, | ||
}; | ||
} | ||
return this.caller.call(async () => { | ||
try { | ||
const res = await this.client.embeddings.create( | ||
request, | ||
requestOptions | ||
); | ||
return res; | ||
} catch (e) { | ||
const error = wrapOpenAIClientError(e); | ||
throw error; | ||
} | ||
}); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
import { type ClientOptions } from "openai"; | ||
import { type ClientOptions, AzureOpenAI as AzureOpenAIClient } from "openai"; | ||
import { type BaseLLMParams } from "@langchain/core/language_models/llms"; | ||
import { OpenAI } from "../llms.js"; | ||
import { OpenAIEndpointConfig, getEndpoint } from "../utils/azure.js"; | ||
import type { | ||
OpenAIInput, | ||
AzureOpenAIInput, | ||
OpenAICoreRequestOptions, | ||
LegacyOpenAIInput, | ||
} from "../types.js"; | ||
|
||
|
@@ -27,15 +29,8 @@ export class AzureOpenAI extends OpenAI { | |
configuration?: ClientOptions & LegacyOpenAIInput; | ||
} | ||
) { | ||
// assume the base URL does not contain "openai" nor "deployments" prefix | ||
let basePath = fields?.openAIBasePath ?? ""; | ||
if (!basePath.endsWith("/")) basePath += "/"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, this is breaking, isn't it? I don't think this is heavily used yet and there are no examples in our docs, so I'm ok with a change now. |
||
if (!basePath.endsWith("openai/deployments")) | ||
basePath += "openai/deployments"; | ||
|
||
const newFields = fields ? { ...fields } : fields; | ||
if (newFields) { | ||
newFields.azureOpenAIBasePath = basePath; | ||
newFields.azureOpenAIApiDeploymentName = newFields.deploymentName; | ||
newFields.azureOpenAIApiKey = newFields.openAIApiKey; | ||
newFields.azureOpenAIApiVersion = newFields.openAIApiVersion; | ||
|
@@ -44,6 +39,57 @@ export class AzureOpenAI extends OpenAI { | |
super(newFields); | ||
} | ||
|
||
protected _getClientOptions(options: OpenAICoreRequestOptions | undefined) { | ||
if (!this.client) { | ||
const openAIEndpointConfig: OpenAIEndpointConfig = { | ||
azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName, | ||
azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName, | ||
azureOpenAIApiKey: this.azureOpenAIApiKey, | ||
azureOpenAIBasePath: this.azureOpenAIBasePath, | ||
baseURL: this.clientConfig.baseURL, | ||
}; | ||
|
||
const endpoint = getEndpoint(openAIEndpointConfig); | ||
|
||
const params = { | ||
...this.clientConfig, | ||
baseURL: endpoint, | ||
timeout: this.timeout, | ||
maxRetries: 0, | ||
}; | ||
|
||
if (!this.azureADTokenProvider) { | ||
params.apiKey = openAIEndpointConfig.azureOpenAIApiKey; | ||
} | ||
|
||
if (!params.baseURL) { | ||
delete params.baseURL; | ||
} | ||
|
||
this.client = new AzureOpenAIClient({ | ||
apiVersion: this.azureOpenAIApiVersion, | ||
azureADTokenProvider: this.azureADTokenProvider, | ||
...params, | ||
}); | ||
} | ||
|
||
const requestOptions = { | ||
...this.clientConfig, | ||
...options, | ||
} as OpenAICoreRequestOptions; | ||
if (this.azureOpenAIApiKey) { | ||
requestOptions.headers = { | ||
"api-key": this.azureOpenAIApiKey, | ||
...requestOptions.headers, | ||
}; | ||
requestOptions.query = { | ||
"api-version": this.azureOpenAIApiVersion, | ||
...requestOptions.query, | ||
}; | ||
} | ||
return requestOptions; | ||
} | ||
|
||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
toJSON(): any { | ||
const json = super.toJSON() as unknown; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey there! 👋 I noticed that the package.json file has an update to the "openai" dependency version and a new dev dependency "@azure/identity" added. This is just a heads up for the maintainers to review the changes in dependencies. Keep up the great work! 🚀