From 5e1c31792d482daf3835eb02f14e6e4ccba4e05d Mon Sep 17 00:00:00 2001 From: Bhavya Jain Date: Thu, 17 Oct 2024 18:00:43 +0530 Subject: [PATCH] drop deprecated parameters from vectorize.table() --- extension/src/api.rs | 11 +- extension/src/search.rs | 310 +++++----------------------------------- 2 files changed, 36 insertions(+), 285 deletions(-) diff --git a/extension/src/api.rs b/extension/src/api.rs index 700cc6b..d2efa9c 100644 --- a/extension/src/api.rs +++ b/extension/src/api.rs @@ -21,8 +21,6 @@ fn table( update_col: default!(String, "'last_updated_at'"), index_dist_type: default!(types::IndexDist, "'pgv_hnsw_cosine'"), transformer: default!(&str, "'sentence-transformers/all-MiniLM-L6-v2'"), - // search_alg is now deprecated - search_alg: default!(types::SimilarityAlg, "'pgv_cosine_similarity'"), table_method: default!(types::TableMethod, "'join'"), // cron-like for a cron based update model, or 'realtime' for a trigger-based schedule: default!(&str, "'* * * * *'"), @@ -37,8 +35,6 @@ fn table( Some(update_col), index_dist_type.into(), &model, - // search_alg is now deprecated - search_alg.into(), table_method.into(), schedule, ) @@ -96,9 +92,6 @@ fn init_rag( index_dist_type: default!(types::IndexDist, "'pgv_hnsw_cosine'"), // transformer model to use in vector-search transformer: default!(&str, "'sentence-transformers/all-MiniLM-L6-v2'"), - // similarity algorithm to use in vector-search - // search_alg is now deprecated - search_alg: default!(types::SimilarityAlg, "'pgv_cosine_similarity'"), table_method: default!(types::TableMethod, "'join'"), schedule: default!(&str, "'* * * * *'"), ) -> Result { @@ -114,14 +107,12 @@ fn init_rag( None, index_dist_type.into(), &transformer_model, - // search_alg is now deprecated - search_alg.into(), table_method.into(), schedule, ) } -/// creates an table indexed with embeddings for chat completion workloads +/// creates a table indexed with embeddings for chat completion workloads #[pg_extern] fn rag( agent_name: &str, diff --git a/extension/src/search.rs b/extension/src/search.rs index e5bee96..ab444e8 100644 --- a/extension/src/search.rs +++ b/extension/src/search.rs @@ -1,18 +1,3 @@ -use crate::guc; -use crate::guc::get_guc_configs; -use crate::init; -use crate::job::{create_event_trigger, create_trigger_handler, initalize_table_job}; -use crate::transformers::openai; -use crate::transformers::transform; -use crate::util; - -use anyhow::{Context, Result}; -use pgrx::prelude::*; -use vectorize_core::transformers::providers::get_provider; -use vectorize_core::transformers::providers::ollama::check_model_host; -use vectorize_core::types::{self, Model, ModelSource, TableMethod, VectorizeMeta}; - -#[allow(clippy::too_many_arguments)] pub fn init_table( job_name: &str, schema: &str, @@ -22,26 +7,17 @@ pub fn init_table( update_col: Option, index_dist_type: types::IndexDist, transformer: &Model, - // search_alg is now deprecated - search_alg: types::SimilarityAlg, table_method: types::TableMethod, - // cron-like for a cron based update model, or 'realtime' for a trigger-based - schedule: &str, + schedule: &str, // cron-like or 'realtime' for trigger-based updates ) -> Result { - // validate table method - // realtime is only compatible with the join method if schedule == "realtime" && table_method != TableMethod::join { error!("realtime schedule is only compatible with the join table method"); } - // get prim key type let pkey_type = init::get_column_datatype(schema, table, primary_key)?; init::init_pgmq()?; let guc_configs = get_guc_configs(&transformer.source); - // validate API key where necessary and collect any optional arguments - // certain embedding services require an API key, e.g. openAI - // key can be set in a GUC, so if its required but not provided in args, and not in GUC, error let optional_args = match transformer.source { ModelSource::OpenAI => { openai::validate_api_key( @@ -52,29 +28,18 @@ pub fn init_table( )?; None } - ModelSource::Tembo => { - error!("Tembo not implemented for search yet"); - } + ModelSource::Tembo => error!("Tembo not implemented for search yet"), ModelSource::Ollama => { - let url = match guc::get_guc(guc::VectorizeGuc::OllamaServiceUrl) { - Some(k) => k, - None => { - error!("failed to get Ollama url from GUC"); - } - }; - let res = check_model_host(&url); - match res { - Ok(_) => { - info!("Model host active!"); - None - } - Err(e) => { - error!("Error with model host: {:?}", e) - } - } + let url = guc::get_guc(guc::VectorizeGuc::OllamaServiceUrl) + .context("Failed to get Ollama URL from GUC")?; + check_model_host(&url).context("Error with model host")?; + None } ModelSource::Portkey => Some(serde_json::json!({ - "virtual_key": guc_configs.virtual_key.clone().expect("Portkey virtual key is required") + "virtual_key": guc_configs + .virtual_key + .clone() + .context("Portkey virtual key is required")? })), _ => None, }; @@ -86,19 +51,15 @@ pub fn init_table( guc_configs.virtual_key.clone(), )?; - // synchronous let runtime = tokio::runtime::Builder::new_current_thread() .enable_io() .enable_time() .build() - .unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e)); - let model_dim = - match runtime.block_on(async { provider.model_dim(&transformer.api_name()).await }) { - Ok(e) => e, - Err(e) => { - error!("error getting model dim: {}", e); - } - }; + .context("Failed to initialize tokio runtime")?; + + let model_dim = runtime + .block_on(async { provider.model_dim(&transformer.api_name()).await }) + .context("Error getting model dimension")?; let valid_params = types::JobParams { schema: schema.to_string(), @@ -112,15 +73,13 @@ pub fn init_table( schedule: schedule.to_string(), args: optional_args, }; + let params = - pgrx::JsonB(serde_json::to_value(valid_params.clone()).expect("error serializing params")); + pgrx::JsonB(serde_json::to_value(&valid_params).context("Error serializing parameters")?); - // write job to table let init_job_q = init::init_job_query(); - // using SPI here because it is unlikely that this code will be run anywhere but inside the extension. - // background worker will likely be moved to an external container or service in near future - let ran: Result<_, spi::Error> = Spi::connect(|mut c| { - match c.update( + Spi::connect(|mut c| { + c.update( &init_job_q, None, Some(vec![ @@ -133,237 +92,38 @@ pub fn init_table( PgBuiltInOids::TEXTOID.oid(), transformer.to_string().into_datum(), ), - ( - PgBuiltInOids::TEXTOID.oid(), - // search_alg is now deprecated - search_alg.to_string().into_datum(), - ), (PgBuiltInOids::JSONBOID.oid(), params.into_datum()), ]), - ) { - Ok(_) => (), - Err(e) => { - error!("error creating job: {}", e); - } - } - Ok(()) - }); - ran?; + ) + .context("Error creating job") + })?; let init_embed_q = init::init_embedding_table_query(job_name, &valid_params, &index_dist_type, model_dim); - - let ran: Result<_, spi::Error> = Spi::connect(|mut c| { + Spi::connect(|mut c| { for q in init_embed_q { - let _r = c.update(&q, None, None)?; + c.update(&q, None, None) + .context("Error initializing embedding table")?; } Ok(()) - }); - if let Err(e) = ran { - error!("error creating embedding table: {}", e); - } + })?; + match schedule { "realtime" => { - // setup triggers - // create the trigger if not exists let trigger_handler = create_trigger_handler(job_name, &columns, primary_key); let insert_trigger = create_event_trigger(job_name, schema, table, "INSERT"); let update_trigger = create_event_trigger(job_name, schema, table, "UPDATE"); - let _: Result<_, spi::Error> = Spi::connect(|mut c| { - let _r = c.update(&trigger_handler, None, None)?; - let _r = c.update(&insert_trigger, None, None)?; - let _r = c.update(&update_trigger, None, None)?; + Spi::connect(|mut c| { + c.update(&trigger_handler, None, None)?; + c.update(&insert_trigger, None, None)?; + c.update(&update_trigger, None, None)?; Ok(()) - }); - } - _ => { - // initialize cron - init::init_cron(schedule, job_name)?; - log!("Initialized cron job"); + })?; } + _ => init::init_cron(schedule, job_name).context("Error initializing cron job")?, } + // start with initial batch load - initalize_table_job( - job_name, - &valid_params, - index_dist_type, - transformer, - // search_alg is now deprecated - search_alg, - )?; + initalize_table_job(job_name, &valid_params, index_dist_type, transformer)?; Ok(format!("Successfully created job: {job_name}")) } - -pub fn search( - job_name: &str, - query: &str, - api_key: Option, - return_columns: Vec, - num_results: i32, - where_clause: Option, -) -> Result> { - let project_meta: VectorizeMeta = util::get_vectorize_meta_spi(job_name)?; - let proj_params: types::JobParams = serde_json::from_value( - serde_json::to_value(project_meta.params).unwrap_or_else(|e| { - error!("failed to serialize metadata: {}", e); - }), - ) - .unwrap_or_else(|e| error!("failed to deserialize metadata: {}", e)); - - let proj_api_key = match api_key { - // if api passed in the function call, use that - Some(k) => Some(k), - // if not, use the one from the project metadata - None => proj_params.api_key.clone(), - }; - let embeddings = transform(query, &project_meta.transformer, proj_api_key); - - match project_meta.index_dist_type { - types::IndexDist::pgv_hnsw_l2 => error!("Not implemented."), - types::IndexDist::pgv_hnsw_ip => error!("Not implemented."), - types::IndexDist::pgv_hnsw_cosine | types::IndexDist::vsc_diskann_cosine => { - cosine_similarity_search( - job_name, - &proj_params, - &return_columns, - num_results, - &embeddings[0], - where_clause, - ) - } - } -} - -pub fn cosine_similarity_search( - project: &str, - job_params: &types::JobParams, - return_columns: &[String], - num_results: i32, - embeddings: &[f64], - where_clause: Option, -) -> Result> { - let schema = job_params.schema.clone(); - let table = job_params.table.clone(); - - // switch on table method - let query = match job_params.table_method { - TableMethod::append => single_table_cosine_similarity( - project, - &schema, - &table, - return_columns, - num_results, - where_clause, - ), - TableMethod::join => join_table_cosine_similarity( - project, - job_params, - return_columns, - num_results, - where_clause, - ), - }; - Spi::connect(|client| { - let mut results: Vec = Vec::new(); - let tup_table = client.select( - &query, - None, - Some(vec![( - PgBuiltInOids::FLOAT8ARRAYOID.oid(), - embeddings.into_datum(), - )]), - )?; - for row in tup_table { - match row["results"].value()? { - Some(r) => results.push(r), - None => error!("failed to get results"), - } - } - Ok(results) - }) -} - -fn join_table_cosine_similarity( - project: &str, - job_params: &types::JobParams, - return_columns: &[String], - num_results: i32, - where_clause: Option, -) -> String { - let schema = job_params.schema.clone(); - let table = job_params.table.clone(); - let join_key = &job_params.primary_key; - let cols = &return_columns - .iter() - .map(|s| format!("t0.{}", s)) - .collect::>() - .join(","); - - let where_str = if let Some(w) = where_clause { - prepare_filter(&w, join_key) - } else { - "".to_string() - }; - let inner_query = format!( - " - SELECT - {join_key}, - 1 - (embeddings <=> $1::vector) AS similarity_score - FROM vectorize._embeddings_{project} - ORDER BY similarity_score DESC - " - ); - format!( - " - SELECT to_jsonb(t) as results - FROM ( - SELECT {cols}, t1.similarity_score - FROM - ( - {inner_query} - ) t1 - INNER JOIN {schema}.{table} t0 on t0.{join_key} = t1.{join_key} - {where_str} - ) t - ORDER BY t.similarity_score DESC - LIMIT {num_results}; - " - ) -} - -fn single_table_cosine_similarity( - project: &str, - schema: &str, - table: &str, - return_columns: &[String], - num_results: i32, - where_clause: Option, -) -> String { - let where_str = if let Some(w) = where_clause { - format!("AND {}", w) - } else { - "".to_string() - }; - format!( - " - SELECT to_jsonb(t) as results - FROM ( - SELECT - 1 - ({project}_embeddings <=> $1::vector) AS similarity_score, - {cols} - FROM {schema}.{table} - WHERE {project}_updated_at is NOT NULL - {where_str} - ORDER BY similarity_score DESC - LIMIT {num_results} - ) t - ", - cols = return_columns.join(", "), - ) -} - -// transform user's where_sql into the format search query expects -fn prepare_filter(filter: &str, pkey: &str) -> String { - let wc = filter.replace(pkey, &format!("t0.{}", pkey)); - format!("AND {wc}") -}