diff --git a/component/src/services/openAI/openAIChatIO.ts b/component/src/services/openAI/openAIChatIO.ts index 8652d25b1..dce140d6f 100644 --- a/component/src/services/openAI/openAIChatIO.ts +++ b/component/src/services/openAI/openAIChatIO.ts @@ -149,7 +149,9 @@ export class OpenAIChatIO extends DirectServiceIO { if (!Array.isArray(handlerResponse)) { if (handlerResponse.text) { const response = {text: handlerResponse.text}; - return await this.deepChat.responseInterceptor?.(response) || response; + const processedResponse = await this.deepChat.responseInterceptor?.(response) || response; + if (Array.isArray(processedResponse)) throw Error(OpenAIUtils.FUNCTION_TOOL_RESP_ARR_ERROR); + return processedResponse; } throw Error(OpenAIUtils.FUNCTION_TOOL_RESP_ERROR); } diff --git a/component/src/services/openAI/utils/openAIUtils.ts b/component/src/services/openAI/utils/openAIUtils.ts index 01e2d9deb..bd5e70ea9 100644 --- a/component/src/services/openAI/utils/openAIUtils.ts +++ b/component/src/services/openAI/utils/openAIUtils.ts @@ -8,6 +8,7 @@ export class OpenAIUtils { public static readonly FUNCTION_TOOL_RESP_ERROR = 'Response object must either be {response: string}[] for each individual function ' + 'or {text: string} for a direct response, see https://deepchat.dev/docs/directConnection/OpenAI#FunctionHandler.'; + public static readonly FUNCTION_TOOL_RESP_ARR_ERROR = 'Arrays are not accepted in handler responses'; public static buildHeaders(key: string) { return { diff --git a/component/src/services/utils/baseServiceIO.ts b/component/src/services/utils/baseServiceIO.ts index 055e367cf..5b4435bed 100644 --- a/component/src/services/utils/baseServiceIO.ts +++ b/component/src/services/utils/baseServiceIO.ts @@ -145,7 +145,7 @@ export class BaseServiceIO implements ServiceIO { if (result.error) throw result.error; if (result.result) return Legacy.handleResponseProperty(result); // if invalid - process later in HTTPRequest.request - if (!RequestUtils.validateResponseFormat(result)) return undefined as unknown as Response; + if (!RequestUtils.validateResponseFormat(result, !!this.stream)) return undefined as unknown as Response; return result; } diff --git a/component/src/types/handler.ts b/component/src/types/handler.ts index e45e91a8e..1d18a42fd 100644 --- a/component/src/types/handler.ts +++ b/component/src/types/handler.ts @@ -2,7 +2,7 @@ import {Response} from './response'; export interface Signals { - onResponse: (response: Response) => Promise; + onResponse: (response: Response | Response[]) => Promise; onOpen: () => void; onClose: () => void; stopClicked: {listener: () => void}; diff --git a/component/src/types/interceptors.ts b/component/src/types/interceptors.ts index 531412755..e4c47d7a3 100644 --- a/component/src/types/interceptors.ts +++ b/component/src/types/interceptors.ts @@ -12,11 +12,10 @@ export type ResponseDetails = RequestDetails | {error: string}; export type RequestInterceptor = (details: RequestDetails) => ResponseDetails | Promise; -// not enabled for streaming // the response type is subject to what type of connection you are using: // if you are using a custom service via the 'connect' property - see Response // if you are directly connecting to an API via the 'directConnection' property - the response type will // dependend to the defined service // https://deepchat.dev/docs/interceptors#responseInterceptor // eslint-disable-next-line @typescript-eslint/no-explicit-any -export type ResponseInterceptor = (response: any) => Response | Promise; +export type ResponseInterceptor = (response: any) => Response | Response[] | Promise; diff --git a/component/src/utils/HTTP/HTTPRequest.ts b/component/src/utils/HTTP/HTTPRequest.ts index 79b8317d4..29eb80431 100644 --- a/component/src/utils/HTTP/HTTPRequest.ts +++ b/component/src/utils/HTTP/HTTPRequest.ts @@ -34,13 +34,14 @@ export class HTTPRequest { const resultData = await io.extractResultData(finalResult, fetchFunc, interceptedBody); // the reason why throwing here is to allow extractResultData to attempt extract error message and throw it if (!responseValid) throw result; - if (!resultData || typeof resultData !== 'object') + if (!resultData || (typeof resultData !== 'object' && !Array.isArray(resultData))) throw Error(ErrorMessages.INVALID_RESPONSE(result, 'response', !!io.deepChat.responseInterceptor, finalResult)); if (resultData.makingAnotherRequest) return; if (Stream.isSimulatable(io.stream, resultData)) { Stream.simulate(messages, io.streamHandlers, resultData); } else { - messages.addNewMessage(resultData); + const messageDataArr = Array.isArray(resultData) ? resultData : [resultData]; + messageDataArr.forEach((message) => messages.addNewMessage(message)); onFinish(); } }) diff --git a/component/src/utils/HTTP/customHandler.ts b/component/src/utils/HTTP/customHandler.ts index 9e1a5f859..0482e1649 100644 --- a/component/src/utils/HTTP/customHandler.ts +++ b/component/src/utils/HTTP/customHandler.ts @@ -18,23 +18,27 @@ export interface IWebsocketHandler { export class CustomHandler { public static async request(io: ServiceIO, body: RequestDetails['body'], messages: Messages) { let isHandlerActive = true; - const onResponse = async (response: Response) => { + const onResponse = async (response: Response | Response[]) => { if (!isHandlerActive) return; isHandlerActive = false; // need to set it here due to asynchronous code below const result = (await io.deepChat.responseInterceptor?.(response)) || response; - if (!RequestUtils.validateResponseFormat(result)) { + if (!RequestUtils.validateResponseFormat(result, !!io.stream)) { console.error(ErrorMessages.INVALID_RESPONSE(response, 'server', !!io.deepChat.responseInterceptor, result)); messages.addNewErrorMessage('service', 'Error in server message'); io.completionsHandlers.onFinish(); - } else if (typeof result.error === 'string') { - console.error(result.error); - messages.addNewErrorMessage('service', result.error); - io.completionsHandlers.onFinish(); - } else if (Stream.isSimulatable(io.stream, result)) { - Stream.simulate(messages, io.streamHandlers, result); } else { - messages.addNewMessage(result); - io.completionsHandlers.onFinish(); + const messageDataArr = Array.isArray(result) ? result : [result]; + const errorMessage = messageDataArr.find((message) => typeof message.error === 'string'); + if (errorMessage) { + console.error(errorMessage.error); + messages.addNewErrorMessage('service', errorMessage.error); + io.completionsHandlers.onFinish(); + } else if (Stream.isSimulatable(io.stream, result as Response)) { + Stream.simulate(messages, io.streamHandlers, result as Response); + } else { + messageDataArr.forEach((message) => messages.addNewMessage(message)); + io.completionsHandlers.onFinish(); + } } }; const signals = CustomHandler.generateOptionalSignals(); @@ -72,16 +76,15 @@ export class CustomHandler { io.streamHandlers.onClose(); isHandlerActive = false; }; - const onResponse = async (response: Response) => { + const onResponse = async (response: Response | Response[]) => { if (!isHandlerActive) return; - const result = (await io.deepChat.responseInterceptor?.(response)) || response; - if (!RequestUtils.validateResponseFormat(result)) { - console.error(ErrorMessages.INVALID_RESPONSE(response, 'server', !!io.deepChat.responseInterceptor, result)); + const result = ((await io.deepChat.responseInterceptor?.(response)) || response) as Response; // array not supported + if (!RequestUtils.validateResponseFormat(result, !!io.stream)) { + const errorMessage = ErrorMessages.INVALID_RESPONSE(response, 'server', !!io.deepChat.responseInterceptor, result); + CustomHandler.streamError(errorMessage, stream, io, messages); + isHandlerActive = false; } else if (result.error) { - console.error(result.error); - stream.finaliseStreamedMessage(); - messages.addNewErrorMessage('service', result.error); - io.streamHandlers.onClose(); + CustomHandler.streamError(result.error, stream, io, messages); isHandlerActive = false; } else { Stream.upsertWFiles(messages, stream.upsertStreamedMessage.bind(stream), stream, result); @@ -97,6 +100,13 @@ export class CustomHandler { {...signals, onOpen, onResponse, onClose, stopClicked: io.streamHandlers.stopClicked}); } + private static streamError(errorMessage: string, stream: MessageStream, io: ServiceIO, messages: Messages) { + console.error(errorMessage); + stream.finaliseStreamedMessage(); + messages.addNewErrorMessage('service', errorMessage); + io.streamHandlers.onClose(); + } + // prettier-ignore public static websocket(io: ServiceIO, messages: Messages) { const internalConfig = {isOpen: false, newUserMessage: {listener: () => {}}, roleToStream: {}}; @@ -108,21 +118,26 @@ export class CustomHandler { const onClose = () => { internalConfig.isOpen = false; }; - const onResponse = async (response: Response) => { + const onResponse = async (response: Response | Response[]) => { if (!internalConfig.isOpen) return; const result = (await io.deepChat.responseInterceptor?.(response)) || response; - if (!RequestUtils.validateResponseFormat(result)) { + if (!RequestUtils.validateResponseFormat(result, !!io.stream)) { console.error(ErrorMessages.INVALID_RESPONSE(response, 'server', !!io.deepChat.responseInterceptor, result)); messages.addNewErrorMessage('service', 'Error in server message'); - } else if (typeof result.error === 'string') { - console.error(result.error); - if (!messages.isLastMessageError()) messages.addNewErrorMessage('service', result.error); - } else if (Stream.isSimulation(io.stream)) { - const upsertFunc = Websocket.stream.bind(this, io, messages, internalConfig.roleToStream); - const stream = (internalConfig.roleToStream as RoleToStream)[response.role || MessageUtils.AI_ROLE]; - Stream.upsertWFiles(messages, upsertFunc, stream, response); } else { - messages.addNewMessage(result); + const messageDataArr = Array.isArray(result) ? result : [result]; + const errorMessage = messageDataArr.find((message) => typeof message.error === 'string'); + if (errorMessage) { + console.error(errorMessage.error); + if (!messages.isLastMessageError()) messages.addNewErrorMessage('service', errorMessage.error); + } else if (Stream.isSimulation(io.stream)) { + const message = result as Response; // array not supported for streaming + const upsertFunc = Websocket.stream.bind(this, io, messages, internalConfig.roleToStream); + const stream = (internalConfig.roleToStream as RoleToStream)[message.role || MessageUtils.AI_ROLE]; + Stream.upsertWFiles(messages, upsertFunc, stream, message); + } else { + messageDataArr.forEach((message) => messages.addNewMessage(message)); + } } }; const signals = CustomHandler.generateOptionalSignals(); diff --git a/component/src/utils/HTTP/requestUtils.ts b/component/src/utils/HTTP/requestUtils.ts index 68508542a..2021804c2 100644 --- a/component/src/utils/HTTP/requestUtils.ts +++ b/component/src/utils/HTTP/requestUtils.ts @@ -1,3 +1,4 @@ +import {ErrorMessages} from '../errorMessages/errorMessages'; import {Messages} from '../../views/chat/messages/messages'; import {Response as ResponseI} from '../../types/response'; import {RequestDetails} from '../../types/interceptors'; @@ -80,15 +81,24 @@ export class RequestUtils { return {body: resReqDetails.body, headers: resReqDetails.headers, error: resErrDetails.error}; } - public static validateResponseFormat(response: ResponseI) { - return ( - response && - typeof response === 'object' && - (typeof response.error === 'string' || - typeof response.text === 'string' || - typeof response.html === 'string' || - Array.isArray(response.files)) + public static validateResponseFormat(response: ResponseI | ResponseI[], isStreaming: boolean) { + if (!response) false; + const dataArr = Array.isArray(response) ? response : [response]; + if (isStreaming && dataArr.length > 1) { + console.error(ErrorMessages.INVALID_STREAM_ARRAY_RESPONSE); + return false; + } + const invalidFound = dataArr.find( + (data) => + typeof data !== 'object' || + !( + typeof data.error === 'string' || + typeof data.text === 'string' || + typeof data.html === 'string' || + Array.isArray(data.files) + ) ); + return !invalidFound; } public static onInterceptorError(messages: Messages, error: string, onFinish?: () => void) { diff --git a/component/src/utils/HTTP/websocket.ts b/component/src/utils/HTTP/websocket.ts index 62ff356a7..f4b35aac1 100644 --- a/component/src/utils/HTTP/websocket.ts +++ b/component/src/utils/HTTP/websocket.ts @@ -71,14 +71,15 @@ export class Websocket { const result: Response = JSON.parse(message.data); const finalResult = (await io.deepChat.responseInterceptor?.(result)) || result; const resultData = await io.extractResultData(finalResult); - if (!resultData || typeof resultData !== 'object') + if (!resultData || (typeof resultData !== 'object' && !Array.isArray(resultData))) throw Error(ErrorMessages.INVALID_RESPONSE(result, 'server', !!io.deepChat.responseInterceptor, finalResult)); if (Stream.isSimulation(io.stream)) { const upsertFunc = Websocket.stream.bind(this, io, messages, roleToStream); const stream = roleToStream[result.role || MessageUtils.AI_ROLE]; Stream.upsertWFiles(messages, upsertFunc, stream, resultData); } else { - messages.addNewMessage(resultData); + const messageDataArr = Array.isArray(resultData) ? resultData : [resultData]; + messageDataArr.forEach((data) => messages.addNewMessage(data)); } } catch (error) { RequestUtils.displayError(messages, error as object, 'Error in server message'); diff --git a/component/src/utils/demo/demo.ts b/component/src/utils/demo/demo.ts index 29a0a6880..eb9f29536 100644 --- a/component/src/utils/demo/demo.ts +++ b/component/src/utils/demo/demo.ts @@ -52,14 +52,16 @@ export class Demo { public static request(io: ServiceIO, messages: Messages) { const response = Demo.getResponse(messages); setTimeout(async () => { - const processedResponse = (await io.deepChat.responseInterceptor?.(response)) || response; - if (processedResponse.error) { - messages.addNewErrorMessage('service', processedResponse.error); + const finalResult = (await io.deepChat.responseInterceptor?.(response)) || response; + const messageDataArr = Array.isArray(finalResult) ? finalResult : [finalResult]; + const errorMessage = messageDataArr.find((message) => typeof message.error === 'string'); + if (errorMessage) { + messages.addNewErrorMessage('service', errorMessage.error); io.completionsHandlers.onFinish(); - } else if (Stream.isSimulatable(io.stream, processedResponse)) { - Stream.simulate(messages, io.streamHandlers, processedResponse); + } else if (Stream.isSimulatable(io.stream, finalResult as Response)) { + Stream.simulate(messages, io.streamHandlers, finalResult as Response); } else { - messages.addNewMessage(processedResponse); + messageDataArr.forEach((data) => messages.addNewMessage(data)); io.completionsHandlers.onFinish(); } }, 400); diff --git a/component/src/utils/errorMessages/errorMessages.ts b/component/src/utils/errorMessages/errorMessages.ts index 9f3dcbd3d..6f39518aa 100644 --- a/component/src/utils/errorMessages/errorMessages.ts +++ b/component/src/utils/errorMessages/errorMessages.ts @@ -52,6 +52,7 @@ export const ErrorMessages = { INVALID_RESPONSE: getInvalidResponseMessage, INVALID_MODEL_REQUEST: getModelRequestMessage, INVALID_MODEL_RESPONSE: getModelResponseMessage, + INVALID_STREAM_ARRAY_RESPONSE: 'Multi-response arrays are not supported for streaming', INVALID_STREAM_EVENT, INVALID_STREAM_EVENT_MIX: 'Cannot mix {text: string} and {html: string} responses.', NO_VALID_STREAM_EVENTS_SENT: `No valid stream events were sent.\n${INVALID_STREAM_EVENT}`, diff --git a/component/src/webModel/webModel.ts b/component/src/webModel/webModel.ts index 0743ac8cf..5d09382fd 100644 --- a/component/src/webModel/webModel.ts +++ b/component/src/webModel/webModel.ts @@ -207,7 +207,7 @@ export class WebModel extends BaseServiceIO { private async immediateResp(messages: Messages, text: string, chat: WebLLM.ChatInterface) { const output = {text: await chat.generate(text, undefined, 0)}; // anything but 1 will not stream const response = await WebModel.processResponse(this.deepChat, messages, output); - if (response) messages.addNewMessage(response); + if (response) response.forEach((data) => messages.addNewMessage(data)); this.completionsHandlers.onFinish(); } @@ -219,7 +219,7 @@ export class WebModel extends BaseServiceIO { const stream = new MessageStream(messages); await chat.generate(text, async (_: number, message: string) => { const response = await WebModel.processResponse(this.deepChat, messages, {text: message}); - if (response) stream.upsertStreamedMessage({text: response.text, overwrite: true}); + if (response) stream.upsertStreamedMessage({text: response[0].text, overwrite: true}); }); stream.finaliseStreamedMessage(); this.streamHandlers.onClose(); @@ -282,16 +282,26 @@ export class WebModel extends BaseServiceIO { } private static async processResponse(deepChat: DeepChat, messages: Messages, output: ResponseI) { - const result: ResponseI = (await deepChat.responseInterceptor?.(output)) || output; - if (result.error) { - RequestUtils.displayError(messages, new Error(result.error)); - return; - } else if (!result || !result.text) { - const error = ErrorMessages.INVALID_MODEL_RESPONSE(output, !!deepChat.responseInterceptor, result); - RequestUtils.displayError(messages, new Error(error)); + const result = (await deepChat.responseInterceptor?.(output)) || output; + if (deepChat.connect?.stream) + if (Array.isArray(result) && result.length > 1) { + console.error(ErrorMessages.INVALID_STREAM_ARRAY_RESPONSE); + return; + } + const messageDataArr = Array.isArray(result) ? result : [result]; + const errorMessage = messageDataArr.find((message) => typeof message.error === 'string'); + if (errorMessage) { + RequestUtils.displayError(messages, new Error(errorMessage.error)); return; + } else { + const errorMessage = messageDataArr.find((message) => !message || !message.text); + if (errorMessage) { + const error = ErrorMessages.INVALID_MODEL_RESPONSE(output, !!deepChat.responseInterceptor, result); + RequestUtils.displayError(messages, new Error(error)); + return; + } } - return result; + return messageDataArr; } override isWebModel() {