Skip to content

Commit

Permalink
openai[minor]: Update OpenAI with Azure Specific Code (#5323) (#5350)
Browse files Browse the repository at this point in the history
* Update OpenAI with Azure Specific Code

* Export Embeddings

* Add deployment parameters

* Check to fix build issue

* Update libs/langchain-openai/src/azure/embeddings.ts



* Update libs/langchain-openai/src/azure/embeddings.ts



* nit changes

* Update Test Descriptions

* Format

* Fix types

---------

Co-authored-by: Sarangan Rajamanickam <[email protected]>
Co-authored-by: Brace Sproul <[email protected]>
  • Loading branch information
3 people authored May 14, 2024
1 parent 5da2466 commit 58c2655
Show file tree
Hide file tree
Showing 15 changed files with 1,552 additions and 45 deletions.
2 changes: 1 addition & 1 deletion langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -1291,7 +1291,7 @@
"node-llama-cpp": "2.7.3",
"notion-to-md": "^3.1.0",
"officeparser": "^4.0.4",
"openai": "^4.32.1",
"openai": "^4.41.1",
"pdf-parse": "1.1.1",
"peggy": "^3.0.2",
"playwright": "^1.32.1",
Expand Down
9 changes: 6 additions & 3 deletions langchain/src/experimental/openai_assistant/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ export class OpenAIAssistantRunnable<
tools: formattedTools,
model,
file_ids: fileIds,
});
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);

return new this({
client: oaiClient,
Expand Down Expand Up @@ -130,7 +131,8 @@ export class OpenAIAssistantRunnable<
role: "user",
file_ids: input.file_ids,
metadata: input.messagesMetadata,
});
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);
run = await this._createRun(input);
} else {
// Submitting tool outputs to an existing run, outside the AgentExecutor
Expand Down Expand Up @@ -189,7 +191,8 @@ export class OpenAIAssistantRunnable<
instructions,
model,
file_ids: fileIds,
});
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);
}

private async _parseStepsInput(input: RunInput): Promise<RunInput> {
Expand Down
3 changes: 2 additions & 1 deletion libs/langchain-openai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@
"dependencies": {
"@langchain/core": "~0.1.56",
"js-tiktoken": "^1.0.7",
"openai": "^4.32.1",
"openai": "^4.41.1",
"zod": "^3.22.4",
"zod-to-json-schema": "^3.22.3"
},
"devDependencies": {
"@azure/identity": "^4.2.0",
"@jest/globals": "^29.5.0",
"@langchain/scripts": "~0.0",
"@swc/core": "^1.3.90",
Expand Down
62 changes: 54 additions & 8 deletions libs/langchain-openai/src/azure/chat_models.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import { type ClientOptions } from "openai";
import { type ClientOptions, AzureOpenAI as AzureOpenAIClient } from "openai";
import { type BaseChatModelParams } from "@langchain/core/language_models/chat_models";
import { ChatOpenAI } from "../chat_models.js";
import { OpenAIEndpointConfig, getEndpoint } from "../utils/azure.js";
import {
AzureOpenAIInput,
LegacyOpenAIInput,
OpenAIChatInput,
OpenAICoreRequestOptions,
} from "../types.js";

export class AzureChatOpenAI extends ChatOpenAI {
Expand All @@ -31,15 +33,8 @@ export class AzureChatOpenAI extends ChatOpenAI {
configuration?: ClientOptions & LegacyOpenAIInput;
}
) {
// assume the base URL does not contain "openai" nor "deployments" prefix
let basePath = fields?.openAIBasePath ?? "";
if (!basePath.endsWith("/")) basePath += "/";
if (!basePath.endsWith("openai/deployments"))
basePath += "openai/deployments";

const newFields = fields ? { ...fields } : fields;
if (newFields) {
newFields.azureOpenAIBasePath = basePath;
newFields.azureOpenAIApiDeploymentName = newFields.deploymentName;
newFields.azureOpenAIApiKey = newFields.openAIApiKey;
newFields.azureOpenAIApiVersion = newFields.openAIApiVersion;
Expand All @@ -48,6 +43,57 @@ export class AzureChatOpenAI extends ChatOpenAI {
super(newFields);
}

protected _getClientOptions(options: OpenAICoreRequestOptions | undefined) {
if (!this.client) {
const openAIEndpointConfig: OpenAIEndpointConfig = {
azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName,
azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName,
azureOpenAIApiKey: this.azureOpenAIApiKey,
azureOpenAIBasePath: this.azureOpenAIBasePath,
baseURL: this.clientConfig.baseURL,
};

const endpoint = getEndpoint(openAIEndpointConfig);

const params = {
...this.clientConfig,
baseURL: endpoint,
timeout: this.timeout,
maxRetries: 0,
};

if (!this.azureADTokenProvider) {
params.apiKey = openAIEndpointConfig.azureOpenAIApiKey;
}

if (!params.baseURL) {
delete params.baseURL;
}

this.client = new AzureOpenAIClient({
apiVersion: this.azureOpenAIApiVersion,
azureADTokenProvider: this.azureADTokenProvider,
deployment: this.azureOpenAIApiDeploymentName,
...params,
});
}
const requestOptions = {
...this.clientConfig,
...options,
} as OpenAICoreRequestOptions;
if (this.azureOpenAIApiKey) {
requestOptions.headers = {
"api-key": this.azureOpenAIApiKey,
...requestOptions.headers,
};
requestOptions.query = {
"api-version": this.azureOpenAIApiVersion,
...requestOptions.query,
};
}
return requestOptions;
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
toJSON(): any {
const json = super.toJSON() as unknown;
Expand Down
98 changes: 98 additions & 0 deletions libs/langchain-openai/src/azure/embeddings.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import {
type ClientOptions,
AzureOpenAI as AzureOpenAIClient,
OpenAI as OpenAIClient,
} from "openai";
import { OpenAIEmbeddings, OpenAIEmbeddingsParams } from "../embeddings.js";
import {
AzureOpenAIInput,
OpenAICoreRequestOptions,
LegacyOpenAIInput,
} from "../types.js";
import { getEndpoint, OpenAIEndpointConfig } from "../utils/azure.js";
import { wrapOpenAIClientError } from "../utils/openai.js";

export class AzureOpenAIEmbeddings extends OpenAIEmbeddings {
constructor(
fields?: Partial<OpenAIEmbeddingsParams> &
Partial<AzureOpenAIInput> & {
verbose?: boolean;
/** The OpenAI API key to use. */
apiKey?: string;
configuration?: ClientOptions;
deploymentName?: string;
openAIApiVersion?: string;
},
configuration?: ClientOptions & LegacyOpenAIInput
) {
const newFields = { ...fields };
if (Object.entries(newFields).length) {
newFields.azureOpenAIApiDeploymentName = newFields.deploymentName;
newFields.azureOpenAIApiKey = newFields.apiKey;
newFields.azureOpenAIApiVersion = newFields.openAIApiVersion;
}

super(newFields, configuration);
}

protected async embeddingWithRetry(
request: OpenAIClient.EmbeddingCreateParams
) {
if (!this.client) {
const openAIEndpointConfig: OpenAIEndpointConfig = {
azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName,
azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName,
azureOpenAIApiKey: this.azureOpenAIApiKey,
azureOpenAIBasePath: this.azureOpenAIBasePath,
baseURL: this.clientConfig.baseURL,
};

const endpoint = getEndpoint(openAIEndpointConfig);

const params = {
...this.clientConfig,
baseURL: endpoint,
timeout: this.timeout,
maxRetries: 0,
};

if (!this.azureADTokenProvider) {
params.apiKey = openAIEndpointConfig.azureOpenAIApiKey;
}

if (!params.baseURL) {
delete params.baseURL;
}

this.client = new AzureOpenAIClient({
apiVersion: this.azureOpenAIApiVersion,
azureADTokenProvider: this.azureADTokenProvider,
deployment: this.azureOpenAIApiDeploymentName,
...params,
});
}
const requestOptions: OpenAICoreRequestOptions = {};
if (this.azureOpenAIApiKey) {
requestOptions.headers = {
"api-key": this.azureOpenAIApiKey,
...requestOptions.headers,
};
requestOptions.query = {
"api-version": this.azureOpenAIApiVersion,
...requestOptions.query,
};
}
return this.caller.call(async () => {
try {
const res = await this.client.embeddings.create(
request,
requestOptions
);
return res;
} catch (e) {
const error = wrapOpenAIClientError(e);
throw error;
}
});
}
}
62 changes: 54 additions & 8 deletions libs/langchain-openai/src/azure/llms.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import { type ClientOptions } from "openai";
import { type ClientOptions, AzureOpenAI as AzureOpenAIClient } from "openai";
import { type BaseLLMParams } from "@langchain/core/language_models/llms";
import { OpenAI } from "../llms.js";
import { OpenAIEndpointConfig, getEndpoint } from "../utils/azure.js";
import type {
OpenAIInput,
AzureOpenAIInput,
OpenAICoreRequestOptions,
LegacyOpenAIInput,
} from "../types.js";

Expand All @@ -27,15 +29,8 @@ export class AzureOpenAI extends OpenAI {
configuration?: ClientOptions & LegacyOpenAIInput;
}
) {
// assume the base URL does not contain "openai" nor "deployments" prefix
let basePath = fields?.openAIBasePath ?? "";
if (!basePath.endsWith("/")) basePath += "/";
if (!basePath.endsWith("openai/deployments"))
basePath += "openai/deployments";

const newFields = fields ? { ...fields } : fields;
if (newFields) {
newFields.azureOpenAIBasePath = basePath;
newFields.azureOpenAIApiDeploymentName = newFields.deploymentName;
newFields.azureOpenAIApiKey = newFields.openAIApiKey;
newFields.azureOpenAIApiVersion = newFields.openAIApiVersion;
Expand All @@ -44,6 +39,57 @@ export class AzureOpenAI extends OpenAI {
super(newFields);
}

protected _getClientOptions(options: OpenAICoreRequestOptions | undefined) {
if (!this.client) {
const openAIEndpointConfig: OpenAIEndpointConfig = {
azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName,
azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName,
azureOpenAIApiKey: this.azureOpenAIApiKey,
azureOpenAIBasePath: this.azureOpenAIBasePath,
baseURL: this.clientConfig.baseURL,
};

const endpoint = getEndpoint(openAIEndpointConfig);

const params = {
...this.clientConfig,
baseURL: endpoint,
timeout: this.timeout,
maxRetries: 0,
};

if (!this.azureADTokenProvider) {
params.apiKey = openAIEndpointConfig.azureOpenAIApiKey;
}

if (!params.baseURL) {
delete params.baseURL;
}

this.client = new AzureOpenAIClient({
apiVersion: this.azureOpenAIApiVersion,
azureADTokenProvider: this.azureADTokenProvider,
...params,
});
}

const requestOptions = {
...this.clientConfig,
...options,
} as OpenAICoreRequestOptions;
if (this.azureOpenAIApiKey) {
requestOptions.headers = {
"api-key": this.azureOpenAIApiKey,
...requestOptions.headers,
};
requestOptions.query = {
"api-version": this.azureOpenAIApiVersion,
...requestOptions.query,
};
}
return requestOptions;
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
toJSON(): any {
const json = super.toJSON() as unknown;
Expand Down
18 changes: 12 additions & 6 deletions libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ export class ChatOpenAI<

azureOpenAIApiKey?: string;

azureADTokenProvider?: () => Promise<string>;

azureOpenAIApiInstanceName?: string;

azureOpenAIApiDeploymentName?: string;
Expand All @@ -386,9 +388,9 @@ export class ChatOpenAI<

organization?: string;

private client: OpenAIClient;
protected client: OpenAIClient;

private clientConfig: ClientOptions;
protected clientConfig: ClientOptions;

constructor(
fields?: Partial<OpenAIChatInput> &
Expand All @@ -411,8 +413,12 @@ export class ChatOpenAI<
fields?.azureOpenAIApiKey ??
getEnvironmentVariable("AZURE_OPENAI_API_KEY");

if (!this.azureOpenAIApiKey && !this.apiKey) {
throw new Error("OpenAI or Azure OpenAI API key not found");
this.azureADTokenProvider = fields?.azureADTokenProvider ?? undefined;

if (!this.azureOpenAIApiKey && !this.apiKey && !this.azureADTokenProvider) {
throw new Error(
"OpenAI or Azure OpenAI API key or Token Provider not found"
);
}

this.azureOpenAIApiInstanceName =
Expand Down Expand Up @@ -455,7 +461,7 @@ export class ChatOpenAI<

this.streaming = fields?.streaming ?? false;

if (this.azureOpenAIApiKey) {
if (this.azureOpenAIApiKey || this.azureADTokenProvider) {
if (!this.azureOpenAIApiInstanceName && !this.azureOpenAIBasePath) {
throw new Error("Azure OpenAI API instance name not found");
}
Expand Down Expand Up @@ -898,7 +904,7 @@ export class ChatOpenAI<
});
}

private _getClientOptions(options: OpenAICoreRequestOptions | undefined) {
protected _getClientOptions(options: OpenAICoreRequestOptions | undefined) {
if (!this.client) {
const openAIEndpointConfig: OpenAIEndpointConfig = {
azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName,
Expand Down
Loading

0 comments on commit 58c2655

Please sign in to comment.