From 939557796a37aecc4fff4ce5c3a90ebf22428c1a Mon Sep 17 00:00:00 2001 From: fegloff Date: Mon, 15 Jan 2024 18:01:01 -0500 Subject: [PATCH] update vision logic to comply with openai api --- src/modules/llms/index.ts | 4 +- src/modules/open-ai/api/openAi.ts | 13 ++--- src/modules/open-ai/helpers.ts | 4 +- src/modules/open-ai/index.ts | 86 +++++++++++++++++-------------- src/modules/open-ai/types.ts | 8 +++ src/modules/payment/index.ts | 1 - src/modules/types.ts | 9 +++- 7 files changed, 72 insertions(+), 53 deletions(-) diff --git a/src/modules/llms/index.ts b/src/modules/llms/index.ts index 35b381b1..a8087150 100644 --- a/src/modules/llms/index.ts +++ b/src/modules/llms/index.ts @@ -550,7 +550,7 @@ export class LlmsBot implements PayableBot { await ctx.api.editMessageText( ctx.chat.id, msgId, - response.completion.content + response.completion.content as string ) conversation.push(response.completion) // const price = getPromptPrice(completion, data); @@ -648,7 +648,7 @@ export class LlmsBot implements PayableBot { return } const chat: ChatConversation = { - content: limitPrompt(prompt), + content: limitPrompt(prompt as string), model } if (model === LlmsModelsEnum.BISON) { diff --git a/src/modules/open-ai/api/openAi.ts b/src/modules/open-ai/api/openAi.ts index a0b46b20..333b70c9 100644 --- a/src/modules/open-ai/api/openAi.ts +++ b/src/modules/open-ai/api/openAi.ts @@ -18,7 +18,7 @@ import { DalleGPTModels } from '../types' import type fs from 'fs' -import { type ChatCompletionCreateParamsNonStreaming } from 'openai/resources/chat/completions' +import { type ChatCompletionMessageParam, type ChatCompletionCreateParamsNonStreaming } from 'openai/resources/chat/completions' const openai = new OpenAI({ apiKey: config.openAiKey }) @@ -112,15 +112,12 @@ export async function chatCompletion ( model = config.openAi.chatGpt.model, limitTokens = true ): Promise { - const payload = { + const response = await openai.chat.completions.create({ model, max_tokens: limitTokens ? config.openAi.chatGpt.maxTokens : undefined, temperature: config.openAi.dalle.completions.temperature, - messages: conversation - } - const response = await openai.chat.completions.create( - payload as OpenAI.Chat.CompletionCreateParamsNonStreaming - ) + messages: conversation as ChatCompletionMessageParam[] + }) const chatModel = getChatModel(model) if (response.usage?.prompt_tokens === undefined) { throw new Error('Unknown number of prompt tokens used') @@ -149,7 +146,7 @@ export const streamChatCompletion = async ( let wordCountMinimum = 2 const stream = await openai.chat.completions.create({ model, - messages: conversation as OpenAI.Chat.Completions.CreateChatCompletionRequestMessage[], + messages: conversation as ChatCompletionMessageParam[], // OpenAI.Chat.Completions.CreateChatCompletionRequestMessage[], stream: true, max_tokens: limitTokens ? config.openAi.chatGpt.maxTokens : undefined, temperature: config.openAi.dalle.completions.temperature || 0.8 diff --git a/src/modules/open-ai/helpers.ts b/src/modules/open-ai/helpers.ts index a873124f..0e8a1e22 100644 --- a/src/modules/open-ai/helpers.ts +++ b/src/modules/open-ai/helpers.ts @@ -235,8 +235,8 @@ export const hasPrefix = (prompt: string): string => { export const getPromptPrice = (completion: string, data: ChatPayload): { price: number, promptTokens: number, completionTokens: number } => { const { conversation, ctx, model } = data - const prompt = conversation[conversation.length - 1].content - const promptTokens = getTokenNumber(prompt) + const prompt = data.prompt ? data.prompt : conversation[conversation.length - 1].content + const promptTokens = getTokenNumber(prompt as string) const completionTokens = getTokenNumber(completion) const modelPrice = getChatModel(model) const price = diff --git a/src/modules/open-ai/index.ts b/src/modules/open-ai/index.ts index c96ba904..bf76cc0b 100644 --- a/src/modules/open-ai/index.ts +++ b/src/modules/open-ai/index.ts @@ -14,12 +14,12 @@ import { } from '../types' import { alterGeneratedImg, + chatCompletion, getChatModel, getDalleModel, getDalleModelPrice, postGenerateImg, - streamChatCompletion, - streamChatVisionCompletion + streamChatCompletion } from './api/openAi' import { appText } from './utils/text' import { chatService } from '../../database/services' @@ -91,7 +91,7 @@ export class OpenAIBot implements PayableBot { try { const priceAdjustment = config.openAi.chatGpt.priceAdjustment const prompts = ctx.match - if (this.isSupportedImageReply(ctx)) { + if (this.isSupportedImageReply(ctx) && !isNaN(+prompts)) { const imageNumber = ctx.message?.caption || ctx.message?.text const imageSize = ctx.session.openAi.imageGen.imgSize const model = getDalleModel(imageSize) @@ -609,18 +609,6 @@ export class OpenAIBot implements PayableBot { } } - // imgInquiryWithVision = async ( - // img: string, - // prompt: string, - // ctx: OnMessageContext | OnCallBackQueryData - // ): Promise => { - // console.log(img, prompt) - // console.log('HELLO') - // const response = await openai.chat.completions.create(payLoad as unknown as ChatCompletionCreateParamsNonStreaming) - // console.log(response.choices[0].message?.content) - // return 'hi' - // } - onInquiryImage = async (photo: PhotoSize[] | undefined, prompt: string | undefined, ctx: OnMessageContext | OnCallBackQueryData): Promise => { try { if (ctx.session.openAi.imageGen.isEnabled) { @@ -639,30 +627,50 @@ export class OpenAIBot implements PayableBot { ctx.message?.reply_to_message?.message_thread_id }) ).message_id - const completion = await streamChatVisionCompletion([], ctx, 'gpt-4-vision-preview', prompt ?? '', filePath, msgId, true) - console.log(completion) - // const inquiry = await imgInquiryWithVision(filePath, prompt ?? '', ctx) - // console.log(inquiry) - // const imgSize = ctx.session.openAi.imageGen.imgSize - // ctx.chatAction = 'upload_photo' - // const imgs = await alterGeneratedImg(prompt ?? '', filePath, ctx, imgSize) - // if (imgs) { - // imgs.map(async (img: any) => { - // if (img?.url) { - // await ctx - // .replyWithPhoto(img.url, { message_thread_id: ctx.message?.message_thread_id }) - // .catch(async (e) => { - // await this.onError( - // ctx, - // e, - // MAX_TRIES, - // 'There was an error while generating the image' - // ) - // }) - // } - // }) - // } - // ctx.chatAction = null + const messages = [ + { + role: 'user', + content: [ + { type: 'text', text: prompt }, + { + type: 'image_url', + image_url: { url: filePath } + } + ] + } + ] + const model = ChatGPTModelsEnum.GPT_4_VISION_PREVIEW + const completion = await chatCompletion(messages as any, model, true) + if (completion) { + await ctx.api + .editMessageText(`${ctx.chat?.id}`, msgId, completion.completion) + .catch(async (e: any) => { + await this.onError( + ctx, + e, + MAX_TRIES, + 'An error occurred while generating the AI edit' + ) + }) + ctx.transient.analytics.sessionState = RequestState.Success + ctx.transient.analytics.actualResponseTime = now() + const price = getPromptPrice(completion.completion, { + conversation: [], + prompt, + model, + ctx + }) + this.logger.info( + `streamChatCompletion result = tokens: ${ + price.promptTokens + price.completionTokens + } | ${model} | price: ${price.price}ยข` + ) + if ( + !(await this.payments.pay(ctx as OnMessageContext, price.price)) + ) { + await this.onNotBalanceMessage(ctx) + } + } } } catch (e: any) { await this.onError( diff --git a/src/modules/open-ai/types.ts b/src/modules/open-ai/types.ts index 46a60c9e..3feb117e 100644 --- a/src/modules/open-ai/types.ts +++ b/src/modules/open-ai/types.ts @@ -16,6 +16,7 @@ export enum ChatGPTModelsEnum { GPT_4_32K = 'gpt-4-32k', GPT_35_TURBO = 'gpt-3.5-turbo', GPT_35_TURBO_16K = 'gpt-3.5-turbo-16k', + GPT_4_VISION_PREVIEW = 'gpt-4-vision-preview' } export const ChatGPTModels: Record = { @@ -46,6 +47,13 @@ export const ChatGPTModels: Record = { outputPrice: 0.004, maxContextTokens: 16000, chargeType: 'TOKEN' + }, + 'gpt-4-vision-preview': { + name: 'gpt-4-vision-preview', + inputPrice: 0.03, + outputPrice: 0.06, + maxContextTokens: 16000, + chargeType: 'TOKEN' } } diff --git a/src/modules/payment/index.ts b/src/modules/payment/index.ts index 0b8ba854..5f855685 100644 --- a/src/modules/payment/index.ts +++ b/src/modules/payment/index.ts @@ -388,7 +388,6 @@ export class BotPayments { public async pay (ctx: OnMessageContext, amountUSD: number): Promise { // eslint-disable-next-line @typescript-eslint/naming-convention const { from, message_id, chat } = ctx.update.message - const accountId = this.getAccountId(ctx) const userAccount = this.getUserAccount(accountId) if (!userAccount) { diff --git a/src/modules/types.ts b/src/modules/types.ts index 928c12b8..c50cc608 100644 --- a/src/modules/types.ts +++ b/src/modules/types.ts @@ -36,13 +36,20 @@ export interface ChatCompletion { } export interface ChatPayload { conversation: ChatConversation[] + prompt?: string model: string ctx: OnMessageContext | OnCallBackQueryData } + +export interface VisionContent { + type: string + text?: string + image_url?: { url: string } +} export interface ChatConversation { role?: string author?: string - content: string + content: string | [VisionContent] model?: string }