Skip to content

Commit

Permalink
ability to send files to OpenAI Assistants
Browse files Browse the repository at this point in the history
  • Loading branch information
OvidijusParsiunas committed Nov 11, 2023
1 parent 9d12b6c commit beed3b8
Show file tree
Hide file tree
Showing 11 changed files with 82 additions and 42 deletions.
49 changes: 22 additions & 27 deletions component/src/services/openAI/openAIAssistantIO.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import {MessageUtils} from '../../views/chat/messages/messageUtils';
import {DirectConnection} from '../../types/directConnection';
import {MessageLimitUtils} from '../utils/messageLimitUtils';
import {Messages} from '../../views/chat/messages/messages';
import {RequestUtils} from '../../utils/HTTP/requestUtils';
import {HTTPRequest} from '../../utils/HTTP/HTTPRequest';
import {DirectServiceIO} from '../utils/directServiceIO';
import {MessageContent} from '../../types/messages';
Expand Down Expand Up @@ -33,8 +32,7 @@ export class OpenAIAssistantIO extends DirectServiceIO {
constructor(deepChat: DeepChat) {
const directConnectionCopy = JSON.parse(JSON.stringify(deepChat.directConnection)) as DirectConnection;
const apiKey = directConnectionCopy.openAI;
const imageFiles = deepChat.images ? {images: {files: {maxNumberOfFiles: 10}}} : {};
super(deepChat, OpenAIUtils.buildKeyVerificationDetails(), OpenAIUtils.buildHeaders, apiKey, imageFiles);
super(deepChat, OpenAIUtils.buildKeyVerificationDetails(), OpenAIUtils.buildHeaders, apiKey);
const config = directConnectionCopy.openAI?.assistant; // can be undefined as this is the default service
if (typeof config === 'object') {
this.rawBody.assistant_id = config.assistant_id;
Expand All @@ -46,39 +44,44 @@ export class OpenAIAssistantIO extends DirectServiceIO {
this.maxMessages = 1; // messages are stored in OpenAI threads and can't create new thread with 'assistant' messages
}

private processMessages(pMessages: MessageContent[]) {
private processMessages(pMessages: MessageContent[], file_ids?: string[]) {
const totalMessagesMaxCharLength = this.totalMessagesMaxCharLength || OpenAIUtils.CONVERSE_MAX_CHAR_LENGTH;
return MessageLimitUtils.getCharacterLimitMessages(pMessages, totalMessagesMaxCharLength).map((message) => {
return {content: message.text, role: message.role === MessageUtils.AI_ROLE ? 'assistant' : 'user'};
return {content: message.text || '', role: message.role === MessageUtils.AI_ROLE ? 'assistant' : 'user', file_ids};
});
}

private createNewThreadMessages(body: OpenAIConverseBodyInternal, pMessages: MessageContent[]) {
private createNewThreadMessages(body: OpenAIConverseBodyInternal, pMessages: MessageContent[], file_ids?: string[]) {
const bodyCopy = JSON.parse(JSON.stringify(body));
const processedMessages = this.processMessages(pMessages);
const processedMessages = this.processMessages(pMessages, file_ids);
bodyCopy.thread = {messages: processedMessages};
return bodyCopy;
}

override async callServiceAPI(messages: Messages, pMessages: MessageContent[]) {
if (!this.requestSettings) throw new Error('Request settings have not been set up');
// here instead of constructor as messages may be loaded later
if (!this.searchedForThreadId) this.searchPreviousMessagesForThreadId(messages.messages);
this.requestSettings.method = 'POST';
private callService(messages: Messages, pMessages: MessageContent[], file_ids?: string[]) {
if (this.sessionId) {
// https://platform.openai.com/docs/api-reference/messages/createMessage
this.url = `${OpenAIAssistantIO.THREAD_PREFIX}/${this.sessionId}/messages`;
const body = this.processMessages([pMessages[pMessages.length - 1]])[0];
const body = this.processMessages(pMessages, file_ids)[0];
HTTPRequest.request(this, body, messages);
} else {
// https://platform.openai.com/docs/api-reference/runs/createThreadAndRun
this.url = `${OpenAIAssistantIO.THREAD_PREFIX}/runs`;
const body = this.createNewThreadMessages(this.rawBody, pMessages);
const body = this.createNewThreadMessages(this.rawBody, pMessages, file_ids);
HTTPRequest.request(this, body, messages);
}
this.messages = messages;
}

override async callServiceAPI(messages: Messages, pMessages: MessageContent[], files?: File[]) {
if (!this.requestSettings) throw new Error('Request settings have not been set up');
// here instead of constructor as messages may be loaded later
if (!this.searchedForThreadId) this.searchPreviousMessagesForThreadId(messages.messages);
const file_ids = files ? await OpenAIUtils.storeFiles(this, messages, files) : undefined;
this.requestSettings.method = 'POST';
this.callService(messages, pMessages, file_ids);
}

private searchPreviousMessagesForThreadId(messages: MessageContent[]) {
const messageWithSession = messages.find((message) => message.sessionId);
if (messageWithSession) this.sessionId = messageWithSession.sessionId;
Expand All @@ -99,11 +102,13 @@ export class OpenAIAssistantIO extends DirectServiceIO {
if (this.sessionId) {
// https://platform.openai.com/docs/api-reference/runs/createRun
this.url = `${OpenAIAssistantIO.THREAD_PREFIX}/${this.sessionId}/runs`;
const runObj = await this.directFetch(JSON.parse(JSON.stringify(this.rawBody)));
const runObj = await OpenAIUtils.directFetch(this, JSON.parse(JSON.stringify(this.rawBody)), 'POST');
this.run_id = runObj.id;
} else {
this.sessionId = result.thread_id;
this.run_id = result.id;
// updates the user sent message with the session id (the message event sent did not have this id)
if (this.messages) this.messages.messages[this.messages.messages.length - 1].sessionId = this.sessionId;
}
}

Expand All @@ -112,7 +117,7 @@ export class OpenAIAssistantIO extends DirectServiceIO {
if (status === 'queued' || status === 'in_progress') return {timeoutMS: OpenAIAssistantIO.POLLING_TIMEOUT_MS};
if (status === 'completed' && this.messages) {
this.url = `${OpenAIAssistantIO.THREAD_PREFIX}/${result.thread_id}/messages`;
const threadMessages = (await this.directFetch({}, 'GET')) as OpenAIAssistantMessagesResult;
const threadMessages = (await OpenAIUtils.directFetch(this, {}, 'GET')) as OpenAIAssistantMessagesResult;
const lastMessage = threadMessages.data[0];
return {text: lastMessage.content[0].text.value, sessionId: this.sessionId};
}
Expand All @@ -137,17 +142,7 @@ export class OpenAIAssistantIO extends DirectServiceIO {
});
// https://platform.openai.com/docs/api-reference/runs/submitToolOutputs
this.url = `${OpenAIAssistantIO.THREAD_PREFIX}/${this.sessionId}/runs/${this.run_id}/submit_tool_outputs`;
await this.directFetch({tool_outputs});
await OpenAIUtils.directFetch(this, {tool_outputs}, 'POST');
return {timeoutMS: OpenAIAssistantIO.POLLING_TIMEOUT_MS};
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
private async directFetch(body: any, method = 'POST') {
this.requestSettings.method = method;
const result = await RequestUtils.fetch(this, this.requestSettings.headers, true, body).then((resp) =>
RequestUtils.processResponseByType(resp)
);
if (result.error) throw result.error.message;
return result;
}
}
3 changes: 1 addition & 2 deletions component/src/services/openAI/openAIChatIO.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ export class OpenAIChatIO extends DirectServiceIO {
constructor(deepChat: DeepChat) {
const directConnectionCopy = JSON.parse(JSON.stringify(deepChat.directConnection)) as DirectConnection;
const apiKey = directConnectionCopy.openAI;
const imageFiles = deepChat.images ? {images: {files: {maxNumberOfFiles: 10}}} : {};
super(deepChat, OpenAIUtils.buildKeyVerificationDetails(), OpenAIUtils.buildHeaders, apiKey, imageFiles);
super(deepChat, OpenAIUtils.buildKeyVerificationDetails(), OpenAIUtils.buildHeaders, apiKey);
const config = directConnectionCopy.openAI?.chat; // can be undefined as this is the default service
if (typeof config === 'object') {
if (config.system_prompt) this._systemMessage = OpenAIChatIO.generateSystemMessage(config.system_prompt);
Expand Down
2 changes: 1 addition & 1 deletion component/src/services/openAI/openAIImagesIO.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ export class OpenAIImagesIO extends DirectServiceIO {
formData = OpenAIImagesIO.createFormDataBody(this.rawBody, files[0]);
}
// need to pass stringifyBody boolean separately as binding is throwing an error for some reason
RequestUtils.temporarilyRemoveHeader(this.requestSettings,
RequestUtils.tempRemoveContentHeader(this.requestSettings,
HTTPRequest.request.bind(this, this, formData, messages), false);
}

Expand Down
2 changes: 1 addition & 1 deletion component/src/services/openAI/openAISpeechToTextIO.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ export class OpenAISpeechToTextIO extends DirectServiceIO {
const body = this.preprocessBody(this.rawBody, pMessages);
const formData = OpenAISpeechToTextIO.createFormDataBody(body, files[0]);
// need to pass stringifyBody boolean separately as binding is throwing an error for some reason
RequestUtils.temporarilyRemoveHeader(this.requestSettings,
RequestUtils.tempRemoveContentHeader(this.requestSettings,
HTTPRequest.request.bind(this, this, formData, messages), false);
}

Expand Down
40 changes: 40 additions & 0 deletions component/src/services/openAI/utils/openAIUtils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import {KeyVerificationDetails} from '../../../types/keyVerificationDetails';
import {ErrorMessages} from '../../../utils/errorMessages/errorMessages';
import {OpenAIConverseResult} from '../../../types/openAIResult';
import {Messages} from '../../../views/chat/messages/messages';
import {RequestUtils} from '../../../utils/HTTP/requestUtils';
import {ServiceIO} from '../../serviceIO';

export class OpenAIUtils {
// 13352 roughly adds up to 3,804 tokens just to be safe
Expand Down Expand Up @@ -36,4 +39,41 @@ export class OpenAIUtils {
handleVerificationResult: OpenAIUtils.handleVerificationResult,
};
}

public static async storeFiles(serviceIO: ServiceIO, messages: Messages, files: File[]) {
const headers = serviceIO.requestSettings.headers;
if (!headers) return;
serviceIO.url = `https://api.openai.com/v1/files`; // stores files
const previousContetType = headers[RequestUtils.CONTENT_TYPE];
delete headers[RequestUtils.CONTENT_TYPE];
const requests = files.map(async (file) => {
const formData = new FormData();
formData.append('purpose', 'assistants');
formData.append('file', file);
return new Promise<{id: string}>((resolve) => {
resolve(OpenAIUtils.directFetch(serviceIO, formData, 'POST', false)); // should perhaps use await but works without
});
});
try {
const fileIds = (await Promise.all(requests)).map((result) => result.id);
headers[RequestUtils.CONTENT_TYPE] = previousContetType;
return fileIds;
} catch (err) {
headers[RequestUtils.CONTENT_TYPE] = previousContetType;
// error handled here as files not sent using HTTPRequest.request to not trigger the interceptors
RequestUtils.displayError(messages, err as object);
serviceIO.completionsHandlers.onFinish();
throw err;
}
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
public static async directFetch(serviceIO: ServiceIO, body: any, method: 'POST' | 'GET', stringify = true) {
serviceIO.requestSettings.method = method;
const result = await RequestUtils.fetch(serviceIO, serviceIO.requestSettings.headers, stringify, body).then((resp) =>
RequestUtils.processResponseByType(resp)
);
if (result.error) throw result.error.message;
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ export class StabilityAIImageToImageIO extends StabilityAIIO {
const lastMessage = pMessages[pMessages.length - 1]?.text?.trim();
const formData = this.createFormDataBody(this.rawBody, files[0], lastMessage);
// need to pass stringifyBody boolean separately as binding is throwing an error for some reason
RequestUtils.temporarilyRemoveHeader(this.requestSettings,
RequestUtils.tempRemoveContentHeader(this.requestSettings,
HTTPRequest.request.bind(this, this, formData, messages), false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ export class StabilityAIImageToImageMaskingIO extends StabilityAIIO {
const lastMessage = pMessages[pMessages.length - 1]?.text?.trim();
const formData = this.createFormDataBody(this.rawBody, files[0], files[1], lastMessage);
// need to pass stringifyBody boolean separately as binding is throwing an error for some reason
RequestUtils.temporarilyRemoveHeader(this.requestSettings,
RequestUtils.tempRemoveContentHeader(this.requestSettings,
HTTPRequest.request.bind(this, this, formData, messages), false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ export class StabilityAIImageToImageUpscaleIO extends StabilityAIIO {
if (!files) throw new Error('Image was not found');
const formData = this.createFormDataBody(this.rawBody, files[0]);
// need to pass stringifyBody boolean separately as binding is throwing an error for some reason
RequestUtils.temporarilyRemoveHeader(this.requestSettings,
RequestUtils.tempRemoveContentHeader(this.requestSettings,
HTTPRequest.request.bind(this, this, formData, messages), false);
}

Expand Down
4 changes: 1 addition & 3 deletions component/src/services/utils/baseServiceIO.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ export class BaseServiceIO implements ServiceIO {
recordAudio?: MicrophoneFilesServiceConfig;
totalMessagesMaxCharLength?: number;
maxMessages?: number;
private readonly _directServiceRequiresFiles: boolean;
demo?: DemoT;
// these are placeholders that are later populated in submitButton.ts
completionsHandlers: CompletionsHandlers = {} as CompletionsHandlers;
Expand All @@ -49,7 +48,6 @@ export class BaseServiceIO implements ServiceIO {
SetFileTypes.set(deepChat, this, existingFileTypes);
if (deepChat.request) this.requestSettings = deepChat.request;
if (this.demo) this.requestSettings.url ??= Demo.URL;
this._directServiceRequiresFiles = !!existingFileTypes && Object.keys(existingFileTypes).length > 0;
if (this.requestSettings.websocket) Websocket.setup(this);
}

Expand Down Expand Up @@ -126,7 +124,7 @@ export class BaseServiceIO implements ServiceIO {
if (this.requestSettings.websocket) {
const body = {messages: processedMessages, ...this.rawBody};
Websocket.sendWebsocket(this, body, messages, false);
} else if (requestContents.files && !this._directServiceRequiresFiles) {
} else if (requestContents.files && !this.isDirectConnection()) {
this.callApiWithFiles(this.rawBody, messages, processedMessages, requestContents.files);
} else {
this.callServiceAPI(messages, processedMessages, requestContents.files);
Expand Down
17 changes: 12 additions & 5 deletions component/src/utils/HTTP/requestUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,20 @@ export class RequestUtils {

// need to pass stringifyBody boolean separately as binding is throwing an error for some reason
// prettier-ignore
public static async temporarilyRemoveHeader(requestSettings: Request | undefined,
request: (stringifyBody?: boolean) => Promise<void>, stringifyBody: boolean) {
public static async tempRemoveContentHeader(requestSettings: Request | undefined,
request: (stringifyBody?: boolean) => Promise<unknown>, stringifyBody: boolean) {
if (!requestSettings?.headers) throw new Error('Request settings have not been set up');
const previousHeader = requestSettings.headers[RequestUtils.CONTENT_TYPE];
const previousContetType = requestSettings.headers[RequestUtils.CONTENT_TYPE];
delete requestSettings.headers[RequestUtils.CONTENT_TYPE];
await request(stringifyBody);
requestSettings.headers[RequestUtils.CONTENT_TYPE] = previousHeader;
let result;
try {
result = await request(stringifyBody);
} catch (e) {
requestSettings.headers[RequestUtils.CONTENT_TYPE] = previousContetType;
throw e;
}
requestSettings.headers[RequestUtils.CONTENT_TYPE] = previousContetType;
return result;
}

public static displayError(messages: Messages, err: object, defMessage = 'Service error, please try again.') {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ export class SubmitButton extends InputButton<Styles> {
const data: Response = {};
if (userText !== '') data.text = userText;
if (uploadedFilesData) data.files = await this._messages.addMultipleFiles(uploadedFilesData);
if (this._serviceIO.sessionId) data.sessionId = this._serviceIO.sessionId;
if (Object.keys(data).length > 0) this._messages.addNewMessage(data, false);
}

Expand Down

0 comments on commit beed3b8

Please sign in to comment.