Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch requests #15

Merged
merged 3 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 38 additions & 38 deletions .github/workflows/extension_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,44 +76,44 @@ jobs:
- name: Clippy
run: cargo clippy

# test:
# name: Run tests
# needs: dependencies
# runs-on: ubuntu-22.04
# steps:
# - uses: actions/checkout@v2
# - name: Install Rust stable toolchain
# uses: actions-rs/toolchain@v1
# with:
# toolchain: stable
# - uses: Swatinem/rust-cache@v2
# with:
# prefix-key: "pg-vectorize-extension-test"
# workspaces: pg-vectorize
# # Additional directories to cache
# cache-directories: /home/runner/.pgrx
# - uses: ./.github/actions/pgx-init
# with:
# working-directory: ./
# - name: Restore cached binaries
# uses: actions/cache@v2
# with:
# path: |
# /usr/local/bin/stoml
# ~/.cargo/bin/trunk
# key: ${{ runner.os }}-bins-${{ github.sha }}
# restore-keys: |
# ${{ runner.os }}-bins-
# - name: test
# run: |
# pgrx15_config=$(/usr/local/bin/stoml ~/.pgrx/config.toml configs.pg15)
# ~/.cargo/bin/trunk install pgvector --pg-config ${pgrx15_config}
# ~/.cargo/bin/trunk install pgmq --pg-config ${pgrx15_config}
# ~/.cargo/bin/trunk install pg_cron --pg-config ${pgrx15_config}
# rm -rf ./target/pgrx-test-data-* || true
# pg_version=$(/usr/local/bin/stoml Cargo.toml features.default)
# cargo pgrx run ${pg_version} --pgcli || true
# cargo pgrx test ${pg_version}
test:
name: Run tests
needs: dependencies
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v2
- name: Install Rust stable toolchain
uses: actions-rs/toolchain@v1
with:
toolchain: stable
- uses: Swatinem/rust-cache@v2
with:
prefix-key: "pg-vectorize-extension-test"
workspaces: pg-vectorize
# Additional directories to cache
cache-directories: /home/runner/.pgrx
- uses: ./.github/actions/pgx-init
with:
working-directory: ./
- name: Restore cached binaries
uses: actions/cache@v2
with:
path: |
/usr/local/bin/stoml
~/.cargo/bin/trunk
key: ${{ runner.os }}-bins-${{ github.sha }}
restore-keys: |
${{ runner.os }}-bins-
- name: test
run: |
pgrx15_config=$(/usr/local/bin/stoml ~/.pgrx/config.toml configs.pg15)
~/.cargo/bin/trunk install pgvector --pg-config ${pgrx15_config}
~/.cargo/bin/trunk install pgmq --pg-config ${pgrx15_config}
~/.cargo/bin/trunk install pg_cron --pg-config ${pgrx15_config}
rm -rf ./target/pgrx-test-data-* || true
pg_version=$(/usr/local/bin/stoml Cargo.toml features.default)
cargo pgrx run ${pg_version} --pgcli || true
cargo pgrx test ${pg_version}

publish:
if: github.event_name == 'release'
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "vectorize"
version = "0.1.0"
version = "0.1.1"
edition = "2021"
publish = false

Expand Down
2 changes: 1 addition & 1 deletion Trunk.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ description = "The simplest implementation of LLM-backed vector search on Postgr
homepage = "https://github.com/tembo-io/pg_vectorize"
documentation = "https://github.com/tembo-io/pg_vectorize"
categories = ["orchestration", "machine_learning"]
version = "0.1.0"
version = "0.1.1"

[build]
postgres_version = "15"
Expand Down
138 changes: 123 additions & 15 deletions src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,31 @@ pub struct ColumnJobParams {
pub table_method: TableMethod,
}

// creates batches based on total token count
// batch_size is the max token count per batch
fn create_batches(data: Vec<Inputs>, batch_size: i32) -> Vec<Vec<Inputs>> {
let mut groups: Vec<Vec<Inputs>> = Vec::new();
let mut current_group: Vec<Inputs> = Vec::new();
let mut current_token_count = 0;

for input in data {
if current_token_count + input.token_estimate > batch_size {
// Create a new group
groups.push(current_group);
current_group = Vec::new();
current_token_count = 0;
}
current_token_count += input.token_estimate;
current_group.push(input);
}

// Add any remaining inputs to the groups
if !current_group.is_empty() {
groups.push(current_group);
}
groups
}

// schema for all messages that hit pgmq
#[derive(Clone, Deserialize, Debug, Serialize)]
pub struct JobMessage {
Expand All @@ -87,10 +112,14 @@ fn job_execute(job_name: String) {
.build()
.unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e));

// TODO: move into a config
// 100k tokens per batch
let max_batch_size = 100000;

runtime.block_on(async {
let conn = get_pg_conn()
.await
.unwrap_or_else(|e| error!("pg-vectorize: failed to establsh db connection: {}", e));
.unwrap_or_else(|e| error!("pg-vectorize: failed to establish db connection: {}", e));
let queue = pgmq::PGMQueueExt::new_with_pool(conn.clone())
.await
.unwrap_or_else(|e| error!("failed to init db connection: {}", e));
Expand All @@ -106,19 +135,28 @@ fn job_execute(job_name: String) {
let new_or_updated_rows = get_new_updates_append(&conn, &job_name, job_params)
.await
.unwrap_or_else(|e| error!("failed to get new updates: {}", e));

match new_or_updated_rows {
Some(rows) => {
log!("num new records: {}", rows.len());
let msg = JobMessage {
job_name: job_name.clone(),
job_meta: meta.clone(),
inputs: rows,
};
let msg_id = queue
.send(PGMQ_QUEUE_NAME, &msg)
.await
.unwrap_or_else(|e| error!("failed to send message updates: {}", e));
log!("message sent: {}", msg_id);
let batches = create_batches(rows, max_batch_size);
log!(
"total batches: {}, max_batch_size: {}",
batches.len(),
max_batch_size
);
for b in batches {
let msg = JobMessage {
job_name: job_name.clone(),
job_meta: meta.clone(),
inputs: b,
};
let msg_id = queue
.send(PGMQ_QUEUE_NAME, &msg)
.await
.unwrap_or_else(|e| error!("failed to send message updates: {}", e));
log!("message sent: {}", msg_id);
}
}
None => {
log!("pg-vectorize: job: {}, no new records", job_name);
Expand Down Expand Up @@ -149,8 +187,9 @@ pub async fn get_vectorize_meta(

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Inputs {
pub record_id: String, // the value to join the record
pub inputs: String, // concatenation of input columns
pub record_id: String, // the value to join the record
pub inputs: String, // concatenation of input columns
pub token_estimate: i32, // estimated token count
}

// queries a table and returns rows that need new embeddings
Expand Down Expand Up @@ -188,9 +227,12 @@ pub async fn get_new_updates_append(
if !rows.is_empty() {
let mut new_inputs: Vec<Inputs> = Vec::new();
for r in rows {
let ipt: String = r.get("input_text");
let token_estimate = ipt.split_whitespace().count() as i32;
new_inputs.push(Inputs {
record_id: r.get("record_id"),
inputs: r.get("input_text"),
inputs: ipt,
token_estimate,
})
}
log!("pg-vectorize: num new inputs: {}", new_inputs.len());
Expand Down Expand Up @@ -239,9 +281,12 @@ pub async fn get_new_updates_shared(
match rows {
Ok(rows) => {
for r in rows {
let ipt: String = r.get("input_text");
let token_estimate = ipt.split_whitespace().count() as i32;
new_inputs.push(Inputs {
record_id: r.get("record_id"),
inputs: r.get("input_text"),
inputs: ipt,
token_estimate,
})
}
Ok(Some(new_inputs))
Expand All @@ -261,3 +306,66 @@ fn collapse_to_csv(strings: &[String]) -> String {
.collect::<Vec<_>>()
.join("|| ', ' ||")
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_create_batches_normal() {
let data = vec![
Inputs {
record_id: "1".to_string(),
inputs: "Test 1.".to_string(),
token_estimate: 2,
},
Inputs {
record_id: "2".to_string(),
inputs: "Test 2.".to_string(),
token_estimate: 2,
},
Inputs {
record_id: "3".to_string(),
inputs: "Test 3.".to_string(),
token_estimate: 3,
},
];

let batches = create_batches(data, 4);
assert_eq!(batches.len(), 2);
assert_eq!(batches[0].len(), 2);
assert_eq!(batches[1].len(), 1);
}

#[test]
fn test_create_batches_empty() {
let data: Vec<Inputs> = Vec::new();
let batches = create_batches(data, 4);
assert_eq!(batches.len(), 0);
}

#[test]
fn test_create_batches_large() {
let data = vec![
Inputs {
record_id: "1".to_string(),
inputs: "Test 1.".to_string(),
token_estimate: 2,
},
Inputs {
record_id: "2".to_string(),
inputs: "Test 2.".to_string(),
token_estimate: 2,
},
Inputs {
record_id: "3".to_string(),
inputs: "Test 3.".to_string(),
token_estimate: 100,
},
];
let batches = create_batches(data, 5);
assert_eq!(batches.len(), 2);
assert_eq!(batches[1].len(), 1);
assert_eq!(batches[1][0].token_estimate, 100);
}
}
1 change: 0 additions & 1 deletion src/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ pub async fn get_embeddings(inputs: &Vec<String>, key: &str) -> Result<Vec<Vec<f
Ok(embeddings)
}

// thanks Evan :D
pub async fn handle_response<T: for<'de> serde::Deserialize<'de>>(
resp: reqwest::Response,
method: &'static str,
Expand Down
23 changes: 11 additions & 12 deletions src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,17 @@ pub extern "C" fn background_worker_main(_arg: pg_sys::Datum) {
// on SIGHUP, you might want to reload configurations and env vars
}
let _: Result<()> = runtime.block_on(async {
let msg: Message<JobMessage> =
match queue.read::<JobMessage>(PGMQ_QUEUE_NAME, 300).await {
Ok(Some(msg)) => msg,
Ok(None) => {
log!("pg-vectorize: No messages in queue");
return Ok(());
}
Err(e) => {
warning!("pg-vectorize: Error reading message: {e}");
return Ok(());
}
};
let msg: Message<JobMessage> = match queue.pop::<JobMessage>(PGMQ_QUEUE_NAME).await {
Ok(Some(msg)) => msg,
Ok(None) => {
log!("pg-vectorize: No messages in queue");
return Ok(());
}
Err(e) => {
warning!("pg-vectorize: Error reading message: {e}");
return Ok(());
}
};

let msg_id = msg.msg_id;
let read_ct = msg.read_ct;
Expand Down
Loading