Skip to content

Commit

Permalink
fix signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
ChuckHend committed Jun 5, 2024
1 parent cbb65da commit 7da32d5
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 23 deletions.
8 changes: 4 additions & 4 deletions extension/sql/vectorize--0.15.1--0.16.0.sql
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
-- src/api.rs:157
-- src/api.rs:158
-- vectorize::api::generate
CREATE FUNCTION vectorize."generate"(
"input" TEXT, /* &str */
"model" TEXT DEFAULT 'openai/gpt-3.5-turbo', /* alloc::string::String */
"api_key" TEXT DEFAULT NULL /* core::option::Option<alloc::string::String> */
) RETURNS double precision[] /* core::result::Result<alloc::vec::Vec<f64>, anyhow::Error> */
) RETURNS TEXT /* core::result::Result<alloc::string::String, anyhow::Error> */
LANGUAGE c /* Rust */
AS 'MODULE_PATHNAME', 'generate_wrapper';

Expand All @@ -17,11 +17,11 @@ STRICT
LANGUAGE c /* Rust */
AS 'MODULE_PATHNAME', 'env_interpolate_guc_wrapper';

-- src/api.rs:78
-- src/api.rs:79
-- vectorize::api::encode
CREATE FUNCTION vectorize."encode"(
"input" TEXT, /* &str */
"model_name" TEXT DEFAULT 'openai/text-embedding-ada-002', /* alloc::string::String */
"model" TEXT DEFAULT 'openai/text-embedding-ada-002', /* alloc::string::String */
"api_key" TEXT DEFAULT NULL /* core::option::Option<alloc::string::String> */
) RETURNS double precision[] /* core::result::Result<alloc::vec::Vec<f64>, anyhow::Error> */
LANGUAGE c /* Rust */
Expand Down
19 changes: 12 additions & 7 deletions extension/src/api.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::chat::ops::call_chat;
use crate::chat::ops::{call_chat, get_chat_response};
use crate::chat::types::RenderedPrompt;
use crate::search::{self, init_table};
use crate::transformers::generic::env_interpolate_string;
use crate::transformers::transform;
Expand Down Expand Up @@ -77,10 +78,10 @@ fn transform_embeddings(
#[pg_extern]
fn encode(
input: &str,
model_name: default!(String, "'openai/text-embedding-ada-002'"),
model: default!(String, "'openai/text-embedding-ada-002'"),
api_key: default!(Option<String>, "NULL"),
) -> Result<Vec<f64>> {
let model = Model::new(&model_name)?;
let model = Model::new(&model)?;
Ok(transform(input, &model, api_key).remove(0))
}

Expand Down Expand Up @@ -158,9 +159,13 @@ fn generate(
input: &str,
model: default!(String, "'openai/gpt-3.5-turbo'"),
api_key: default!(Option<String>, "NULL"),
) -> Result<Vec<f64>> {
) -> Result<String> {
let model = Model::new(&model)?;
Ok(transform(input, &model, api_key).remove(0))
let prompt = RenderedPrompt {
sys_rendered: "".to_string(),
user_rendered: input.to_string(),
};
get_chat_response(prompt, &model, api_key)
}

#[pg_extern]
Expand All @@ -169,6 +174,6 @@ fn env_interpolate_guc(guc_name: &str) -> Result<String> {
"SELECT current_setting($1)",
vec![(PgBuiltInOids::TEXTOID.oid(), guc_name.into_datum())],
)?
.expect(&format!("no value set for guc: {guc_name}"));
Ok(env_interpolate_string(&g)?)
.unwrap_or_else(|| panic!("no value set for guc: {guc_name}"));
env_interpolate_string(&g)
}
24 changes: 15 additions & 9 deletions extension/src/chat/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,22 +123,28 @@ pub fn call_chat(
)?;

// http request to chat completions
let chat_response = match chat_model.source {
ModelSource::OpenAI | ModelSource::Tembo => {
call_chat_completions(rendered_prompt, chat_model, api_key)?
}
ModelSource::SentenceTransformers => {
error!("SentenceTransformers not supported for chat completions");
}
ModelSource::Ollama => call_ollama_chat_completions(rendered_prompt, &chat_model.name)?,
};
let chat_response = get_chat_response(rendered_prompt, chat_model, api_key)?;

Ok(ChatResponse {
context: search_results,
chat_response,
})
}

pub fn get_chat_response(
prompt: RenderedPrompt,
model: &Model,
api_key: Option<String>,
) -> Result<String> {
match model.source {
ModelSource::OpenAI | ModelSource::Tembo => call_chat_completions(prompt, model, api_key),
ModelSource::SentenceTransformers => {
error!("SentenceTransformers not supported for chat completions");
}
ModelSource::Ollama => call_ollama_chat_completions(prompt, &model.name),
}
}

fn render_user_message(user_prompt_template: &str, context: &str, query: &str) -> Result<String> {
let handlebars = Handlebars::new();
let render_vals = serde_json::json!({
Expand Down
6 changes: 3 additions & 3 deletions extension/src/transformers/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ pub fn get_env_interpolated_guc(requested: guc::VectorizeGuc) -> Result<String>
} else {
match requested {
guc::VectorizeGuc::EmbeddingServiceUrl => {
return Err(anyhow::anyhow!("vectorize.embedding_service_url not set"))
Err(anyhow::anyhow!("vectorize.embedding_service_url not set"))
}
guc::VectorizeGuc::OpenAIServiceUrl => {
return Err(anyhow::anyhow!("vectorize.openai_service_url not set"))
Err(anyhow::anyhow!("vectorize.openai_service_url not set"))
}
_ => return Err(anyhow::anyhow!("GUC not found")),
_ => Err(anyhow::anyhow!("GUC not found")),
}
}
}
Expand Down

0 comments on commit 7da32d5

Please sign in to comment.