From 96c2525bb0001b0ad43d47bd4484cc4ef03f141d Mon Sep 17 00:00:00 2001 From: Sarangan Rajamanickam Date: Mon, 13 May 2024 17:00:53 -0700 Subject: [PATCH] openai[minor]: Update OpenAI with Azure Specific Code (#5323) * Update OpenAI with Azure Specific Code * Export Embeddings * Add deployment parameters * Check to fix build issue * Update libs/langchain-openai/src/azure/embeddings.ts Co-authored-by: Brace Sproul * Update libs/langchain-openai/src/azure/embeddings.ts Co-authored-by: Brace Sproul * nit changes * Update Test Descriptions * Format * Fix types --------- Co-authored-by: Brace Sproul Co-authored-by: jacoblee93 --- langchain/package.json | 2 +- .../experimental/openai_assistant/index.ts | 9 +- libs/langchain-openai/package.json | 3 +- .../langchain-openai/src/azure/chat_models.ts | 62 +- libs/langchain-openai/src/azure/embeddings.ts | 98 +++ libs/langchain-openai/src/azure/llms.ts | 62 +- libs/langchain-openai/src/chat_models.ts | 18 +- libs/langchain-openai/src/embeddings.ts | 19 +- libs/langchain-openai/src/index.ts | 1 + libs/langchain-openai/src/llms.ts | 18 +- .../src/tests/azure/chat_models.int.test.ts | 829 ++++++++++++++++++ .../src/tests/azure/embeddings.int.test.ts | 75 ++ .../src/tests/azure/llms.int.test.ts | 333 +++++++ libs/langchain-openai/src/types.ts | 6 + yarn.lock | 62 +- 15 files changed, 1552 insertions(+), 45 deletions(-) create mode 100644 libs/langchain-openai/src/azure/embeddings.ts create mode 100644 libs/langchain-openai/src/tests/azure/chat_models.int.test.ts create mode 100644 libs/langchain-openai/src/tests/azure/embeddings.int.test.ts create mode 100644 libs/langchain-openai/src/tests/azure/llms.int.test.ts diff --git a/langchain/package.json b/langchain/package.json index 1998fe70a199..704c24c477de 100644 --- a/langchain/package.json +++ b/langchain/package.json @@ -1291,7 +1291,7 @@ "node-llama-cpp": "2.7.3", "notion-to-md": "^3.1.0", "officeparser": "^4.0.4", - "openai": "^4.32.1", + "openai": "^4.41.1", "pdf-parse": "1.1.1", "peggy": "^3.0.2", "playwright": "^1.32.1", diff --git a/langchain/src/experimental/openai_assistant/index.ts b/langchain/src/experimental/openai_assistant/index.ts index 68e051391eed..57bd7574d768 100644 --- a/langchain/src/experimental/openai_assistant/index.ts +++ b/langchain/src/experimental/openai_assistant/index.ts @@ -83,7 +83,8 @@ export class OpenAIAssistantRunnable< tools: formattedTools, model, file_ids: fileIds, - }); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any); return new this({ client: oaiClient, @@ -130,7 +131,8 @@ export class OpenAIAssistantRunnable< role: "user", file_ids: input.file_ids, metadata: input.messagesMetadata, - }); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any); run = await this._createRun(input); } else { // Submitting tool outputs to an existing run, outside the AgentExecutor @@ -189,7 +191,8 @@ export class OpenAIAssistantRunnable< instructions, model, file_ids: fileIds, - }); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any); } private async _parseStepsInput(input: RunInput): Promise { diff --git a/libs/langchain-openai/package.json b/libs/langchain-openai/package.json index 6740a19c2256..503ba7e61433 100644 --- a/libs/langchain-openai/package.json +++ b/libs/langchain-openai/package.json @@ -41,11 +41,12 @@ "dependencies": { "@langchain/core": "~0.1.56", "js-tiktoken": "^1.0.7", - "openai": "^4.32.1", + "openai": "^4.41.1", "zod": "^3.22.4", "zod-to-json-schema": "^3.22.3" }, "devDependencies": { + "@azure/identity": "^4.2.0", "@jest/globals": "^29.5.0", "@langchain/scripts": "~0.0", "@swc/core": "^1.3.90", diff --git a/libs/langchain-openai/src/azure/chat_models.ts b/libs/langchain-openai/src/azure/chat_models.ts index ef297d245889..fc1a36f46581 100644 --- a/libs/langchain-openai/src/azure/chat_models.ts +++ b/libs/langchain-openai/src/azure/chat_models.ts @@ -1,10 +1,12 @@ -import { type ClientOptions } from "openai"; +import { type ClientOptions, AzureOpenAI as AzureOpenAIClient } from "openai"; import { type BaseChatModelParams } from "@langchain/core/language_models/chat_models"; import { ChatOpenAI } from "../chat_models.js"; +import { OpenAIEndpointConfig, getEndpoint } from "../utils/azure.js"; import { AzureOpenAIInput, LegacyOpenAIInput, OpenAIChatInput, + OpenAICoreRequestOptions, } from "../types.js"; export class AzureChatOpenAI extends ChatOpenAI { @@ -31,15 +33,8 @@ export class AzureChatOpenAI extends ChatOpenAI { configuration?: ClientOptions & LegacyOpenAIInput; } ) { - // assume the base URL does not contain "openai" nor "deployments" prefix - let basePath = fields?.openAIBasePath ?? ""; - if (!basePath.endsWith("/")) basePath += "/"; - if (!basePath.endsWith("openai/deployments")) - basePath += "openai/deployments"; - const newFields = fields ? { ...fields } : fields; if (newFields) { - newFields.azureOpenAIBasePath = basePath; newFields.azureOpenAIApiDeploymentName = newFields.deploymentName; newFields.azureOpenAIApiKey = newFields.openAIApiKey; newFields.azureOpenAIApiVersion = newFields.openAIApiVersion; @@ -48,6 +43,57 @@ export class AzureChatOpenAI extends ChatOpenAI { super(newFields); } + protected _getClientOptions(options: OpenAICoreRequestOptions | undefined) { + if (!this.client) { + const openAIEndpointConfig: OpenAIEndpointConfig = { + azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName, + azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName, + azureOpenAIApiKey: this.azureOpenAIApiKey, + azureOpenAIBasePath: this.azureOpenAIBasePath, + baseURL: this.clientConfig.baseURL, + }; + + const endpoint = getEndpoint(openAIEndpointConfig); + + const params = { + ...this.clientConfig, + baseURL: endpoint, + timeout: this.timeout, + maxRetries: 0, + }; + + if (!this.azureADTokenProvider) { + params.apiKey = openAIEndpointConfig.azureOpenAIApiKey; + } + + if (!params.baseURL) { + delete params.baseURL; + } + + this.client = new AzureOpenAIClient({ + apiVersion: this.azureOpenAIApiVersion, + azureADTokenProvider: this.azureADTokenProvider, + deployment: this.azureOpenAIApiDeploymentName, + ...params, + }); + } + const requestOptions = { + ...this.clientConfig, + ...options, + } as OpenAICoreRequestOptions; + if (this.azureOpenAIApiKey) { + requestOptions.headers = { + "api-key": this.azureOpenAIApiKey, + ...requestOptions.headers, + }; + requestOptions.query = { + "api-version": this.azureOpenAIApiVersion, + ...requestOptions.query, + }; + } + return requestOptions; + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any toJSON(): any { const json = super.toJSON() as unknown; diff --git a/libs/langchain-openai/src/azure/embeddings.ts b/libs/langchain-openai/src/azure/embeddings.ts new file mode 100644 index 000000000000..d5023bfbada9 --- /dev/null +++ b/libs/langchain-openai/src/azure/embeddings.ts @@ -0,0 +1,98 @@ +import { + type ClientOptions, + AzureOpenAI as AzureOpenAIClient, + OpenAI as OpenAIClient, +} from "openai"; +import { OpenAIEmbeddings, OpenAIEmbeddingsParams } from "../embeddings.js"; +import { + AzureOpenAIInput, + OpenAICoreRequestOptions, + LegacyOpenAIInput, +} from "../types.js"; +import { getEndpoint, OpenAIEndpointConfig } from "../utils/azure.js"; +import { wrapOpenAIClientError } from "../utils/openai.js"; + +export class AzureOpenAIEmbeddings extends OpenAIEmbeddings { + constructor( + fields?: Partial & + Partial & { + verbose?: boolean; + /** The OpenAI API key to use. */ + apiKey?: string; + configuration?: ClientOptions; + deploymentName?: string; + openAIApiVersion?: string; + }, + configuration?: ClientOptions & LegacyOpenAIInput + ) { + const newFields = { ...fields }; + if (Object.entries(newFields).length) { + newFields.azureOpenAIApiDeploymentName = newFields.deploymentName; + newFields.azureOpenAIApiKey = newFields.apiKey; + newFields.azureOpenAIApiVersion = newFields.openAIApiVersion; + } + + super(newFields, configuration); + } + + protected async embeddingWithRetry( + request: OpenAIClient.EmbeddingCreateParams + ) { + if (!this.client) { + const openAIEndpointConfig: OpenAIEndpointConfig = { + azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName, + azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName, + azureOpenAIApiKey: this.azureOpenAIApiKey, + azureOpenAIBasePath: this.azureOpenAIBasePath, + baseURL: this.clientConfig.baseURL, + }; + + const endpoint = getEndpoint(openAIEndpointConfig); + + const params = { + ...this.clientConfig, + baseURL: endpoint, + timeout: this.timeout, + maxRetries: 0, + }; + + if (!this.azureADTokenProvider) { + params.apiKey = openAIEndpointConfig.azureOpenAIApiKey; + } + + if (!params.baseURL) { + delete params.baseURL; + } + + this.client = new AzureOpenAIClient({ + apiVersion: this.azureOpenAIApiVersion, + azureADTokenProvider: this.azureADTokenProvider, + deployment: this.azureOpenAIApiDeploymentName, + ...params, + }); + } + const requestOptions: OpenAICoreRequestOptions = {}; + if (this.azureOpenAIApiKey) { + requestOptions.headers = { + "api-key": this.azureOpenAIApiKey, + ...requestOptions.headers, + }; + requestOptions.query = { + "api-version": this.azureOpenAIApiVersion, + ...requestOptions.query, + }; + } + return this.caller.call(async () => { + try { + const res = await this.client.embeddings.create( + request, + requestOptions + ); + return res; + } catch (e) { + const error = wrapOpenAIClientError(e); + throw error; + } + }); + } +} diff --git a/libs/langchain-openai/src/azure/llms.ts b/libs/langchain-openai/src/azure/llms.ts index be51eed01e1f..546893113a06 100644 --- a/libs/langchain-openai/src/azure/llms.ts +++ b/libs/langchain-openai/src/azure/llms.ts @@ -1,9 +1,11 @@ -import { type ClientOptions } from "openai"; +import { type ClientOptions, AzureOpenAI as AzureOpenAIClient } from "openai"; import { type BaseLLMParams } from "@langchain/core/language_models/llms"; import { OpenAI } from "../llms.js"; +import { OpenAIEndpointConfig, getEndpoint } from "../utils/azure.js"; import type { OpenAIInput, AzureOpenAIInput, + OpenAICoreRequestOptions, LegacyOpenAIInput, } from "../types.js"; @@ -27,15 +29,8 @@ export class AzureOpenAI extends OpenAI { configuration?: ClientOptions & LegacyOpenAIInput; } ) { - // assume the base URL does not contain "openai" nor "deployments" prefix - let basePath = fields?.openAIBasePath ?? ""; - if (!basePath.endsWith("/")) basePath += "/"; - if (!basePath.endsWith("openai/deployments")) - basePath += "openai/deployments"; - const newFields = fields ? { ...fields } : fields; if (newFields) { - newFields.azureOpenAIBasePath = basePath; newFields.azureOpenAIApiDeploymentName = newFields.deploymentName; newFields.azureOpenAIApiKey = newFields.openAIApiKey; newFields.azureOpenAIApiVersion = newFields.openAIApiVersion; @@ -44,6 +39,57 @@ export class AzureOpenAI extends OpenAI { super(newFields); } + protected _getClientOptions(options: OpenAICoreRequestOptions | undefined) { + if (!this.client) { + const openAIEndpointConfig: OpenAIEndpointConfig = { + azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName, + azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName, + azureOpenAIApiKey: this.azureOpenAIApiKey, + azureOpenAIBasePath: this.azureOpenAIBasePath, + baseURL: this.clientConfig.baseURL, + }; + + const endpoint = getEndpoint(openAIEndpointConfig); + + const params = { + ...this.clientConfig, + baseURL: endpoint, + timeout: this.timeout, + maxRetries: 0, + }; + + if (!this.azureADTokenProvider) { + params.apiKey = openAIEndpointConfig.azureOpenAIApiKey; + } + + if (!params.baseURL) { + delete params.baseURL; + } + + this.client = new AzureOpenAIClient({ + apiVersion: this.azureOpenAIApiVersion, + azureADTokenProvider: this.azureADTokenProvider, + ...params, + }); + } + + const requestOptions = { + ...this.clientConfig, + ...options, + } as OpenAICoreRequestOptions; + if (this.azureOpenAIApiKey) { + requestOptions.headers = { + "api-key": this.azureOpenAIApiKey, + ...requestOptions.headers, + }; + requestOptions.query = { + "api-version": this.azureOpenAIApiVersion, + ...requestOptions.query, + }; + } + return requestOptions; + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any toJSON(): any { const json = super.toJSON() as unknown; diff --git a/libs/langchain-openai/src/chat_models.ts b/libs/langchain-openai/src/chat_models.ts index d3d1465ae142..8f4d607ebd9f 100644 --- a/libs/langchain-openai/src/chat_models.ts +++ b/libs/langchain-openai/src/chat_models.ts @@ -378,6 +378,8 @@ export class ChatOpenAI< azureOpenAIApiKey?: string; + azureADTokenProvider?: () => Promise; + azureOpenAIApiInstanceName?: string; azureOpenAIApiDeploymentName?: string; @@ -386,9 +388,9 @@ export class ChatOpenAI< organization?: string; - private client: OpenAIClient; + protected client: OpenAIClient; - private clientConfig: ClientOptions; + protected clientConfig: ClientOptions; constructor( fields?: Partial & @@ -411,8 +413,12 @@ export class ChatOpenAI< fields?.azureOpenAIApiKey ?? getEnvironmentVariable("AZURE_OPENAI_API_KEY"); - if (!this.azureOpenAIApiKey && !this.apiKey) { - throw new Error("OpenAI or Azure OpenAI API key not found"); + this.azureADTokenProvider = fields?.azureADTokenProvider ?? undefined; + + if (!this.azureOpenAIApiKey && !this.apiKey && !this.azureADTokenProvider) { + throw new Error( + "OpenAI or Azure OpenAI API key or Token Provider not found" + ); } this.azureOpenAIApiInstanceName = @@ -455,7 +461,7 @@ export class ChatOpenAI< this.streaming = fields?.streaming ?? false; - if (this.azureOpenAIApiKey) { + if (this.azureOpenAIApiKey || this.azureADTokenProvider) { if (!this.azureOpenAIApiInstanceName && !this.azureOpenAIBasePath) { throw new Error("Azure OpenAI API instance name not found"); } @@ -898,7 +904,7 @@ export class ChatOpenAI< }); } - private _getClientOptions(options: OpenAICoreRequestOptions | undefined) { + protected _getClientOptions(options: OpenAICoreRequestOptions | undefined) { if (!this.client) { const openAIEndpointConfig: OpenAIEndpointConfig = { azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName, diff --git a/libs/langchain-openai/src/embeddings.ts b/libs/langchain-openai/src/embeddings.ts index 4e0e0cd2227d..115c9fe2d4a4 100644 --- a/libs/langchain-openai/src/embeddings.ts +++ b/libs/langchain-openai/src/embeddings.ts @@ -88,6 +88,8 @@ export class OpenAIEmbeddings azureOpenAIApiKey?: string; + azureADTokenProvider?: () => Promise; + azureOpenAIApiInstanceName?: string; azureOpenAIApiDeploymentName?: string; @@ -96,9 +98,9 @@ export class OpenAIEmbeddings organization?: string; - private client: OpenAIClient; + protected client: OpenAIClient; - private clientConfig: ClientOptions; + protected clientConfig: ClientOptions; constructor( fields?: Partial & @@ -127,8 +129,13 @@ export class OpenAIEmbeddings const azureApiKey = fieldsWithDefaults?.azureOpenAIApiKey ?? getEnvironmentVariable("AZURE_OPENAI_API_KEY"); - if (!azureApiKey && !apiKey) { - throw new Error("OpenAI or Azure OpenAI API key not found"); + + this.azureADTokenProvider = fields?.azureADTokenProvider ?? undefined; + + if (!azureApiKey && !apiKey && !this.azureADTokenProvider) { + throw new Error( + "OpenAI or Azure OpenAI API key or Token Provider not found" + ); } const azureApiInstanceName = @@ -168,7 +175,7 @@ export class OpenAIEmbeddings this.azureOpenAIApiInstanceName = azureApiInstanceName; this.azureOpenAIApiDeploymentName = azureApiDeploymentName; - if (this.azureOpenAIApiKey) { + if (this.azureOpenAIApiKey || this.azureADTokenProvider) { if (!this.azureOpenAIApiInstanceName && !this.azureOpenAIBasePath) { throw new Error("Azure OpenAI API instance name not found"); } @@ -254,7 +261,7 @@ export class OpenAIEmbeddings * @param request Request to send to the OpenAI API. * @returns Promise that resolves to the response from the API. */ - private async embeddingWithRetry( + protected async embeddingWithRetry( request: OpenAIClient.EmbeddingCreateParams ) { if (!this.client) { diff --git a/libs/langchain-openai/src/index.ts b/libs/langchain-openai/src/index.ts index 79287ecf5e69..76e8df482138 100644 --- a/libs/langchain-openai/src/index.ts +++ b/libs/langchain-openai/src/index.ts @@ -3,6 +3,7 @@ export * from "./chat_models.js"; export * from "./azure/chat_models.js"; export * from "./llms.js"; export * from "./azure/llms.js"; +export * from "./azure/embeddings.js"; export * from "./embeddings.js"; export * from "./types.js"; export * from "./utils/openai.js"; diff --git a/libs/langchain-openai/src/llms.ts b/libs/langchain-openai/src/llms.ts index b9cb55fac806..871a51500b2f 100644 --- a/libs/langchain-openai/src/llms.ts +++ b/libs/langchain-openai/src/llms.ts @@ -147,6 +147,8 @@ export class OpenAI azureOpenAIApiKey?: string; + azureADTokenProvider?: () => Promise; + azureOpenAIApiInstanceName?: string; azureOpenAIApiDeploymentName?: string; @@ -155,9 +157,9 @@ export class OpenAI organization?: string; - private client: OpenAIClient; + protected client: OpenAIClient; - private clientConfig: ClientOptions; + protected clientConfig: ClientOptions; constructor( fields?: Partial & @@ -192,8 +194,12 @@ export class OpenAI fields?.azureOpenAIApiKey ?? getEnvironmentVariable("AZURE_OPENAI_API_KEY"); - if (!this.azureOpenAIApiKey && !this.apiKey) { - throw new Error("OpenAI or Azure OpenAI API key not found"); + this.azureADTokenProvider = fields?.azureADTokenProvider ?? undefined; + + if (!this.azureOpenAIApiKey && !this.apiKey && !this.azureADTokenProvider) { + throw new Error( + "OpenAI or Azure OpenAI API key or Token Provider not found" + ); } this.azureOpenAIApiInstanceName = @@ -242,7 +248,7 @@ export class OpenAI throw new Error("Cannot stream results when bestOf > 1"); } - if (this.azureOpenAIApiKey) { + if (this.azureOpenAIApiKey || this.azureADTokenProvider) { if (!this.azureOpenAIApiInstanceName && !this.azureOpenAIBasePath) { throw new Error("Azure OpenAI API instance name not found"); } @@ -530,7 +536,7 @@ export class OpenAI * @param options Optional configuration for the API call. * @returns The response from the OpenAI API. */ - private _getClientOptions(options: OpenAICoreRequestOptions | undefined) { + protected _getClientOptions(options: OpenAICoreRequestOptions | undefined) { if (!this.client) { const openAIEndpointConfig: OpenAIEndpointConfig = { azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName, diff --git a/libs/langchain-openai/src/tests/azure/chat_models.int.test.ts b/libs/langchain-openai/src/tests/azure/chat_models.int.test.ts new file mode 100644 index 000000000000..bb9a99d61b74 --- /dev/null +++ b/libs/langchain-openai/src/tests/azure/chat_models.int.test.ts @@ -0,0 +1,829 @@ +import { test, jest, expect } from "@jest/globals"; +import { + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +} from "@langchain/core/messages"; +import { ChatGeneration, LLMResult } from "@langchain/core/outputs"; +import { ChatPromptValue } from "@langchain/core/prompt_values"; +import { + PromptTemplate, + ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +} from "@langchain/core/prompts"; +import { CallbackManager } from "@langchain/core/callbacks/manager"; +import { NewTokenIndices } from "@langchain/core/callbacks/base"; +import { InMemoryCache } from "@langchain/core/caches"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { + ClientSecretCredential, + getBearerTokenProvider, +} from "@azure/identity"; +import { AzureChatOpenAI } from "../../azure/chat_models.js"; + +test("Test Azure ChatOpenAI call method", async () => { + const chat = new AzureChatOpenAI({ + modelName: "gpt-3.5-turbo", + maxTokens: 10, + }); + const message = new HumanMessage("Hello!"); + const res = await chat.call([message]); + console.log({ res }); +}); + +test("Test Azure ChatOpenAI with SystemChatMessage", async () => { + const chat = new AzureChatOpenAI({ + modelName: "gpt-3.5-turbo", + maxTokens: 10, + }); + const system_message = new SystemMessage("You are to chat with a user."); + const message = new HumanMessage("Hello!"); + const res = await chat.call([system_message, message]); + console.log({ res }); +}); + +test("Test Azure ChatOpenAI Generate", async () => { + const chat = new AzureChatOpenAI({ + modelName: "gpt-3.5-turbo", + maxTokens: 10, + n: 2, + }); + const message = new HumanMessage("Hello!"); + const res = await chat.generate([[message], [message]]); + expect(res.generations.length).toBe(2); + for (const generation of res.generations) { + expect(generation.length).toBe(2); + for (const message of generation) { + console.log(message.text); + expect(typeof message.text).toBe("string"); + } + } + console.log({ res }); +}); + +test("Test Azure ChatOpenAI Generate throws when one of the calls fails", async () => { + const chat = new AzureChatOpenAI({ + modelName: "gpt-3.5-turbo", + maxTokens: 10, + n: 2, + }); + const message = new HumanMessage("Hello!"); + await expect(() => + chat.generate([[message], [message]], { + signal: AbortSignal.timeout(10), + }) + ).rejects.toThrow(); +}); + +test("Test Azure ChatOpenAI tokenUsage", async () => { + let tokenUsage = { + completionTokens: 0, + promptTokens: 0, + totalTokens: 0, + }; + + const model = new AzureChatOpenAI({ + modelName: "gpt-3.5-turbo", + maxTokens: 10, + callbackManager: CallbackManager.fromHandlers({ + async handleLLMEnd(output: LLMResult) { + console.log(output); + tokenUsage = output.llmOutput?.tokenUsage; + }, + }), + }); + const message = new HumanMessage("Hello"); + const res = await model.invoke([message]); + console.log({ res }); + + expect(tokenUsage.promptTokens).toBeGreaterThan(0); +}); + +test("Test Azure ChatOpenAI tokenUsage with a batch", async () => { + let tokenUsage = { + completionTokens: 0, + promptTokens: 0, + totalTokens: 0, + }; + + const model = new AzureChatOpenAI({ + temperature: 0, + modelName: "gpt-3.5-turbo", + callbackManager: CallbackManager.fromHandlers({ + async handleLLMEnd(output: LLMResult) { + tokenUsage = output.llmOutput?.tokenUsage; + }, + }), + }); + const res = await model.generate([ + [new HumanMessage("Hello")], + [new HumanMessage("Hi")], + ]); + console.log(res); + + expect(tokenUsage.promptTokens).toBeGreaterThan(0); +}); + +test("Test Azure ChatOpenAI in streaming mode", async () => { + let nrNewTokens = 0; + let streamedCompletion = ""; + + const model = new AzureChatOpenAI({ + modelName: "gpt-3.5-turbo", + streaming: true, + maxTokens: 10, + callbacks: [ + { + async handleLLMNewToken(token: string) { + nrNewTokens += 1; + streamedCompletion += token; + }, + }, + ], + }); + const message = new HumanMessage("Hello!"); + const result = await model.invoke([message]); + + expect(nrNewTokens > 0).toBe(true); + expect(result.content).toBe(streamedCompletion); +}, 10000); + +test("Test Azure ChatOpenAI in streaming mode with n > 1 and multiple prompts", async () => { + let nrNewTokens = 0; + const streamedCompletions = [ + ["", ""], + ["", ""], + ]; + + const model = new AzureChatOpenAI({ + modelName: "gpt-3.5-turbo", + streaming: true, + maxTokens: 10, + n: 2, + callbacks: [ + { + async handleLLMNewToken(token: string, idx: NewTokenIndices) { + nrNewTokens += 1; + streamedCompletions[idx.prompt][idx.completion] += token; + }, + }, + ], + }); + const message1 = new HumanMessage("Hello!"); + const message2 = new HumanMessage("Bye!"); + const result = await model.generate([[message1], [message2]]); + + expect(nrNewTokens > 0).toBe(true); + expect(result.generations.map((g) => g.map((gg) => gg.text))).toEqual( + streamedCompletions + ); +}, 10000); + +test("Test Azure ChatOpenAI prompt value", async () => { + const chat = new AzureChatOpenAI({ + modelName: "gpt-3.5-turbo", + maxTokens: 10, + n: 2, + }); + const message = new HumanMessage("Hello!"); + const res = await chat.generatePrompt([new ChatPromptValue([message])]); + expect(res.generations.length).toBe(1); + for (const generation of res.generations) { + expect(generation.length).toBe(2); + for (const g of generation) { + console.log(g.text); + } + } + console.log({ res }); +}); + +test("Test Azure OpenAI Chat, docs, prompt templates", async () => { + const chat = new AzureChatOpenAI({ temperature: 0, maxTokens: 10 }); + + const systemPrompt = PromptTemplate.fromTemplate( + "You are a helpful assistant that translates {input_language} to {output_language}." + ); + + const chatPrompt = ChatPromptTemplate.fromMessages([ + new SystemMessagePromptTemplate(systemPrompt), + HumanMessagePromptTemplate.fromTemplate("{text}"), + ]); + + const responseA = await chat.generatePrompt([ + await chatPrompt.formatPromptValue({ + input_language: "English", + output_language: "French", + text: "I love programming.", + }), + ]); + + console.log(responseA.generations); +}, 5000); + +test("Test Azure ChatOpenAI with stop", async () => { + const model = new AzureChatOpenAI({ maxTokens: 5 }); + const res = await model.call( + [new HumanMessage("Print hello world")], + ["world"] + ); + console.log({ res }); +}); + +test("Test Azure ChatOpenAI with stop in object", async () => { + const model = new AzureChatOpenAI({ maxTokens: 5 }); + const res = await model.invoke([new HumanMessage("Print hello world")], { + stop: ["world"], + }); + console.log({ res }); +}); + +test("Test Azure ChatOpenAI with timeout in call options", async () => { + const model = new AzureChatOpenAI({ maxTokens: 5 }); + await expect(() => + model.invoke([new HumanMessage("Print hello world")], { timeout: 10 }) + ).rejects.toThrow(); +}, 5000); + +test("Test Azure ChatOpenAI with timeout in call options and node adapter", async () => { + const model = new AzureChatOpenAI({ maxTokens: 5 }); + await expect(() => + model.invoke([new HumanMessage("Print hello world")], { timeout: 10 }) + ).rejects.toThrow(); +}, 5000); + +test("Test Azure ChatOpenAI with signal in call options", async () => { + const model = new AzureChatOpenAI({ maxTokens: 5 }); + const controller = new AbortController(); + await expect(() => { + const ret = model.invoke([new HumanMessage("Print hello world")], { + signal: controller.signal, + }); + + controller.abort(); + + return ret; + }).rejects.toThrow(); +}, 5000); + +test("Test Azure ChatOpenAI with signal in call options and node adapter", async () => { + const model = new AzureChatOpenAI({ + maxTokens: 5, + modelName: "gpt-3.5-turbo-instruct", + }); + const controller = new AbortController(); + await expect(() => { + const ret = model.invoke([new HumanMessage("Print hello world")], { + signal: controller.signal, + }); + + controller.abort(); + + return ret; + }).rejects.toThrow(); +}, 5000); + +test("Test Azure ChatOpenAI with specific roles in ChatMessage", async () => { + const chat = new AzureChatOpenAI({ + modelName: "gpt-3.5-turbo", + maxTokens: 10, + }); + const system_message = new ChatMessage( + "You are to chat with a user.", + "system" + ); + const user_message = new ChatMessage("Hello!", "user"); + const res = await chat.call([system_message, user_message]); + console.log({ res }); +}); + +test("Test Azure ChatOpenAI stream method", async () => { + const model = new AzureChatOpenAI({ + maxTokens: 50, + modelName: "gpt-3.5-turbo", + }); + const stream = await model.stream("Print hello world."); + const chunks = []; + for await (const chunk of stream) { + console.log(chunk); + chunks.push(chunk); + } + expect(chunks.length).toBeGreaterThan(1); +}); + +test("Test Azure ChatOpenAI stream method with abort", async () => { + await expect(async () => { + const model = new AzureChatOpenAI({ + maxTokens: 100, + modelName: "gpt-3.5-turbo", + }); + const stream = await model.stream( + "How is your day going? Be extremely verbose.", + { + signal: AbortSignal.timeout(500), + } + ); + for await (const chunk of stream) { + console.log(chunk); + } + }).rejects.toThrow(); +}); + +test("Test Azure ChatOpenAI stream method with early break", async () => { + const model = new AzureChatOpenAI({ + maxTokens: 50, + modelName: "gpt-3.5-turbo", + }); + const stream = await model.stream( + "How is your day going? Be extremely verbose." + ); + let i = 0; + for await (const chunk of stream) { + console.log(chunk); + i += 1; + if (i > 10) { + break; + } + } +}); + +test("Test Azure ChatOpenAI stream method, timeout error thrown from SDK", async () => { + await expect(async () => { + const model = new AzureChatOpenAI({ + maxTokens: 50, + modelName: "gpt-3.5-turbo", + timeout: 1, + }); + const stream = await model.stream( + "How is your day going? Be extremely verbose." + ); + for await (const chunk of stream) { + console.log(chunk); + } + }).rejects.toThrow(); +}); + +test("Test Azure ChatOpenAI Function calling with streaming", async () => { + let finalResult: BaseMessage | undefined; + const modelForFunctionCalling = new AzureChatOpenAI({ + modelName: "gpt-3.5-turbo", + temperature: 0, + callbacks: [ + { + handleLLMEnd(output: LLMResult) { + finalResult = (output.generations[0][0] as ChatGeneration).message; + }, + }, + ], + }); + + const stream = await modelForFunctionCalling.stream( + "What is the weather in New York?", + { + functions: [ + { + name: "get_current_weather", + description: "Get the current weather in a given location", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The city and state, e.g. San Francisco, CA", + }, + unit: { type: "string", enum: ["celsius", "fahrenheit"] }, + }, + required: ["location"], + }, + }, + ], + function_call: { + name: "get_current_weather", + }, + } + ); + + const chunks = []; + let streamedOutput; + for await (const chunk of stream) { + chunks.push(chunk); + if (!streamedOutput) { + streamedOutput = chunk; + } else if (chunk) { + streamedOutput = streamedOutput.concat(chunk); + } + } + + expect(finalResult).toEqual(streamedOutput); + expect(chunks.length).toBeGreaterThan(1); + expect(finalResult?.additional_kwargs?.function_call?.name).toBe( + "get_current_weather" + ); + console.log( + JSON.parse(finalResult?.additional_kwargs?.function_call?.arguments ?? "") + .location + ); +}); + +test("Test Azure ChatOpenAI can cache generations", async () => { + const memoryCache = new InMemoryCache(); + const lookupSpy = jest.spyOn(memoryCache, "lookup"); + const updateSpy = jest.spyOn(memoryCache, "update"); + const chat = new AzureChatOpenAI({ + modelName: "gpt-3.5-turbo", + maxTokens: 10, + n: 2, + cache: memoryCache, + }); + const message = new HumanMessage("Hello"); + const res = await chat.generate([[message], [message]]); + expect(res.generations.length).toBe(2); + + expect(lookupSpy).toHaveBeenCalledTimes(2); + expect(updateSpy).toHaveBeenCalledTimes(2); + + lookupSpy.mockRestore(); + updateSpy.mockRestore(); +}); + +test("Test Azure ChatOpenAI can write and read cached generations", async () => { + const memoryCache = new InMemoryCache(); + const lookupSpy = jest.spyOn(memoryCache, "lookup"); + const updateSpy = jest.spyOn(memoryCache, "update"); + + const chat = new AzureChatOpenAI({ + modelName: "gpt-3.5-turbo", + maxTokens: 100, + n: 1, + cache: memoryCache, + }); + const generateUncachedSpy = jest.spyOn(chat, "_generateUncached"); + + const messages = [ + [ + new HumanMessage("what color is the sky?"), + new HumanMessage("what color is the ocean?"), + ], + [new HumanMessage("hello")], + ]; + + const response1 = await chat.generate(messages); + expect(generateUncachedSpy).toHaveBeenCalledTimes(1); + generateUncachedSpy.mockRestore(); + + const response2 = await chat.generate(messages); + expect(generateUncachedSpy).toHaveBeenCalledTimes(0); // Request should be cached, no need to generate. + generateUncachedSpy.mockRestore(); + + expect(response1.generations.length).toBe(2); + expect(response2.generations).toEqual(response1.generations); + expect(lookupSpy).toHaveBeenCalledTimes(4); + expect(updateSpy).toHaveBeenCalledTimes(2); + + lookupSpy.mockRestore(); + updateSpy.mockRestore(); +}); + +test("Test Azure ChatOpenAI should not reuse cache if function call args have changed", async () => { + const memoryCache = new InMemoryCache(); + const lookupSpy = jest.spyOn(memoryCache, "lookup"); + const updateSpy = jest.spyOn(memoryCache, "update"); + + const chat = new AzureChatOpenAI({ + modelName: "gpt-3.5-turbo", + maxTokens: 100, + n: 1, + cache: memoryCache, + }); + + const generateUncachedSpy = jest.spyOn(chat, "_generateUncached"); + + const messages = [ + [ + new HumanMessage("what color is the sky?"), + new HumanMessage("what color is the ocean?"), + ], + [new HumanMessage("hello")], + ]; + + const response1 = await chat.generate(messages); + expect(generateUncachedSpy).toHaveBeenCalledTimes(1); + generateUncachedSpy.mockRestore(); + + const response2 = await chat.generate(messages, { + functions: [ + { + name: "extractor", + description: "Extract fields from the input", + parameters: { + type: "object", + properties: { + tone: { + type: "string", + description: "the tone of the input", + }, + }, + required: ["tone"], + }, + }, + ], + function_call: { + name: "extractor", + }, + }); + + expect(generateUncachedSpy).toHaveBeenCalledTimes(0); // Request should not be cached since it's being called with different function call args + + expect(response1.generations.length).toBe(2); + expect( + (response2.generations[0][0] as ChatGeneration).message.additional_kwargs + .function_call?.name ?? "" + ).toEqual("extractor"); + + const response3 = await chat.generate(messages, { + functions: [ + { + name: "extractor", + description: "Extract fields from the input", + parameters: { + type: "object", + properties: { + tone: { + type: "string", + description: "the tone of the input", + }, + }, + required: ["tone"], + }, + }, + ], + function_call: { + name: "extractor", + }, + }); + + expect(response2.generations).toEqual(response3.generations); + + expect(lookupSpy).toHaveBeenCalledTimes(6); + expect(updateSpy).toHaveBeenCalledTimes(4); + + lookupSpy.mockRestore(); + updateSpy.mockRestore(); +}); + +function createSampleMessages(): BaseMessage[] { + // same example as in https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb + return [ + createSystemChatMessage( + "You are a helpful, pattern-following assistant that translates corporate jargon into plain English." + ), + createSystemChatMessage( + "New synergies will help drive top-line growth.", + "example_user" + ), + createSystemChatMessage( + "Things working well together will increase revenue.", + "example_assistant" + ), + createSystemChatMessage( + "Let's circle back when we have more bandwidth to touch base on opportunities for increased leverage.", + "example_user" + ), + createSystemChatMessage( + "Let's talk later when we're less busy about how to do better.", + "example_assistant" + ), + new HumanMessage( + "This late pivot means we don't have time to boil the ocean for the client deliverable." + ), + ]; +} + +function createSystemChatMessage(text: string, name?: string) { + const msg = new SystemMessage(text); + msg.name = name; + return msg; +} + +test("Test Azure ChatOpenAI getNumTokensFromMessages gpt-3.5-turbo-0301 model for sample input", async () => { + const messages: BaseMessage[] = createSampleMessages(); + + const chat = new AzureChatOpenAI({ + azureOpenAIApiKey: "dummy", + modelName: "gpt-3.5-turbo-0301", + }); + + const { totalCount } = await chat.getNumTokensFromMessages(messages); + + expect(totalCount).toBe(127); +}); + +test("Test Azure ChatOpenAI getNumTokensFromMessages gpt-4-0314 model for sample input", async () => { + const messages: BaseMessage[] = createSampleMessages(); + + const chat = new AzureChatOpenAI({ + azureOpenAIApiKey: "dummy", + modelName: "gpt-4-0314", + }); + + const { totalCount } = await chat.getNumTokensFromMessages(messages); + + expect(totalCount).toBe(129); +}); + +test("Test Azure ChatOpenAI token usage reporting for streaming function calls", async () => { + let streamingTokenUsed = -1; + let nonStreamingTokenUsed = -1; + + const humanMessage = "What a beautiful day!"; + const extractionFunctionSchema = { + name: "extractor", + description: "Extracts fields from the input.", + parameters: { + type: "object", + properties: { + tone: { + type: "string", + enum: ["positive", "negative"], + description: "The overall tone of the input", + }, + word_count: { + type: "number", + description: "The number of words in the input", + }, + chat_response: { + type: "string", + description: "A response to the human's input", + }, + }, + required: ["tone", "word_count", "chat_response"], + }, + }; + + const streamingModel = new AzureChatOpenAI({ + modelName: "gpt-3.5-turbo", + streaming: true, + maxRetries: 10, + maxConcurrency: 10, + temperature: 0, + topP: 0, + callbacks: [ + { + handleLLMEnd: async (output) => { + streamingTokenUsed = + output.llmOutput?.estimatedTokenUsage?.totalTokens; + console.log("streaming usage", output.llmOutput?.estimatedTokenUsage); + }, + handleLLMError: async (err) => { + console.error(err); + }, + }, + ], + }).bind({ + seed: 42, + functions: [extractionFunctionSchema], + function_call: { name: "extractor" }, + }); + + const nonStreamingModel = new AzureChatOpenAI({ + modelName: "gpt-3.5-turbo", + streaming: false, + maxRetries: 10, + maxConcurrency: 10, + temperature: 0, + topP: 0, + callbacks: [ + { + handleLLMEnd: async (output) => { + nonStreamingTokenUsed = output.llmOutput?.tokenUsage?.totalTokens; + console.log("non-streaming usage", output.llmOutput?.tokenUsage); + }, + handleLLMError: async (err) => { + console.error(err); + }, + }, + ], + }).bind({ + functions: [extractionFunctionSchema], + function_call: { name: "extractor" }, + }); + + const [nonStreamingResult, streamingResult] = await Promise.all([ + nonStreamingModel.invoke([new HumanMessage(humanMessage)]), + streamingModel.invoke([new HumanMessage(humanMessage)]), + ]); + + if ( + nonStreamingResult.additional_kwargs.function_call?.arguments && + streamingResult.additional_kwargs.function_call?.arguments + ) { + console.log( + `Function Call: ${JSON.stringify( + nonStreamingResult.additional_kwargs.function_call + )}` + ); + const nonStreamingArguments = JSON.stringify( + JSON.parse(nonStreamingResult.additional_kwargs.function_call.arguments) + ); + const streamingArguments = JSON.stringify( + JSON.parse(streamingResult.additional_kwargs.function_call.arguments) + ); + if (nonStreamingArguments === streamingArguments) { + expect(streamingTokenUsed).toEqual(nonStreamingTokenUsed); + } + } + + expect(streamingTokenUsed).toBeGreaterThan(-1); +}); + +test("Test Azure ChatOpenAI token usage reporting for streaming calls", async () => { + let streamingTokenUsed = -1; + let nonStreamingTokenUsed = -1; + const systemPrompt = "You are a helpful assistant"; + const question = "What is the color of the night sky?"; + + const streamingModel = new AzureChatOpenAI({ + modelName: "gpt-3.5-turbo", + streaming: true, + maxRetries: 10, + maxConcurrency: 10, + temperature: 0, + topP: 0, + callbacks: [ + { + handleLLMEnd: async (output) => { + streamingTokenUsed = + output.llmOutput?.estimatedTokenUsage?.totalTokens; + console.log("streaming usage", output.llmOutput?.estimatedTokenUsage); + }, + handleLLMError: async (err) => { + console.error(err); + }, + }, + ], + }); + + const nonStreamingModel = new AzureChatOpenAI({ + modelName: "gpt-3.5-turbo", + streaming: false, + maxRetries: 10, + maxConcurrency: 10, + temperature: 0, + topP: 0, + callbacks: [ + { + handleLLMEnd: async (output) => { + nonStreamingTokenUsed = output.llmOutput?.tokenUsage?.totalTokens; + console.log("non-streaming usage", output.llmOutput?.estimated); + }, + handleLLMError: async (err) => { + console.error(err); + }, + }, + ], + }); + + const [nonStreamingResult, streamingResult] = await Promise.all([ + nonStreamingModel.generate([ + [new SystemMessage(systemPrompt), new HumanMessage(question)], + ]), + streamingModel.generate([ + [new SystemMessage(systemPrompt), new HumanMessage(question)], + ]), + ]); + + expect(streamingTokenUsed).toBeGreaterThan(-1); + if ( + nonStreamingResult.generations[0][0].text === + streamingResult.generations[0][0].text + ) { + expect(streamingTokenUsed).toEqual(nonStreamingTokenUsed); + } +}); + +test("Test Azure ChatOpenAI with bearer token provider", async () => { + const tenantId: string = getEnvironmentVariable("AZURE_TENANT_ID") ?? ""; + const clientId: string = getEnvironmentVariable("AZURE_CLIENT_ID") ?? ""; + const clientSecret: string = + getEnvironmentVariable("AZURE_CLIENT_SECRET") ?? ""; + + const credentials = new ClientSecretCredential( + tenantId, + clientId, + clientSecret + ); + const azureADTokenProvider = getBearerTokenProvider( + credentials, + "https://cognitiveservices.azure.com/.default" + ); + + const chat = new AzureChatOpenAI({ + modelName: "gpt-3.5-turbo", + maxTokens: 5, + azureADTokenProvider, + }); + const message = new HumanMessage("Hello!"); + const res = await chat.invoke([["system", "Say hi"], message]); + console.log(res); +}); diff --git a/libs/langchain-openai/src/tests/azure/embeddings.int.test.ts b/libs/langchain-openai/src/tests/azure/embeddings.int.test.ts new file mode 100644 index 000000000000..7362fe1a73e8 --- /dev/null +++ b/libs/langchain-openai/src/tests/azure/embeddings.int.test.ts @@ -0,0 +1,75 @@ +import { test, expect } from "@jest/globals"; +import { AzureOpenAIEmbeddings as OpenAIEmbeddings } from "../../azure/embeddings.js"; + +test("Test AzureOpenAIEmbeddings.embedQuery", async () => { + const embeddings = new OpenAIEmbeddings(); + const res = await embeddings.embedQuery("Hello world"); + expect(typeof res[0]).toBe("number"); +}); + +test("Test AzureOpenAIEmbeddings.embedDocuments", async () => { + const embeddings = new OpenAIEmbeddings(); + const res = await embeddings.embedDocuments(["Hello world", "Bye bye"]); + expect(res).toHaveLength(2); + expect(typeof res[0][0]).toBe("number"); + expect(typeof res[1][0]).toBe("number"); +}); + +test("Test AzureOpenAIEmbeddings concurrency", async () => { + const embeddings = new OpenAIEmbeddings({ + batchSize: 1, + maxConcurrency: 2, + }); + const res = await embeddings.embedDocuments([ + "Hello world", + "Bye bye", + "Hello world", + "Bye bye", + "Hello world", + "Bye bye", + ]); + expect(res).toHaveLength(6); + expect(res.find((embedding) => typeof embedding[0] !== "number")).toBe( + undefined + ); +}); + +test("Test timeout error thrown from SDK", async () => { + await expect(async () => { + const model = new OpenAIEmbeddings({ + timeout: 1, + maxRetries: 0, + }); + await model.embedDocuments([ + "Hello world", + "Bye bye", + "Hello world", + "Bye bye", + "Hello world", + "Bye bye", + ]); + }).rejects.toThrow(); +}); + +test("Test AzureOpenAIEmbeddings.embedQuery with v3 and dimensions", async () => { + const embeddings = new OpenAIEmbeddings({ + modelName: "text-embedding-3-small", + dimensions: 127, + }); + const res = await embeddings.embedQuery("Hello world"); + expect(typeof res[0]).toBe("number"); + expect(res.length).toBe(127); +}); + +test("Test AzureOpenAIEmbeddings.embedDocuments with v3 and dimensions", async () => { + const embeddings = new OpenAIEmbeddings({ + modelName: "text-embedding-3-small", + dimensions: 127, + }); + const res = await embeddings.embedDocuments(["Hello world", "Bye bye"]); + expect(res).toHaveLength(2); + expect(typeof res[0][0]).toBe("number"); + expect(typeof res[1][0]).toBe("number"); + expect(res[0].length).toBe(127); + expect(res[1].length).toBe(127); +}); diff --git a/libs/langchain-openai/src/tests/azure/llms.int.test.ts b/libs/langchain-openai/src/tests/azure/llms.int.test.ts new file mode 100644 index 000000000000..c4c1baff9d62 --- /dev/null +++ b/libs/langchain-openai/src/tests/azure/llms.int.test.ts @@ -0,0 +1,333 @@ +import { test, expect } from "@jest/globals"; +import { LLMResult } from "@langchain/core/outputs"; +import { StringPromptValue } from "@langchain/core/prompt_values"; +import { CallbackManager } from "@langchain/core/callbacks/manager"; +import { NewTokenIndices } from "@langchain/core/callbacks/base"; +import { + ClientSecretCredential, + getBearerTokenProvider, +} from "@azure/identity"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { AzureOpenAI } from "../../azure/llms.js"; + +test("Test Azure OpenAI invoke", async () => { + const model = new AzureOpenAI({ + maxTokens: 5, + modelName: "gpt-3.5-turbo-instruct", + }); + const res = await model.invoke("Print hello world"); + console.log({ res }); +}); + +test("Test Azure OpenAI call", async () => { + const model = new AzureOpenAI({ + maxTokens: 5, + modelName: "gpt-3.5-turbo-instruct", + }); + const res = await model.call("Print hello world", ["world"]); + console.log({ res }); +}); + +test("Test Azure OpenAI with stop in object", async () => { + const model = new AzureOpenAI({ + maxTokens: 5, + modelName: "gpt-3.5-turbo-instruct", + }); + const res = await model.invoke("Print hello world", { stop: ["world"] }); + console.log({ res }); +}); + +test("Test Azure OpenAI with timeout in call options", async () => { + const model = new AzureOpenAI({ + maxTokens: 5, + modelName: "gpt-3.5-turbo-instruct", + }); + await expect(() => + model.invoke("Print hello world", { + timeout: 10, + }) + ).rejects.toThrow(); +}, 5000); + +test("Test Azure OpenAI with timeout in call options and node adapter", async () => { + const model = new AzureOpenAI({ + maxTokens: 5, + modelName: "gpt-3.5-turbo-instruct", + }); + await expect(() => + model.invoke("Print hello world", { + timeout: 10, + }) + ).rejects.toThrow(); +}, 5000); + +test("Test Azure OpenAI with signal in call options", async () => { + const model = new AzureOpenAI({ + maxTokens: 5, + modelName: "gpt-3.5-turbo-instruct", + }); + const controller = new AbortController(); + await expect(() => { + const ret = model.invoke("Print hello world", { + signal: controller.signal, + }); + + controller.abort(); + + return ret; + }).rejects.toThrow(); +}, 5000); + +test("Test Azure OpenAI with signal in call options and node adapter", async () => { + const model = new AzureOpenAI({ + maxTokens: 5, + modelName: "gpt-3.5-turbo-instruct", + }); + const controller = new AbortController(); + await expect(() => { + const ret = model.invoke("Print hello world", { + signal: controller.signal, + }); + + controller.abort(); + + return ret; + }).rejects.toThrow(); +}, 5000); + +test("Test Azure OpenAI with concurrency == 1", async () => { + const model = new AzureOpenAI({ + maxTokens: 5, + modelName: "gpt-3.5-turbo-instruct", + maxConcurrency: 1, + }); + const res = await Promise.all([ + model.invoke("Print hello world"), + model.invoke("Print hello world"), + ]); + console.log({ res }); +}); + +test("Test Azure OpenAI with maxTokens -1", async () => { + const model = new AzureOpenAI({ + maxTokens: -1, + modelName: "gpt-3.5-turbo-instruct", + }); + const res = await model.call("Print hello world", ["world"]); + console.log({ res }); +}); + +test("Test Azure OpenAI with model name", async () => { + const model = new AzureOpenAI({ modelName: "gpt-3.5-turbo-instruct" }); + expect(model).toBeInstanceOf(AzureOpenAI); + const res = await model.invoke("Print hello world"); + console.log({ res }); + expect(typeof res).toBe("string"); +}); + +test("Test Azure OpenAI with versioned instruct model returns Azure OpenAI", async () => { + const model = new AzureOpenAI({ + modelName: "gpt-3.5-turbo-instruct-0914", + }); + expect(model).toBeInstanceOf(AzureOpenAI); + const res = await model.invoke("Print hello world"); + console.log({ res }); + expect(typeof res).toBe("string"); +}); + +test("Test Azure OpenAI tokenUsage", async () => { + let tokenUsage = { + completionTokens: 0, + promptTokens: 0, + totalTokens: 0, + }; + + const model = new AzureOpenAI({ + maxTokens: 5, + modelName: "gpt-3.5-turbo-instruct", + callbackManager: CallbackManager.fromHandlers({ + async handleLLMEnd(output: LLMResult) { + tokenUsage = output.llmOutput?.tokenUsage; + }, + }), + }); + const res = await model.invoke("Hello"); + console.log({ res }); + + expect(tokenUsage.promptTokens).toBe(1); +}); + +test("Test Azure OpenAI in streaming mode", async () => { + let nrNewTokens = 0; + let streamedCompletion = ""; + + const model = new AzureOpenAI({ + maxTokens: 5, + modelName: "gpt-3.5-turbo-instruct", + streaming: true, + callbacks: CallbackManager.fromHandlers({ + async handleLLMNewToken(token: string) { + nrNewTokens += 1; + streamedCompletion += token; + }, + }), + }); + const res = await model.invoke("Print hello world"); + console.log({ res }); + + expect(nrNewTokens > 0).toBe(true); + expect(res).toBe(streamedCompletion); +}); + +test("Test Azure OpenAI in streaming mode with multiple prompts", async () => { + let nrNewTokens = 0; + const completions = [ + ["", ""], + ["", ""], + ]; + + const model = new AzureOpenAI({ + maxTokens: 5, + modelName: "gpt-3.5-turbo-instruct", + streaming: true, + n: 2, + callbacks: CallbackManager.fromHandlers({ + async handleLLMNewToken(token: string, idx: NewTokenIndices) { + nrNewTokens += 1; + completions[idx.prompt][idx.completion] += token; + }, + }), + }); + const res = await model.generate(["Print hello world", "print hello sea"]); + console.log( + res.generations, + res.generations.map((g) => g[0].generationInfo) + ); + + expect(nrNewTokens > 0).toBe(true); + expect(res.generations.length).toBe(2); + expect(res.generations.map((g) => g.map((gg) => gg.text))).toEqual( + completions + ); +}); + +test("Test Azure OpenAI in streaming mode with multiple prompts", async () => { + let nrNewTokens = 0; + const completions = [[""], [""]]; + + const model = new AzureOpenAI({ + maxTokens: 5, + modelName: "gpt-3.5-turbo", + streaming: true, + n: 1, + callbacks: CallbackManager.fromHandlers({ + async handleLLMNewToken(token: string, idx: NewTokenIndices) { + nrNewTokens += 1; + completions[idx.prompt][idx.completion] += token; + }, + }), + }); + const res = await model.generate(["Print hello world", "print hello sea"]); + console.log( + res.generations, + res.generations.map((g) => g[0].generationInfo) + ); + + expect(nrNewTokens > 0).toBe(true); + expect(res.generations.length).toBe(2); + expect(res.generations.map((g) => g.map((gg) => gg.text))).toEqual( + completions + ); +}); + +test("Test Azure OpenAI prompt value", async () => { + const model = new AzureOpenAI({ + maxTokens: 5, + modelName: "gpt-3.5-turbo-instruct", + }); + const res = await model.generatePrompt([ + new StringPromptValue("Print hello world"), + ]); + expect(res.generations.length).toBe(1); + for (const generation of res.generations) { + expect(generation.length).toBe(1); + for (const g of generation) { + console.log(g.text); + } + } + console.log({ res }); +}); + +test("Test Azure OpenAI stream method", async () => { + const model = new AzureOpenAI({ + maxTokens: 50, + modelName: "gpt-3.5-turbo-instruct", + }); + const stream = await model.stream("Print hello world."); + const chunks = []; + for await (const chunk of stream) { + chunks.push(chunk); + } + expect(chunks.length).toBeGreaterThan(1); +}); + +test("Test Azure OpenAI stream method with abort", async () => { + await expect(async () => { + const model = new AzureOpenAI({ + maxTokens: 250, + modelName: "gpt-3.5-turbo-instruct", + }); + const stream = await model.stream( + "How is your day going? Be extremely verbose.", + { + signal: AbortSignal.timeout(1000), + } + ); + for await (const chunk of stream) { + console.log(chunk); + } + }).rejects.toThrow(); +}); + +test("Test Azure OpenAI stream method with early break", async () => { + const model = new AzureOpenAI({ + maxTokens: 50, + modelName: "gpt-3.5-turbo-instruct", + }); + const stream = await model.stream( + "How is your day going? Be extremely verbose." + ); + let i = 0; + for await (const chunk of stream) { + console.log(chunk); + i += 1; + if (i > 5) { + break; + } + } +}); + +test("Test Azure OpenAI with bearer token credentials", async () => { + const tenantId: string = getEnvironmentVariable("AZURE_TENANT_ID") ?? ""; + const clientId: string = getEnvironmentVariable("AZURE_CLIENT_ID") ?? ""; + const clientSecret: string = + getEnvironmentVariable("AZURE_CLIENT_SECRET") ?? ""; + + const credentials = new ClientSecretCredential( + tenantId, + clientId, + clientSecret + ); + const azureADTokenProvider = getBearerTokenProvider( + credentials, + "https://cognitiveservices.azure.com/.default" + ); + + const model = new AzureOpenAI({ + maxTokens: 5, + modelName: "davinci-002", + azureADTokenProvider, + }); + const res = await model.invoke("Print hello world"); + console.log({ res }); +}); diff --git a/libs/langchain-openai/src/types.ts b/libs/langchain-openai/src/types.ts index 6580df660dbd..135b2c63e394 100644 --- a/libs/langchain-openai/src/types.ts +++ b/libs/langchain-openai/src/types.ts @@ -197,4 +197,10 @@ export declare interface AzureOpenAIInput { * will be result in the endpoint URL: https://westeurope.api.cognitive.microsoft.com/openai/deployments/{DeploymentName}/ */ azureOpenAIBasePath?: string; + + /** + * A function that returns an access token for Microsoft Entra (formerly known as Azure Active Directory), + * which will be invoked on every request. + */ + azureADTokenProvider?: () => Promise; } diff --git a/yarn.lock b/yarn.lock index c02b0e1867bf..be0fb1e845ff 100644 --- a/yarn.lock +++ b/yarn.lock @@ -4691,6 +4691,28 @@ __metadata: languageName: node linkType: hard +"@azure/identity@npm:^4.2.0": + version: 4.2.0 + resolution: "@azure/identity@npm:4.2.0" + dependencies: + "@azure/abort-controller": ^1.0.0 + "@azure/core-auth": ^1.5.0 + "@azure/core-client": ^1.4.0 + "@azure/core-rest-pipeline": ^1.1.0 + "@azure/core-tracing": ^1.0.0 + "@azure/core-util": ^1.3.0 + "@azure/logger": ^1.0.0 + "@azure/msal-browser": ^3.11.1 + "@azure/msal-node": ^2.6.6 + events: ^3.0.0 + jws: ^4.0.0 + open: ^8.0.0 + stoppable: ^1.1.0 + tslib: ^2.2.0 + checksum: b1b336113c944abf89376f366bf8e82958617465c91e561e922c165a10aaa1789e83a78b7baa070671247d0f97c63b4cc89cf6cabc72258f3d9cbe12fe799e2a + languageName: node + linkType: hard + "@azure/logger@npm:^1.0.0, @azure/logger@npm:^1.0.3": version: 1.0.4 resolution: "@azure/logger@npm:1.0.4" @@ -4700,6 +4722,15 @@ __metadata: languageName: node linkType: hard +"@azure/msal-browser@npm:^3.11.1": + version: 3.14.0 + resolution: "@azure/msal-browser@npm:3.14.0" + dependencies: + "@azure/msal-common": 14.10.0 + checksum: 747cd3df32f082e515c5e268d64f0d16afa0ce21ab5154e235ee0eb0fd0e2902504d12bac1f94839afaf9cc94c823d961775c3f57c3f20f12864d13a5ed0fa44 + languageName: node + linkType: hard + "@azure/msal-browser@npm:^3.5.0": version: 3.7.1 resolution: "@azure/msal-browser@npm:3.7.1" @@ -4709,6 +4740,13 @@ __metadata: languageName: node linkType: hard +"@azure/msal-common@npm:14.10.0": + version: 14.10.0 + resolution: "@azure/msal-common@npm:14.10.0" + checksum: 50994f54cdce7425bef42d44b3b15e756ede11efa5d2f84440da41ea13f0b0c00e8c262925e42b2f6d5f6f850233dccf99d810ced7b2cf372b9a645ed53489f8 + languageName: node + linkType: hard + "@azure/msal-common@npm:14.6.1": version: 14.6.1 resolution: "@azure/msal-common@npm:14.6.1" @@ -4727,6 +4765,17 @@ __metadata: languageName: node linkType: hard +"@azure/msal-node@npm:^2.6.6": + version: 2.8.0 + resolution: "@azure/msal-node@npm:2.8.0" + dependencies: + "@azure/msal-common": 14.10.0 + jsonwebtoken: ^9.0.0 + uuid: ^8.3.0 + checksum: 778b7f8a9088ec264b7b59c1dd06ca9f4d9acfe7dd9e9fa92fb1fbcb8da3f9917ebf641715d4fe2338da17eae00a571cb8f1b6431c419609c8eff372654ce378 + languageName: node + linkType: hard + "@azure/openai@npm:1.0.0-beta.11": version: 1.0.0-beta.11 resolution: "@azure/openai@npm:1.0.0-beta.11" @@ -9764,6 +9813,7 @@ __metadata: version: 0.0.0-use.local resolution: "@langchain/openai@workspace:libs/langchain-openai" dependencies: + "@azure/identity": ^4.2.0 "@jest/globals": ^29.5.0 "@langchain/core": ~0.1.56 "@langchain/scripts": ~0.0 @@ -9780,7 +9830,7 @@ __metadata: jest: ^29.5.0 jest-environment-node: ^29.6.4 js-tiktoken: ^1.0.7 - openai: ^4.32.1 + openai: ^4.41.1 prettier: ^2.8.3 release-it: ^15.10.1 rimraf: ^5.0.1 @@ -26787,7 +26837,7 @@ __metadata: node-llama-cpp: 2.7.3 notion-to-md: ^3.1.0 officeparser: ^4.0.4 - openai: ^4.32.1 + openai: ^4.41.1 openapi-types: ^12.1.3 p-retry: 4 pdf-parse: 1.1.1 @@ -29484,9 +29534,9 @@ __metadata: languageName: node linkType: hard -"openai@npm:^4.32.1": - version: 4.32.1 - resolution: "openai@npm:4.32.1" +"openai@npm:^4.41.1": + version: 4.42.0 + resolution: "openai@npm:4.42.0" dependencies: "@types/node": ^18.11.18 "@types/node-fetch": ^2.6.4 @@ -29498,7 +29548,7 @@ __metadata: web-streams-polyfill: ^3.2.1 bin: openai: bin/cli - checksum: da60909a86196ff88b5983c302c3da1dcbe0660ec87652dda62866e38cf6005c4974eaeb731446de9a9f7f13146ef3018aca51b8bfe1486374226bb102494dcb + checksum: 28eae59cb4b51b3d7224647e3be369453522b5be3f1530a51bf8ffe2aefa165cf4156a7621803782014c022943b16ff27690159ef3720214fdbb137a21e47bf2 languageName: node linkType: hard