Skip to content

Commit

Permalink
delete instead of archive, batch OpenAI requests (#129)
Browse files Browse the repository at this point in the history
* delete instead of archive

* vectorscale optional

* break up large openai requests

* Delete core/src/transformers/debug-test.py
  • Loading branch information
ChuckHend authored Jul 16, 2024
1 parent 1e194c7 commit 1e47832
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 21 deletions.
58 changes: 41 additions & 17 deletions core/src/transformers/http_handler.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use anyhow::Result;

use crate::transformers::types::{
EmbeddingPayload, EmbeddingRequest, EmbeddingResponse, Inputs, PairedEmbeddings,
};
use anyhow::Result;
pub async fn handle_response<T: for<'de> serde::Deserialize<'de>>(
resp: reqwest::Response,
method: &'static str,
Expand All @@ -26,22 +25,47 @@ pub async fn openai_embedding_request(
timeout: i32,
) -> Result<Vec<Vec<f64>>> {
let client = reqwest::Client::new();
let mut req = client
.post(request.url)
.timeout(std::time::Duration::from_secs(timeout as u64))
.json::<EmbeddingPayload>(&request.payload)
.header("Content-Type", "application/json");
if let Some(key) = request.api_key {
req = req.header("Authorization", format!("Bearer {}", key));

// openai request size limit is 2048 inputs
let number_inputs = request.payload.input.len();
let todo_requests: Vec<EmbeddingPayload> = if number_inputs > 2048 {
split_vector(request.payload.input, 2048)
.iter()
.map(|chunk| EmbeddingPayload {
input: chunk.clone(),
model: request.payload.model.clone(),
})
.collect()
} else {
vec![request.payload]
};

let mut all_embeddings: Vec<Vec<f64>> = Vec::with_capacity(number_inputs);

for request_payload in todo_requests.iter() {
let mut req = client
.post(&request.url)
.timeout(std::time::Duration::from_secs(timeout as u64))
.json::<EmbeddingPayload>(request_payload)
.header("Content-Type", "application/json");
if let Some(key) = request.api_key.as_ref() {
req = req.header("Authorization", format!("Bearer {}", key));
}
let resp = req.send().await?;
let embedding_resp: EmbeddingResponse =
handle_response::<EmbeddingResponse>(resp, "embeddings").await?;
let embeddings: Vec<Vec<f64>> = embedding_resp
.data
.iter()
.map(|d| d.embedding.clone())
.collect();
all_embeddings.extend(embeddings);
}
let resp = req.send().await?;
let embedding_resp = handle_response::<EmbeddingResponse>(resp, "embeddings").await?;
let embeddings = embedding_resp
.data
.iter()
.map(|d| d.embedding.clone())
.collect();
Ok(embeddings)
Ok(all_embeddings)
}

fn split_vector(vec: Vec<String>, chunk_size: usize) -> Vec<Vec<String>> {
vec.chunks(chunk_size).map(|chunk| chunk.to_vec()).collect()
}

// merges the vec of inputs with the embedding responses
Expand Down
2 changes: 1 addition & 1 deletion extension/Trunk.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ description = "The simplest way to orchestrate vector search on Postgres."
homepage = "https://github.com/tembo-io/pg_vectorize"
documentation = "https://github.com/tembo-io/pg_vectorize"
categories = ["orchestration", "machine_learning"]
version = "0.16.0"
version = "0.17.0"

[build]
postgres_version = "15"
Expand Down
7 changes: 6 additions & 1 deletion extension/src/transformers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,13 @@ pub fn transform(input: &str, transformer: &Model, api_key: Option<String>) -> V
input: vec![input.to_string()],
model: transformer.name.to_string(),
};

let url = match guc::get_guc(guc::VectorizeGuc::OpenAIServiceUrl) {
Some(k) => k,
None => OPENAI_BASE_URL.to_string(),
};
EmbeddingRequest {
url: format!("{OPENAI_BASE_URL}/embeddings"),
url: format!("{url}/embeddings"),
payload: embedding_request,
api_key: Some(api_key.to_string()),
}
Expand Down
2 changes: 1 addition & 1 deletion extension/src/workers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub async fn run_worker(

// delete message from queue
if delete_it {
match queue.archive(queue_name, msg_id).await {
match queue.delete(queue_name, msg_id).await {
Ok(_) => {
info!("pg-vectorize: deleted message: {}", msg_id);
}
Expand Down
2 changes: 1 addition & 1 deletion extension/vectorize.control
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ module_pathname = '$libdir/vectorize'
relocatable = false
superuser = true
schema = 'vectorize'
requires = 'pg_cron,pgmq,vector,vectorscale'
requires = 'pg_cron,pgmq,vector'

0 comments on commit 1e47832

Please sign in to comment.