From f99645c37e4049a19ea1563f34d58ccba8dffde3 Mon Sep 17 00:00:00 2001 From: Dongri Jin Date: Sun, 7 Jan 2024 16:33:58 +0900 Subject: [PATCH] Add vision --- examples/chat_completion.rs | 2 +- examples/function_call.rs | 2 +- examples/function_call_role.rs | 10 ++++---- examples/vision.rs | 39 ++++++++++++++++++++++++++++++ src/v1/chat_completion.rs | 43 +++++++++++++++++++++++++++++++++- 5 files changed, 89 insertions(+), 7 deletions(-) create mode 100644 examples/vision.rs diff --git a/examples/chat_completion.rs b/examples/chat_completion.rs index b80a1b9..f5fd4a0 100644 --- a/examples/chat_completion.rs +++ b/examples/chat_completion.rs @@ -10,7 +10,7 @@ fn main() -> Result<(), Box> { GPT4.to_string(), vec![chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, - content: String::from("What is Bitcoin?"), + content: chat_completion::Content::Text(String::from("What is bitcoin?")), name: None, }], ); diff --git a/examples/function_call.rs b/examples/function_call.rs index 12e6a2d..8d2cea3 100644 --- a/examples/function_call.rs +++ b/examples/function_call.rs @@ -31,7 +31,7 @@ fn main() -> Result<(), Box> { GPT3_5_TURBO_0613.to_string(), vec![chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, - content: String::from("What is the price of Ethereum?"), + content: chat_completion::Content::Text(String::from("What is the price of Ethereum?")), name: None, }], ) diff --git a/examples/function_call_role.rs b/examples/function_call_role.rs index 9091da2..e8839ed 100644 --- a/examples/function_call_role.rs +++ b/examples/function_call_role.rs @@ -31,7 +31,7 @@ fn main() -> Result<(), Box> { GPT3_5_TURBO_0613.to_string(), vec![chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, - content: String::from("What is the price of Ethereum?"), + content: chat_completion::Content::Text(String::from("What is the price of Ethereum?")), name: None, }], ) @@ -83,15 +83,17 @@ fn main() -> Result<(), Box> { vec![ chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, - content: String::from("What is the price of Ethereum?"), + content: chat_completion::Content::Text(String::from( + "What is the price of Ethereum?", + )), name: None, }, chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::function, - content: { + content: chat_completion::Content::Text({ let price = get_coin_price(&coin); format!("{{\"price\": {}}}", price) - }, + }), name: Some(String::from("get_coin_price")), }, ], diff --git a/examples/vision.rs b/examples/vision.rs new file mode 100644 index 0000000..91e9617 --- /dev/null +++ b/examples/vision.rs @@ -0,0 +1,39 @@ +use openai_api_rs::v1::api::Client; +use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; +use openai_api_rs::v1::common::GPT4_VISION_PREVIEW; +use std::env; + +fn main() -> Result<(), Box> { + let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + + let req = ChatCompletionRequest::new( + GPT4_VISION_PREVIEW.to_string(), + vec![chat_completion::ChatCompletionMessage { + role: chat_completion::MessageRole::user, + content: chat_completion::Content::ImageUrl(vec![ + chat_completion::ImageUrl { + r#type: chat_completion::ContentType::text, + text: Some(String::from("What’s in this image?")), + image_url: None, + }, + chat_completion::ImageUrl { + r#type: chat_completion::ContentType::image_url, + text: None, + image_url: Some(chat_completion::ImageUrlType { + url: String::from( + "https://upload.wikimedia.org/wikipedia/commons/5/50/Bitcoin.png", + ), + }), + }, + ]), + name: None, + }], + ); + + let result = client.chat_completion(req)?; + println!("{:?}", result.choices[0].message.content); + + Ok(()) +} + +// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example chat_completion diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index 16715c7..fc4465c 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -98,10 +98,51 @@ pub enum MessageRole { function, } +#[derive(Debug, Deserialize, Clone)] +pub enum Content { + Text(String), + ImageUrl(Vec), +} + +impl serde::Serialize for Content { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match *self { + Content::Text(ref text) => serializer.serialize_str(text), + Content::ImageUrl(ref image_url) => image_url.serialize(serializer), + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[allow(non_camel_case_types)] +pub enum ContentType { + text, + image_url, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[allow(non_camel_case_types)] +pub struct ImageUrlType { + pub url: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[allow(non_camel_case_types)] +pub struct ImageUrl { + pub r#type: ContentType, + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub image_url: Option, +} + #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ChatCompletionMessage { pub role: MessageRole, - pub content: String, + pub content: Content, #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, }