Skip to content

Commit

Permalink
Merge pull request #123 from dongri/refactoring-new-client
Browse files Browse the repository at this point in the history
Refactoring new client
  • Loading branch information
dongri authored Nov 10, 2024
2 parents b02db53 + b48f420 commit 9568281
Show file tree
Hide file tree
Showing 14 changed files with 100 additions and 57 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ $ export OPENAI_API_KEY=sk-xxxxxxx

### Create client
```rust
let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
let client = OpenAIClient::builder().with_api_key(api_key).build()?;
```

### Create request
Expand Down Expand Up @@ -57,7 +58,8 @@ use std::env;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
let client = OpenAIClient::builder().with_api_key(api_key).build()?;

let req = ChatCompletionRequest::new(
GPT4_O.to_string(),
Expand Down
3 changes: 2 additions & 1 deletion examples/assistant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use std::env;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
let client = OpenAIClient::builder().with_api_key(api_key).build()?;

let mut tools = HashMap::new();
tools.insert("type".to_string(), "code_interpreter".to_string());
Expand Down
3 changes: 2 additions & 1 deletion examples/audio_speech.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use std::env;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
let client = OpenAIClient::builder().with_api_key(api_key).build()?;

let req = AudioSpeechRequest::new(
TTS_1.to_string(),
Expand Down
3 changes: 2 additions & 1 deletion examples/audio_transcriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use std::env;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
let client = OpenAIClient::builder().with_api_key(api_key).build()?;

let req = AudioTranscriptionRequest::new(
"examples/data/problem.mp3".to_string(),
Expand Down
3 changes: 2 additions & 1 deletion examples/audio_translations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use std::env;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
let client = OpenAIClient::builder().with_api_key(api_key).build()?;

let req = AudioTranslationRequest::new(
"examples/data/problem_cn.mp3".to_string(),
Expand Down
3 changes: 2 additions & 1 deletion examples/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use std::str;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
let client = OpenAIClient::builder().with_api_key(api_key).build()?;

let req = FileUploadRequest::new(
"examples/data/batch_request.json".to_string(),
Expand Down
3 changes: 2 additions & 1 deletion examples/chat_completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use std::env;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
let client = OpenAIClient::builder().with_api_key(api_key).build()?;

let req = ChatCompletionRequest::new(
GPT4_O_MINI.to_string(),
Expand Down
3 changes: 2 additions & 1 deletion examples/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use std::env;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
let client = OpenAIClient::builder().with_api_key(api_key).build()?;

let req = CompletionRequest::new(
completion::GPT3_TEXT_DAVINCI_003.to_string(),
Expand Down
Binary file modified examples/data/problem.mp3
Binary file not shown.
3 changes: 2 additions & 1 deletion examples/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use std::env;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
let client = OpenAIClient::builder().with_api_key(api_key).build()?;

let mut req = EmbeddingRequest::new(
TEXT_EMBEDDING_3_SMALL.to_string(),
Expand Down
3 changes: 2 additions & 1 deletion examples/function_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ fn get_coin_price(coin: &str) -> f64 {

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
let client = OpenAIClient::builder().with_api_key(api_key).build()?;

let mut properties = HashMap::new();
properties.insert(
Expand Down
3 changes: 2 additions & 1 deletion examples/function_call_role.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ fn get_coin_price(coin: &str) -> f64 {

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
let client = OpenAIClient::builder().with_api_key(api_key).build()?;

let mut properties = HashMap::new();
properties.insert(
Expand Down
3 changes: 2 additions & 1 deletion examples/vision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use std::env;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
let client = OpenAIClient::builder().with_api_key(api_key).build()?;

let req = ChatCompletionRequest::new(
GPT4_O.to_string(),
Expand Down
118 changes: 74 additions & 44 deletions src/v1/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,73 +38,98 @@ use crate::v1::run::{
use crate::v1::thread::{CreateThreadRequest, ModifyThreadRequest, ThreadObject};

use bytes::Bytes;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use reqwest::multipart::{Form, Part};
use reqwest::{Client, Method, Response};
use serde::Serialize;
use serde_json::Value;

use std::error::Error;
use std::fs::{create_dir_all, File};
use std::io::Read;
use std::io::Write;
use std::path::Path;

const API_URL_V1: &str = "https://api.openai.com/v1";

#[derive(Default)]
pub struct OpenAIClientBuilder {
api_endpoint: Option<String>,
api_key: Option<String>,
organization: Option<String>,
proxy: Option<String>,
timeout: Option<u64>,
headers: Option<HeaderMap>,
}

pub struct OpenAIClient {
pub api_endpoint: String,
pub api_key: String,
pub organization: Option<String>,
pub proxy: Option<String>,
pub timeout: Option<u64>,
api_endpoint: String,
api_key: String,
organization: Option<String>,
proxy: Option<String>,
timeout: Option<u64>,
headers: Option<HeaderMap>,
}

impl OpenAIClient {
pub fn new(api_key: String) -> Self {
let endpoint = std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned());
Self::new_with_endpoint(endpoint, api_key)
impl OpenAIClientBuilder {
pub fn new() -> Self {
Self::default()
}

pub fn new_with_endpoint(api_endpoint: String, api_key: String) -> Self {
Self {
api_endpoint,
api_key,
organization: None,
proxy: None,
timeout: None,
}
pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self
}

pub fn new_with_organization(api_key: String, organization: String) -> Self {
let endpoint = std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned());
Self {
api_endpoint: endpoint,
api_key,
organization: Some(organization),
proxy: None,
timeout: None,
}
pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.api_endpoint = Some(endpoint.into());
self
}

pub fn new_with_proxy(api_key: String, proxy: String) -> Self {
let endpoint = std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned());
Self {
api_endpoint: endpoint,
api_key,
organization: None,
proxy: Some(proxy),
timeout: None,
}
pub fn with_organization(mut self, organization: impl Into<String>) -> Self {
self.organization = Some(organization.into());
self
}

pub fn with_proxy(mut self, proxy: impl Into<String>) -> Self {
self.proxy = Some(proxy.into());
self
}

pub fn with_timeout(mut self, timeout: u64) -> Self {
self.timeout = Some(timeout);
self
}

pub fn new_with_timeout(api_key: String, timeout: u64) -> Self {
let endpoint = std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned());
Self {
api_endpoint: endpoint,
pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
let headers = self.headers.get_or_insert_with(HeaderMap::new);
headers.insert(
HeaderName::from_bytes(key.into().as_bytes()).expect("Invalid header name"),
HeaderValue::from_str(&value.into()).expect("Invalid header value"),
);
self
}

pub fn build(self) -> Result<OpenAIClient, Box<dyn Error>> {
let api_key = self.api_key.ok_or("API key is required")?;
let api_endpoint = self.api_endpoint.unwrap_or_else(|| {
std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned())
});

Ok(OpenAIClient {
api_endpoint,
api_key,
organization: None,
proxy: None,
timeout: Some(timeout),
}
organization: self.organization,
proxy: self.proxy,
timeout: self.timeout,
headers: self.headers,
})
}
}

impl OpenAIClient {
pub fn builder() -> OpenAIClientBuilder {
OpenAIClientBuilder::new()
}

async fn build_request(&self, method: Method, path: &str) -> reqwest::RequestBuilder {
Expand All @@ -127,13 +152,18 @@ impl OpenAIClient {

let mut request = client
.request(method, url)
// .header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key));

if let Some(organization) = &self.organization {
request = request.header("openai-organization", organization);
}

if let Some(headers) = &self.headers {
for (key, value) in headers {
request = request.header(key, value);
}
}

if Self::is_beta(path) {
request = request.header("OpenAI-Beta", "assistants=v2");
}
Expand Down

0 comments on commit 9568281

Please sign in to comment.