Skip to content

Commit

Permalink
Merge pull request #64 from dongri/add-vision
Browse files Browse the repository at this point in the history
Add vision
  • Loading branch information
Dongri Jin authored Jan 7, 2024
2 parents c8d376e + f99645c commit c7fbe2b
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 7 deletions.
2 changes: 1 addition & 1 deletion examples/chat_completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
GPT4.to_string(),
vec![chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: String::from("What is Bitcoin?"),
content: chat_completion::Content::Text(String::from("What is bitcoin?")),
name: None,
}],
);
Expand Down
2 changes: 1 addition & 1 deletion examples/function_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
GPT3_5_TURBO_0613.to_string(),
vec![chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: String::from("What is the price of Ethereum?"),
content: chat_completion::Content::Text(String::from("What is the price of Ethereum?")),
name: None,
}],
)
Expand Down
10 changes: 6 additions & 4 deletions examples/function_call_role.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
GPT3_5_TURBO_0613.to_string(),
vec![chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: String::from("What is the price of Ethereum?"),
content: chat_completion::Content::Text(String::from("What is the price of Ethereum?")),
name: None,
}],
)
Expand Down Expand Up @@ -83,15 +83,17 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
vec![
chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: String::from("What is the price of Ethereum?"),
content: chat_completion::Content::Text(String::from(
"What is the price of Ethereum?",
)),
name: None,
},
chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::function,
content: {
content: chat_completion::Content::Text({
let price = get_coin_price(&coin);
format!("{{\"price\": {}}}", price)
},
}),
name: Some(String::from("get_coin_price")),
},
],
Expand Down
39 changes: 39 additions & 0 deletions examples/vision.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use openai_api_rs::v1::api::Client;
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest};
use openai_api_rs::v1::common::GPT4_VISION_PREVIEW;
use std::env;

fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string());

let req = ChatCompletionRequest::new(
GPT4_VISION_PREVIEW.to_string(),
vec![chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: chat_completion::Content::ImageUrl(vec![
chat_completion::ImageUrl {
r#type: chat_completion::ContentType::text,
text: Some(String::from("What’s in this image?")),
image_url: None,
},
chat_completion::ImageUrl {
r#type: chat_completion::ContentType::image_url,
text: None,
image_url: Some(chat_completion::ImageUrlType {
url: String::from(
"https://upload.wikimedia.org/wikipedia/commons/5/50/Bitcoin.png",
),
}),
},
]),
name: None,
}],
);

let result = client.chat_completion(req)?;
println!("{:?}", result.choices[0].message.content);

Ok(())
}

// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example chat_completion
43 changes: 42 additions & 1 deletion src/v1/chat_completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,51 @@ pub enum MessageRole {
function,
}

#[derive(Debug, Deserialize, Clone)]
pub enum Content {
Text(String),
ImageUrl(Vec<ImageUrl>),
}

impl serde::Serialize for Content {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match *self {
Content::Text(ref text) => serializer.serialize_str(text),
Content::ImageUrl(ref image_url) => image_url.serialize(serializer),
}
}
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[allow(non_camel_case_types)]
pub enum ContentType {
text,
image_url,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[allow(non_camel_case_types)]
pub struct ImageUrlType {
pub url: String,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[allow(non_camel_case_types)]
pub struct ImageUrl {
pub r#type: ContentType,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_url: Option<ImageUrlType>,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ChatCompletionMessage {
pub role: MessageRole,
pub content: String,
pub content: Content,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
Expand Down

0 comments on commit c7fbe2b

Please sign in to comment.