diff --git a/extension/src/api.rs b/extension/src/api.rs index d2efa9c..a37db90 100644 --- a/extension/src/api.rs +++ b/extension/src/api.rs @@ -167,3 +167,40 @@ fn env_interpolate_guc(guc_name: &str) -> Result { .unwrap_or_else(|| panic!("no value set for guc: {guc_name}")); env_interpolate_string(&g) } + +/// Recursive split text based on separators +#[pg_extern] +fn chunk_text(text: &str, chunk_size: i32, chunk_overlap: i32) -> Vec { + let separators = vec!["\n\n", "\n", " ", ""]; + let chunk_size = chunk_size as usize; + let chunk_overlap = chunk_overlap as usize; + + let mut chunks = Vec::new(); + let mut start = 0; + + while start < text.len() { + let mut end = std::cmp::min(start + chunk_size, text.len()); + let mut found_separator = false; + + // Try to split the text based on the separators + for sep in separators { + if let Some(pos) = text[start..end].rfind(sep) { + if pos > 0 { + end = start + pos; + chunks.push(chunk.clone()); + found_separator = true; + break; + } + } + } + + // Fallback if no suitable separator is found, chunk by size + if !chunk_found { + end = std::cmp::min(start + chunk_size, text.len()); + } + chunks.push(text[start..end].to_string()); + start = end.saturating_sub(chunk_overlap); + } + + chunks +} diff --git a/extension/src/util.rs b/extension/src/util.rs index fc2e6e9..4d7aa9f 100644 --- a/extension/src/util.rs +++ b/extension/src/util.rs @@ -2,7 +2,8 @@ use anyhow::Result; use pgrx::spi::SpiTupleTable; use pgrx::*; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; -use sqlx::{Pool, Postgres}; +use sqlx::{Pool, Postgres, Row}; +use sqlx::{Pool, Postgres, Row}; use std::env; use url::{ParseError, Url}; @@ -212,3 +213,17 @@ pub async fn ready(conn: &Pool) -> bool { .await .expect("failed") } + +/// Fetch rows from a given table and schema +pub async fn fetch_table_rows( + conn: &Pool, + table: &str, + columns: Vec, + schema: &str, +) -> Result> { + let query = format!("SELECT {} FROM {}.{}", columns.join(", "), schema, table); + + // Execute the query using sqlx and fetch the rows + let rows = sqlx::query(&query).fetch_all(conn).await?; + Ok(rows) +} diff --git a/extension/tests/integration_tests.rs b/extension/tests/integration_tests.rs index 8828b59..de9f7bd 100644 --- a/extension/tests/integration_tests.rs +++ b/extension/tests/integration_tests.rs @@ -52,6 +52,93 @@ async fn test_scheduled_job() { assert_eq!(result.rows_affected(), 3); } +#[ignore] +#[tokio::test] +async fn test_chunk_text() { + let conn = common::init_database().await; + + // Test case 1: Simple text chunking + let query = r#" + SELECT chunk_text('This is a test for chunking.', 10, 5); + "#; + let result: Vec = sqlx::query_scalar(query) + .fetch_all(&conn) + .await + .expect("failed to execute query"); + assert_eq!( + result, + vec![ + "This is a ".to_string(), + "is a test ".to_string(), + "test for c".to_string(), + "for chunki".to_string(), + "chunking.".to_string(), + ] + ); + + // Test case 2: Empty text + let query = r#" + SELECT chunk_text('', 10, 5); + "#; + let result: Vec = sqlx::query_scalar(query) + .fetch_all(&conn) + .await + .expect("failed to execute query"); + assert_eq!(result, vec![]); + + // Test case 3: Single short input + let query = r#" + SELECT chunk_text('Short', 10, 5); + "#; + let result: Vec = sqlx::query_scalar(query) + .fetch_all(&conn) + .await + .expect("failed to execute query"); + assert_eq!(result, vec!["Short".to_string()]); + + // Test case 4: Text with separators + let query = r#" + SELECT chunk_text('This\nis\na\ntest.', 5, 2); + "#; + let result: Vec = sqlx::query_scalar(query) + .fetch_all(&conn) + .await + .expect("failed to execute query"); + assert_eq!( + result, + vec![ + "This".to_string(), + "is".to_string(), + "a".to_string(), + "test.".to_string(), + ] + ); + + // Test case 5: Large input with overlap + let query = r#" + SELECT chunk_text('Lorem ipsum dolor sit amet, consectetur adipiscing elit.', 15, 5); + "#; + let result: Vec = sqlx::query_scalar(query) + .fetch_all(&conn) + .await + .expect("failed to execute query"); + assert_eq!( + result, + vec![ + "Lorem ipsum do".to_string(), + "ipsum dolor si".to_string(), + "dolor sit amet".to_string(), + "sit amet, cons".to_string(), + "amet, consecte".to_string(), + "consectetur ad".to_string(), + "adipiscing eli".to_string(), + "elit.".to_string(), + ] + ); + + println!("All chunk_text test cases passed!"); +} + #[ignore] #[tokio::test] async fn test_scheduled_single_table() {