diff --git a/extension/sql/vectorize--0.15.1--0.16.0.sql b/extension/sql/vectorize--0.15.1--0.16.0.sql index 9dead79..0e822b8 100644 --- a/extension/sql/vectorize--0.15.1--0.16.0.sql +++ b/extension/sql/vectorize--0.15.1--0.16.0.sql @@ -1,10 +1,10 @@ --- src/api.rs:157 +-- src/api.rs:158 -- vectorize::api::generate CREATE FUNCTION vectorize."generate"( "input" TEXT, /* &str */ "model" TEXT DEFAULT 'openai/gpt-3.5-turbo', /* alloc::string::String */ "api_key" TEXT DEFAULT NULL /* core::option::Option */ -) RETURNS double precision[] /* core::result::Result, anyhow::Error> */ +) RETURNS TEXT /* core::result::Result */ LANGUAGE c /* Rust */ AS 'MODULE_PATHNAME', 'generate_wrapper'; @@ -17,11 +17,11 @@ STRICT LANGUAGE c /* Rust */ AS 'MODULE_PATHNAME', 'env_interpolate_guc_wrapper'; --- src/api.rs:78 +-- src/api.rs:79 -- vectorize::api::encode CREATE FUNCTION vectorize."encode"( "input" TEXT, /* &str */ - "model_name" TEXT DEFAULT 'openai/text-embedding-ada-002', /* alloc::string::String */ + "model" TEXT DEFAULT 'openai/text-embedding-ada-002', /* alloc::string::String */ "api_key" TEXT DEFAULT NULL /* core::option::Option */ ) RETURNS double precision[] /* core::result::Result, anyhow::Error> */ LANGUAGE c /* Rust */ diff --git a/extension/src/api.rs b/extension/src/api.rs index f62baa4..3f840f9 100644 --- a/extension/src/api.rs +++ b/extension/src/api.rs @@ -1,4 +1,5 @@ -use crate::chat::ops::call_chat; +use crate::chat::ops::{call_chat, get_chat_response}; +use crate::chat::types::RenderedPrompt; use crate::search::{self, init_table}; use crate::transformers::generic::env_interpolate_string; use crate::transformers::transform; @@ -77,10 +78,10 @@ fn transform_embeddings( #[pg_extern] fn encode( input: &str, - model_name: default!(String, "'openai/text-embedding-ada-002'"), + model: default!(String, "'openai/text-embedding-ada-002'"), api_key: default!(Option, "NULL"), ) -> Result> { - let model = Model::new(&model_name)?; + let model = Model::new(&model)?; Ok(transform(input, &model, api_key).remove(0)) } @@ -158,9 +159,13 @@ fn generate( input: &str, model: default!(String, "'openai/gpt-3.5-turbo'"), api_key: default!(Option, "NULL"), -) -> Result> { +) -> Result { let model = Model::new(&model)?; - Ok(transform(input, &model, api_key).remove(0)) + let prompt = RenderedPrompt { + sys_rendered: "".to_string(), + user_rendered: input.to_string(), + }; + get_chat_response(prompt, &model, api_key) } #[pg_extern] @@ -169,6 +174,6 @@ fn env_interpolate_guc(guc_name: &str) -> Result { "SELECT current_setting($1)", vec![(PgBuiltInOids::TEXTOID.oid(), guc_name.into_datum())], )? - .expect(&format!("no value set for guc: {guc_name}")); - Ok(env_interpolate_string(&g)?) + .unwrap_or_else(|| panic!("no value set for guc: {guc_name}")); + env_interpolate_string(&g) } diff --git a/extension/src/chat/ops.rs b/extension/src/chat/ops.rs index ffc1554..b579cc8 100644 --- a/extension/src/chat/ops.rs +++ b/extension/src/chat/ops.rs @@ -123,15 +123,7 @@ pub fn call_chat( )?; // http request to chat completions - let chat_response = match chat_model.source { - ModelSource::OpenAI | ModelSource::Tembo => { - call_chat_completions(rendered_prompt, chat_model, api_key)? - } - ModelSource::SentenceTransformers => { - error!("SentenceTransformers not supported for chat completions"); - } - ModelSource::Ollama => call_ollama_chat_completions(rendered_prompt, &chat_model.name)?, - }; + let chat_response = get_chat_response(rendered_prompt, chat_model, api_key)?; Ok(ChatResponse { context: search_results, @@ -139,6 +131,20 @@ pub fn call_chat( }) } +pub fn get_chat_response( + prompt: RenderedPrompt, + model: &Model, + api_key: Option, +) -> Result { + match model.source { + ModelSource::OpenAI | ModelSource::Tembo => call_chat_completions(prompt, model, api_key), + ModelSource::SentenceTransformers => { + error!("SentenceTransformers not supported for chat completions"); + } + ModelSource::Ollama => call_ollama_chat_completions(prompt, &model.name), + } +} + fn render_user_message(user_prompt_template: &str, context: &str, query: &str) -> Result { let handlebars = Handlebars::new(); let render_vals = serde_json::json!({ diff --git a/extension/src/transformers/generic.rs b/extension/src/transformers/generic.rs index 6fd5963..bf3cf28 100644 --- a/extension/src/transformers/generic.rs +++ b/extension/src/transformers/generic.rs @@ -9,12 +9,12 @@ pub fn get_env_interpolated_guc(requested: guc::VectorizeGuc) -> Result } else { match requested { guc::VectorizeGuc::EmbeddingServiceUrl => { - return Err(anyhow::anyhow!("vectorize.embedding_service_url not set")) + Err(anyhow::anyhow!("vectorize.embedding_service_url not set")) } guc::VectorizeGuc::OpenAIServiceUrl => { - return Err(anyhow::anyhow!("vectorize.openai_service_url not set")) + Err(anyhow::anyhow!("vectorize.openai_service_url not set")) } - _ => return Err(anyhow::anyhow!("GUC not found")), + _ => Err(anyhow::anyhow!("GUC not found")), } } }