diff --git a/core/src/transformers/http_handler.rs b/core/src/transformers/http_handler.rs index 48a8d00..b2f2bbb 100644 --- a/core/src/transformers/http_handler.rs +++ b/core/src/transformers/http_handler.rs @@ -1,8 +1,7 @@ -use anyhow::Result; - use crate::transformers::types::{ EmbeddingPayload, EmbeddingRequest, EmbeddingResponse, Inputs, PairedEmbeddings, }; +use anyhow::Result; pub async fn handle_response serde::Deserialize<'de>>( resp: reqwest::Response, method: &'static str, @@ -26,22 +25,47 @@ pub async fn openai_embedding_request( timeout: i32, ) -> Result>> { let client = reqwest::Client::new(); - let mut req = client - .post(request.url) - .timeout(std::time::Duration::from_secs(timeout as u64)) - .json::(&request.payload) - .header("Content-Type", "application/json"); - if let Some(key) = request.api_key { - req = req.header("Authorization", format!("Bearer {}", key)); + + // openai request size limit is 2048 inputs + let number_inputs = request.payload.input.len(); + let todo_requests: Vec = if number_inputs > 2048 { + split_vector(request.payload.input, 2048) + .iter() + .map(|chunk| EmbeddingPayload { + input: chunk.clone(), + model: request.payload.model.clone(), + }) + .collect() + } else { + vec![request.payload] + }; + + let mut all_embeddings: Vec> = Vec::with_capacity(number_inputs); + + for request_payload in todo_requests.iter() { + let mut req = client + .post(&request.url) + .timeout(std::time::Duration::from_secs(timeout as u64)) + .json::(request_payload) + .header("Content-Type", "application/json"); + if let Some(key) = request.api_key.as_ref() { + req = req.header("Authorization", format!("Bearer {}", key)); + } + let resp = req.send().await?; + let embedding_resp: EmbeddingResponse = + handle_response::(resp, "embeddings").await?; + let embeddings: Vec> = embedding_resp + .data + .iter() + .map(|d| d.embedding.clone()) + .collect(); + all_embeddings.extend(embeddings); } - let resp = req.send().await?; - let embedding_resp = handle_response::(resp, "embeddings").await?; - let embeddings = embedding_resp - .data - .iter() - .map(|d| d.embedding.clone()) - .collect(); - Ok(embeddings) + Ok(all_embeddings) +} + +fn split_vector(vec: Vec, chunk_size: usize) -> Vec> { + vec.chunks(chunk_size).map(|chunk| chunk.to_vec()).collect() } // merges the vec of inputs with the embedding responses diff --git a/extension/Trunk.toml b/extension/Trunk.toml index b946d5e..a473938 100644 --- a/extension/Trunk.toml +++ b/extension/Trunk.toml @@ -6,7 +6,7 @@ description = "The simplest way to orchestrate vector search on Postgres." homepage = "https://github.com/tembo-io/pg_vectorize" documentation = "https://github.com/tembo-io/pg_vectorize" categories = ["orchestration", "machine_learning"] -version = "0.16.0" +version = "0.17.0" [build] postgres_version = "15" diff --git a/extension/src/transformers/mod.rs b/extension/src/transformers/mod.rs index 76b14f7..c51a8a5 100644 --- a/extension/src/transformers/mod.rs +++ b/extension/src/transformers/mod.rs @@ -43,8 +43,13 @@ pub fn transform(input: &str, transformer: &Model, api_key: Option) -> V input: vec![input.to_string()], model: transformer.name.to_string(), }; + + let url = match guc::get_guc(guc::VectorizeGuc::OpenAIServiceUrl) { + Some(k) => k, + None => OPENAI_BASE_URL.to_string(), + }; EmbeddingRequest { - url: format!("{OPENAI_BASE_URL}/embeddings"), + url: format!("{url}/embeddings"), payload: embedding_request, api_key: Some(api_key.to_string()), } diff --git a/extension/src/workers/mod.rs b/extension/src/workers/mod.rs index 96fa00b..f3dac8a 100644 --- a/extension/src/workers/mod.rs +++ b/extension/src/workers/mod.rs @@ -51,7 +51,7 @@ pub async fn run_worker( // delete message from queue if delete_it { - match queue.archive(queue_name, msg_id).await { + match queue.delete(queue_name, msg_id).await { Ok(_) => { info!("pg-vectorize: deleted message: {}", msg_id); } diff --git a/extension/vectorize.control b/extension/vectorize.control index 41db2cf..6ff76dd 100644 --- a/extension/vectorize.control +++ b/extension/vectorize.control @@ -4,4 +4,4 @@ module_pathname = '$libdir/vectorize' relocatable = false superuser = true schema = 'vectorize' -requires = 'pg_cron,pgmq,vector,vectorscale' \ No newline at end of file +requires = 'pg_cron,pgmq,vector' \ No newline at end of file