From beed3b8988b52ad9252d64981be61c1118514626 Mon Sep 17 00:00:00 2001 From: Ovidijus Parsiunas Date: Sat, 11 Nov 2023 20:20:59 +0000 Subject: [PATCH] ability to send files to OpenAI Assistants --- .../src/services/openAI/openAIAssistantIO.ts | 49 +++++++++---------- component/src/services/openAI/openAIChatIO.ts | 3 +- .../src/services/openAI/openAIImagesIO.ts | 2 +- .../services/openAI/openAISpeechToTextIO.ts | 2 +- .../src/services/openAI/utils/openAIUtils.ts | 40 +++++++++++++++ .../stabilityAI/stabilityAIImageToImageIO.ts | 2 +- .../stabilityAIImageToImageMaskingIO.ts | 2 +- .../stabilityAIImageToImageUpscaleIO.ts | 2 +- component/src/services/utils/baseServiceIO.ts | 4 +- component/src/utils/HTTP/requestUtils.ts | 17 +++++-- .../chat/input/buttons/submit/submitButton.ts | 1 + 11 files changed, 82 insertions(+), 42 deletions(-) diff --git a/component/src/services/openAI/openAIAssistantIO.ts b/component/src/services/openAI/openAIAssistantIO.ts index ba9e6959c..ca71edcd2 100644 --- a/component/src/services/openAI/openAIAssistantIO.ts +++ b/component/src/services/openAI/openAIAssistantIO.ts @@ -4,7 +4,6 @@ import {MessageUtils} from '../../views/chat/messages/messageUtils'; import {DirectConnection} from '../../types/directConnection'; import {MessageLimitUtils} from '../utils/messageLimitUtils'; import {Messages} from '../../views/chat/messages/messages'; -import {RequestUtils} from '../../utils/HTTP/requestUtils'; import {HTTPRequest} from '../../utils/HTTP/HTTPRequest'; import {DirectServiceIO} from '../utils/directServiceIO'; import {MessageContent} from '../../types/messages'; @@ -33,8 +32,7 @@ export class OpenAIAssistantIO extends DirectServiceIO { constructor(deepChat: DeepChat) { const directConnectionCopy = JSON.parse(JSON.stringify(deepChat.directConnection)) as DirectConnection; const apiKey = directConnectionCopy.openAI; - const imageFiles = deepChat.images ? {images: {files: {maxNumberOfFiles: 10}}} : {}; - super(deepChat, OpenAIUtils.buildKeyVerificationDetails(), OpenAIUtils.buildHeaders, apiKey, imageFiles); + super(deepChat, OpenAIUtils.buildKeyVerificationDetails(), OpenAIUtils.buildHeaders, apiKey); const config = directConnectionCopy.openAI?.assistant; // can be undefined as this is the default service if (typeof config === 'object') { this.rawBody.assistant_id = config.assistant_id; @@ -46,39 +44,44 @@ export class OpenAIAssistantIO extends DirectServiceIO { this.maxMessages = 1; // messages are stored in OpenAI threads and can't create new thread with 'assistant' messages } - private processMessages(pMessages: MessageContent[]) { + private processMessages(pMessages: MessageContent[], file_ids?: string[]) { const totalMessagesMaxCharLength = this.totalMessagesMaxCharLength || OpenAIUtils.CONVERSE_MAX_CHAR_LENGTH; return MessageLimitUtils.getCharacterLimitMessages(pMessages, totalMessagesMaxCharLength).map((message) => { - return {content: message.text, role: message.role === MessageUtils.AI_ROLE ? 'assistant' : 'user'}; + return {content: message.text || '', role: message.role === MessageUtils.AI_ROLE ? 'assistant' : 'user', file_ids}; }); } - private createNewThreadMessages(body: OpenAIConverseBodyInternal, pMessages: MessageContent[]) { + private createNewThreadMessages(body: OpenAIConverseBodyInternal, pMessages: MessageContent[], file_ids?: string[]) { const bodyCopy = JSON.parse(JSON.stringify(body)); - const processedMessages = this.processMessages(pMessages); + const processedMessages = this.processMessages(pMessages, file_ids); bodyCopy.thread = {messages: processedMessages}; return bodyCopy; } - override async callServiceAPI(messages: Messages, pMessages: MessageContent[]) { - if (!this.requestSettings) throw new Error('Request settings have not been set up'); - // here instead of constructor as messages may be loaded later - if (!this.searchedForThreadId) this.searchPreviousMessagesForThreadId(messages.messages); - this.requestSettings.method = 'POST'; + private callService(messages: Messages, pMessages: MessageContent[], file_ids?: string[]) { if (this.sessionId) { // https://platform.openai.com/docs/api-reference/messages/createMessage this.url = `${OpenAIAssistantIO.THREAD_PREFIX}/${this.sessionId}/messages`; - const body = this.processMessages([pMessages[pMessages.length - 1]])[0]; + const body = this.processMessages(pMessages, file_ids)[0]; HTTPRequest.request(this, body, messages); } else { // https://platform.openai.com/docs/api-reference/runs/createThreadAndRun this.url = `${OpenAIAssistantIO.THREAD_PREFIX}/runs`; - const body = this.createNewThreadMessages(this.rawBody, pMessages); + const body = this.createNewThreadMessages(this.rawBody, pMessages, file_ids); HTTPRequest.request(this, body, messages); } this.messages = messages; } + override async callServiceAPI(messages: Messages, pMessages: MessageContent[], files?: File[]) { + if (!this.requestSettings) throw new Error('Request settings have not been set up'); + // here instead of constructor as messages may be loaded later + if (!this.searchedForThreadId) this.searchPreviousMessagesForThreadId(messages.messages); + const file_ids = files ? await OpenAIUtils.storeFiles(this, messages, files) : undefined; + this.requestSettings.method = 'POST'; + this.callService(messages, pMessages, file_ids); + } + private searchPreviousMessagesForThreadId(messages: MessageContent[]) { const messageWithSession = messages.find((message) => message.sessionId); if (messageWithSession) this.sessionId = messageWithSession.sessionId; @@ -99,11 +102,13 @@ export class OpenAIAssistantIO extends DirectServiceIO { if (this.sessionId) { // https://platform.openai.com/docs/api-reference/runs/createRun this.url = `${OpenAIAssistantIO.THREAD_PREFIX}/${this.sessionId}/runs`; - const runObj = await this.directFetch(JSON.parse(JSON.stringify(this.rawBody))); + const runObj = await OpenAIUtils.directFetch(this, JSON.parse(JSON.stringify(this.rawBody)), 'POST'); this.run_id = runObj.id; } else { this.sessionId = result.thread_id; this.run_id = result.id; + // updates the user sent message with the session id (the message event sent did not have this id) + if (this.messages) this.messages.messages[this.messages.messages.length - 1].sessionId = this.sessionId; } } @@ -112,7 +117,7 @@ export class OpenAIAssistantIO extends DirectServiceIO { if (status === 'queued' || status === 'in_progress') return {timeoutMS: OpenAIAssistantIO.POLLING_TIMEOUT_MS}; if (status === 'completed' && this.messages) { this.url = `${OpenAIAssistantIO.THREAD_PREFIX}/${result.thread_id}/messages`; - const threadMessages = (await this.directFetch({}, 'GET')) as OpenAIAssistantMessagesResult; + const threadMessages = (await OpenAIUtils.directFetch(this, {}, 'GET')) as OpenAIAssistantMessagesResult; const lastMessage = threadMessages.data[0]; return {text: lastMessage.content[0].text.value, sessionId: this.sessionId}; } @@ -137,17 +142,7 @@ export class OpenAIAssistantIO extends DirectServiceIO { }); // https://platform.openai.com/docs/api-reference/runs/submitToolOutputs this.url = `${OpenAIAssistantIO.THREAD_PREFIX}/${this.sessionId}/runs/${this.run_id}/submit_tool_outputs`; - await this.directFetch({tool_outputs}); + await OpenAIUtils.directFetch(this, {tool_outputs}, 'POST'); return {timeoutMS: OpenAIAssistantIO.POLLING_TIMEOUT_MS}; } - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - private async directFetch(body: any, method = 'POST') { - this.requestSettings.method = method; - const result = await RequestUtils.fetch(this, this.requestSettings.headers, true, body).then((resp) => - RequestUtils.processResponseByType(resp) - ); - if (result.error) throw result.error.message; - return result; - } } diff --git a/component/src/services/openAI/openAIChatIO.ts b/component/src/services/openAI/openAIChatIO.ts index 1966b320d..70e1f6101 100644 --- a/component/src/services/openAI/openAIChatIO.ts +++ b/component/src/services/openAI/openAIChatIO.ts @@ -32,8 +32,7 @@ export class OpenAIChatIO extends DirectServiceIO { constructor(deepChat: DeepChat) { const directConnectionCopy = JSON.parse(JSON.stringify(deepChat.directConnection)) as DirectConnection; const apiKey = directConnectionCopy.openAI; - const imageFiles = deepChat.images ? {images: {files: {maxNumberOfFiles: 10}}} : {}; - super(deepChat, OpenAIUtils.buildKeyVerificationDetails(), OpenAIUtils.buildHeaders, apiKey, imageFiles); + super(deepChat, OpenAIUtils.buildKeyVerificationDetails(), OpenAIUtils.buildHeaders, apiKey); const config = directConnectionCopy.openAI?.chat; // can be undefined as this is the default service if (typeof config === 'object') { if (config.system_prompt) this._systemMessage = OpenAIChatIO.generateSystemMessage(config.system_prompt); diff --git a/component/src/services/openAI/openAIImagesIO.ts b/component/src/services/openAI/openAIImagesIO.ts index d4fa592e1..af935e029 100644 --- a/component/src/services/openAI/openAIImagesIO.ts +++ b/component/src/services/openAI/openAIImagesIO.ts @@ -82,7 +82,7 @@ export class OpenAIImagesIO extends DirectServiceIO { formData = OpenAIImagesIO.createFormDataBody(this.rawBody, files[0]); } // need to pass stringifyBody boolean separately as binding is throwing an error for some reason - RequestUtils.temporarilyRemoveHeader(this.requestSettings, + RequestUtils.tempRemoveContentHeader(this.requestSettings, HTTPRequest.request.bind(this, this, formData, messages), false); } diff --git a/component/src/services/openAI/openAISpeechToTextIO.ts b/component/src/services/openAI/openAISpeechToTextIO.ts index 9127a7904..d0802e306 100644 --- a/component/src/services/openAI/openAISpeechToTextIO.ts +++ b/component/src/services/openAI/openAISpeechToTextIO.ts @@ -86,7 +86,7 @@ export class OpenAISpeechToTextIO extends DirectServiceIO { const body = this.preprocessBody(this.rawBody, pMessages); const formData = OpenAISpeechToTextIO.createFormDataBody(body, files[0]); // need to pass stringifyBody boolean separately as binding is throwing an error for some reason - RequestUtils.temporarilyRemoveHeader(this.requestSettings, + RequestUtils.tempRemoveContentHeader(this.requestSettings, HTTPRequest.request.bind(this, this, formData, messages), false); } diff --git a/component/src/services/openAI/utils/openAIUtils.ts b/component/src/services/openAI/utils/openAIUtils.ts index 1445ec4b9..78f8fadda 100644 --- a/component/src/services/openAI/utils/openAIUtils.ts +++ b/component/src/services/openAI/utils/openAIUtils.ts @@ -1,6 +1,9 @@ import {KeyVerificationDetails} from '../../../types/keyVerificationDetails'; import {ErrorMessages} from '../../../utils/errorMessages/errorMessages'; import {OpenAIConverseResult} from '../../../types/openAIResult'; +import {Messages} from '../../../views/chat/messages/messages'; +import {RequestUtils} from '../../../utils/HTTP/requestUtils'; +import {ServiceIO} from '../../serviceIO'; export class OpenAIUtils { // 13352 roughly adds up to 3,804 tokens just to be safe @@ -36,4 +39,41 @@ export class OpenAIUtils { handleVerificationResult: OpenAIUtils.handleVerificationResult, }; } + + public static async storeFiles(serviceIO: ServiceIO, messages: Messages, files: File[]) { + const headers = serviceIO.requestSettings.headers; + if (!headers) return; + serviceIO.url = `https://api.openai.com/v1/files`; // stores files + const previousContetType = headers[RequestUtils.CONTENT_TYPE]; + delete headers[RequestUtils.CONTENT_TYPE]; + const requests = files.map(async (file) => { + const formData = new FormData(); + formData.append('purpose', 'assistants'); + formData.append('file', file); + return new Promise<{id: string}>((resolve) => { + resolve(OpenAIUtils.directFetch(serviceIO, formData, 'POST', false)); // should perhaps use await but works without + }); + }); + try { + const fileIds = (await Promise.all(requests)).map((result) => result.id); + headers[RequestUtils.CONTENT_TYPE] = previousContetType; + return fileIds; + } catch (err) { + headers[RequestUtils.CONTENT_TYPE] = previousContetType; + // error handled here as files not sent using HTTPRequest.request to not trigger the interceptors + RequestUtils.displayError(messages, err as object); + serviceIO.completionsHandlers.onFinish(); + throw err; + } + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + public static async directFetch(serviceIO: ServiceIO, body: any, method: 'POST' | 'GET', stringify = true) { + serviceIO.requestSettings.method = method; + const result = await RequestUtils.fetch(serviceIO, serviceIO.requestSettings.headers, stringify, body).then((resp) => + RequestUtils.processResponseByType(resp) + ); + if (result.error) throw result.error.message; + return result; + } } diff --git a/component/src/services/stabilityAI/stabilityAIImageToImageIO.ts b/component/src/services/stabilityAI/stabilityAIImageToImageIO.ts index 01a4a4d8c..c328256de 100644 --- a/component/src/services/stabilityAI/stabilityAIImageToImageIO.ts +++ b/component/src/services/stabilityAI/stabilityAIImageToImageIO.ts @@ -70,7 +70,7 @@ export class StabilityAIImageToImageIO extends StabilityAIIO { const lastMessage = pMessages[pMessages.length - 1]?.text?.trim(); const formData = this.createFormDataBody(this.rawBody, files[0], lastMessage); // need to pass stringifyBody boolean separately as binding is throwing an error for some reason - RequestUtils.temporarilyRemoveHeader(this.requestSettings, + RequestUtils.tempRemoveContentHeader(this.requestSettings, HTTPRequest.request.bind(this, this, formData, messages), false); } diff --git a/component/src/services/stabilityAI/stabilityAIImageToImageMaskingIO.ts b/component/src/services/stabilityAI/stabilityAIImageToImageMaskingIO.ts index 03a697c77..4ce5521bc 100644 --- a/component/src/services/stabilityAI/stabilityAIImageToImageMaskingIO.ts +++ b/component/src/services/stabilityAI/stabilityAIImageToImageMaskingIO.ts @@ -74,7 +74,7 @@ export class StabilityAIImageToImageMaskingIO extends StabilityAIIO { const lastMessage = pMessages[pMessages.length - 1]?.text?.trim(); const formData = this.createFormDataBody(this.rawBody, files[0], files[1], lastMessage); // need to pass stringifyBody boolean separately as binding is throwing an error for some reason - RequestUtils.temporarilyRemoveHeader(this.requestSettings, + RequestUtils.tempRemoveContentHeader(this.requestSettings, HTTPRequest.request.bind(this, this, formData, messages), false); } diff --git a/component/src/services/stabilityAI/stabilityAIImageToImageUpscaleIO.ts b/component/src/services/stabilityAI/stabilityAIImageToImageUpscaleIO.ts index ce1f0dd99..3315453dd 100644 --- a/component/src/services/stabilityAI/stabilityAIImageToImageUpscaleIO.ts +++ b/component/src/services/stabilityAI/stabilityAIImageToImageUpscaleIO.ts @@ -57,7 +57,7 @@ export class StabilityAIImageToImageUpscaleIO extends StabilityAIIO { if (!files) throw new Error('Image was not found'); const formData = this.createFormDataBody(this.rawBody, files[0]); // need to pass stringifyBody boolean separately as binding is throwing an error for some reason - RequestUtils.temporarilyRemoveHeader(this.requestSettings, + RequestUtils.tempRemoveContentHeader(this.requestSettings, HTTPRequest.request.bind(this, this, formData, messages), false); } diff --git a/component/src/services/utils/baseServiceIO.ts b/component/src/services/utils/baseServiceIO.ts index ff1babd97..2612702ad 100644 --- a/component/src/services/utils/baseServiceIO.ts +++ b/component/src/services/utils/baseServiceIO.ts @@ -34,7 +34,6 @@ export class BaseServiceIO implements ServiceIO { recordAudio?: MicrophoneFilesServiceConfig; totalMessagesMaxCharLength?: number; maxMessages?: number; - private readonly _directServiceRequiresFiles: boolean; demo?: DemoT; // these are placeholders that are later populated in submitButton.ts completionsHandlers: CompletionsHandlers = {} as CompletionsHandlers; @@ -49,7 +48,6 @@ export class BaseServiceIO implements ServiceIO { SetFileTypes.set(deepChat, this, existingFileTypes); if (deepChat.request) this.requestSettings = deepChat.request; if (this.demo) this.requestSettings.url ??= Demo.URL; - this._directServiceRequiresFiles = !!existingFileTypes && Object.keys(existingFileTypes).length > 0; if (this.requestSettings.websocket) Websocket.setup(this); } @@ -126,7 +124,7 @@ export class BaseServiceIO implements ServiceIO { if (this.requestSettings.websocket) { const body = {messages: processedMessages, ...this.rawBody}; Websocket.sendWebsocket(this, body, messages, false); - } else if (requestContents.files && !this._directServiceRequiresFiles) { + } else if (requestContents.files && !this.isDirectConnection()) { this.callApiWithFiles(this.rawBody, messages, processedMessages, requestContents.files); } else { this.callServiceAPI(messages, processedMessages, requestContents.files); diff --git a/component/src/utils/HTTP/requestUtils.ts b/component/src/utils/HTTP/requestUtils.ts index 60c5c3528..3f7098199 100644 --- a/component/src/utils/HTTP/requestUtils.ts +++ b/component/src/utils/HTTP/requestUtils.ts @@ -17,13 +17,20 @@ export class RequestUtils { // need to pass stringifyBody boolean separately as binding is throwing an error for some reason // prettier-ignore - public static async temporarilyRemoveHeader(requestSettings: Request | undefined, - request: (stringifyBody?: boolean) => Promise, stringifyBody: boolean) { + public static async tempRemoveContentHeader(requestSettings: Request | undefined, + request: (stringifyBody?: boolean) => Promise, stringifyBody: boolean) { if (!requestSettings?.headers) throw new Error('Request settings have not been set up'); - const previousHeader = requestSettings.headers[RequestUtils.CONTENT_TYPE]; + const previousContetType = requestSettings.headers[RequestUtils.CONTENT_TYPE]; delete requestSettings.headers[RequestUtils.CONTENT_TYPE]; - await request(stringifyBody); - requestSettings.headers[RequestUtils.CONTENT_TYPE] = previousHeader; + let result; + try { + result = await request(stringifyBody); + } catch (e) { + requestSettings.headers[RequestUtils.CONTENT_TYPE] = previousContetType; + throw e; + } + requestSettings.headers[RequestUtils.CONTENT_TYPE] = previousContetType; + return result; } public static displayError(messages: Messages, err: object, defMessage = 'Service error, please try again.') { diff --git a/component/src/views/chat/input/buttons/submit/submitButton.ts b/component/src/views/chat/input/buttons/submit/submitButton.ts index fe754faa5..d74cb63d7 100644 --- a/component/src/views/chat/input/buttons/submit/submitButton.ts +++ b/component/src/views/chat/input/buttons/submit/submitButton.ts @@ -173,6 +173,7 @@ export class SubmitButton extends InputButton { const data: Response = {}; if (userText !== '') data.text = userText; if (uploadedFilesData) data.files = await this._messages.addMultipleFiles(uploadedFilesData); + if (this._serviceIO.sessionId) data.sessionId = this._serviceIO.sessionId; if (Object.keys(data).length > 0) this._messages.addNewMessage(data, false); }