Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chunking support to vectorize.table() #162

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions extension/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,40 @@ fn env_interpolate_guc(guc_name: &str) -> Result<String> {
.unwrap_or_else(|| panic!("no value set for guc: {guc_name}"));
env_interpolate_string(&g)
}

/// Recursive split text based on separators
#[pg_extern]
fn chunk_text(text: &str, chunk_size: i32, chunk_overlap: i32) -> Vec<String> {
let separators = vec!["\n\n", "\n", " ", ""];
let chunk_size = chunk_size as usize;
let chunk_overlap = chunk_overlap as usize;

let mut chunks = Vec::new();
let mut start = 0;

while start < text.len() {
let mut end = std::cmp::min(start + chunk_size, text.len());
let mut found_separator = false;

// Try to split the text based on the separators
for sep in separators {
if let Some(pos) = text[start..end].rfind(sep) {
if pos > 0 {
end = start + pos;
chunks.push(chunk.clone());
found_separator = true;
break;
}
}
}

// Fallback if no suitable separator is found, chunk by size
if !chunk_found {
end = std::cmp::min(start + chunk_size, text.len());
}
chunks.push(text[start..end].to_string());
start = end.saturating_sub(chunk_overlap);
}

chunks
}
17 changes: 16 additions & 1 deletion extension/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use anyhow::Result;
use pgrx::spi::SpiTupleTable;
use pgrx::*;
use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
use sqlx::{Pool, Postgres};
use sqlx::{Pool, Postgres, Row};
use sqlx::{Pool, Postgres, Row};
use std::env;
use url::{ParseError, Url};

Expand Down Expand Up @@ -212,3 +213,17 @@ pub async fn ready(conn: &Pool<Postgres>) -> bool {
.await
.expect("failed")
}

/// Fetch rows from a given table and schema
pub async fn fetch_table_rows(
conn: &Pool<Postgres>,
table: &str,
columns: Vec<String>,
schema: &str,
) -> Result<Vec<sqlx::postgres::PgRow>> {
let query = format!("SELECT {} FROM {}.{}", columns.join(", "), schema, table);

// Execute the query using sqlx and fetch the rows
let rows = sqlx::query(&query).fetch_all(conn).await?;
Ok(rows)
}
87 changes: 87 additions & 0 deletions extension/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,93 @@ async fn test_scheduled_job() {
assert_eq!(result.rows_affected(), 3);
}

#[ignore]
#[tokio::test]
async fn test_chunk_text() {
let conn = common::init_database().await;

// Test case 1: Simple text chunking
let query = r#"
SELECT chunk_text('This is a test for chunking.', 10, 5);
"#;
let result: Vec<String> = sqlx::query_scalar(query)
.fetch_all(&conn)
.await
.expect("failed to execute query");
assert_eq!(
result,
vec![
"This is a ".to_string(),
"is a test ".to_string(),
"test for c".to_string(),
"for chunki".to_string(),
"chunking.".to_string(),
]
);

// Test case 2: Empty text
let query = r#"
SELECT chunk_text('', 10, 5);
"#;
let result: Vec<String> = sqlx::query_scalar(query)
.fetch_all(&conn)
.await
.expect("failed to execute query");
assert_eq!(result, vec![]);

// Test case 3: Single short input
let query = r#"
SELECT chunk_text('Short', 10, 5);
"#;
let result: Vec<String> = sqlx::query_scalar(query)
.fetch_all(&conn)
.await
.expect("failed to execute query");
assert_eq!(result, vec!["Short".to_string()]);

// Test case 4: Text with separators
let query = r#"
SELECT chunk_text('This\nis\na\ntest.', 5, 2);
"#;
let result: Vec<String> = sqlx::query_scalar(query)
.fetch_all(&conn)
.await
.expect("failed to execute query");
assert_eq!(
result,
vec![
"This".to_string(),
"is".to_string(),
"a".to_string(),
"test.".to_string(),
]
);

// Test case 5: Large input with overlap
let query = r#"
SELECT chunk_text('Lorem ipsum dolor sit amet, consectetur adipiscing elit.', 15, 5);
"#;
let result: Vec<String> = sqlx::query_scalar(query)
.fetch_all(&conn)
.await
.expect("failed to execute query");
assert_eq!(
result,
vec![
"Lorem ipsum do".to_string(),
"ipsum dolor si".to_string(),
"dolor sit amet".to_string(),
"sit amet, cons".to_string(),
"amet, consecte".to_string(),
"consectetur ad".to_string(),
"adipiscing eli".to_string(),
"elit.".to_string(),
]
);

println!("All chunk_text test cases passed!");
}

#[ignore]
#[tokio::test]
async fn test_scheduled_single_table() {
Expand Down