From 8d563d21e09b9f0ea8c7b657ded8452004693430 Mon Sep 17 00:00:00 2001 From: Palash Nigam Date: Sun, 27 Oct 2024 19:50:53 +0530 Subject: [PATCH] Add embedding provider for VoyageAI. Fixes: #152 --- .github/workflows/extension_ci.yml | 2 + core/src/transformers/providers/mod.rs | 4 + core/src/transformers/providers/voyage.rs | 152 ++++++++++++++++++++++ core/src/types.rs | 14 ++ extension/src/chat/ops.rs | 5 +- extension/src/guc.rs | 29 +++++ extension/src/job.rs | 4 +- 7 files changed, 206 insertions(+), 4 deletions(-) create mode 100644 core/src/transformers/providers/voyage.rs diff --git a/.github/workflows/extension_ci.yml b/.github/workflows/extension_ci.yml index 4ba7613..97100c7 100644 --- a/.github/workflows/extension_ci.yml +++ b/.github/workflows/extension_ci.yml @@ -109,6 +109,7 @@ jobs: CO_API_KEY: ${{ secrets.CO_API_KEY }} PORTKEY_API_KEY: ${{ secrets.PORTKEY_API_KEY }} PORTKEY_VIRTUAL_KEY_OPENAI: ${{ secrets.PORTKEY_VIRTUAL_KEY_OPENAI }} + VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }} run: | cd ../core && cargo test - name: Restore cached binaries @@ -132,6 +133,7 @@ jobs: CO_API_KEY: ${{ secrets.CO_API_KEY }} PORTKEY_API_KEY: ${{ secrets.PORTKEY_API_KEY }} PORTKEY_VIRTUAL_KEY_OPENAI: ${{ secrets.PORTKEY_VIRTUAL_KEY_OPENAI }} + VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }} run: | echo "\q" | make run make test-integration diff --git a/core/src/transformers/providers/mod.rs b/core/src/transformers/providers/mod.rs index b743c2a..700e3cc 100644 --- a/core/src/transformers/providers/mod.rs +++ b/core/src/transformers/providers/mod.rs @@ -3,6 +3,7 @@ pub mod ollama; pub mod openai; pub mod portkey; pub mod vector_serve; +pub mod voyage; use anyhow::Result; use async_trait::async_trait; @@ -66,6 +67,9 @@ pub fn get_provider( api_key, virtual_key, ))), + ModelSource::Voyage => Ok(Box::new(providers::voyage::VoyageProvider::new( + url, api_key, + ))), ModelSource::SentenceTransformers => Ok(Box::new( providers::vector_serve::VectorServeProvider::new(url, api_key), )), diff --git a/core/src/transformers/providers/voyage.rs b/core/src/transformers/providers/voyage.rs new file mode 100644 index 0000000..5e61562 --- /dev/null +++ b/core/src/transformers/providers/voyage.rs @@ -0,0 +1,152 @@ +use reqwest::Client; +use serde::{Deserialize, Serialize}; + +use super::{EmbeddingProvider, GenericEmbeddingRequest, GenericEmbeddingResponse}; +use crate::errors::VectorizeError; +use crate::transformers::http_handler::handle_response; +use async_trait::async_trait; +use std::env; + +pub const VOYAGE_BASE_URL: &str = "https://api.voyageai.com/v1"; + +pub struct VoyageProvider { + pub url: String, + pub api_key: String, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct VoyageEmbeddingBody { + pub input: Vec, + pub model: String, + pub input_type: String, +} + +impl From for VoyageEmbeddingBody { + fn from(request: GenericEmbeddingRequest) -> Self { + VoyageEmbeddingBody { + input: request.input, + model: request.model, + input_type: "document".to_string(), + } + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct VoyageEmbeddingResponse { + pub data: Vec, +} + +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct EmbeddingObject { + pub embedding: Vec, +} + +impl From for GenericEmbeddingResponse { + fn from(response: VoyageEmbeddingResponse) -> Self { + GenericEmbeddingResponse { + embeddings: response.data.iter().map(|x| x.embedding.clone()).collect(), + } + } +} + +impl VoyageProvider { + pub fn new(url: Option, api_key: Option) -> Self { + let final_url = match url { + Some(url) => url, + None => VOYAGE_BASE_URL.to_string(), + }; + let final_api_key = match api_key { + Some(api_key) => api_key, + None => env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set"), + }; + VoyageProvider { + url: final_url, + api_key: final_api_key, + } + } +} + +#[async_trait] +impl EmbeddingProvider for VoyageProvider { + async fn generate_embedding<'a>( + &self, + request: &'a GenericEmbeddingRequest, + ) -> Result { + let client = Client::new(); + + let req_body = VoyageEmbeddingBody::from(request.clone()); + let embedding_url = format!("{}/embeddings", self.url); + + let response = client + .post(&embedding_url) + .timeout(std::time::Duration::from_secs(120_u64)) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&req_body) + .send() + .await?; + + let embeddings = handle_response::(response, "embeddings").await?; + Ok(GenericEmbeddingResponse { + embeddings: embeddings + .data + .iter() + .map(|x| x.embedding.clone()) + .collect(), + }) + } + + async fn model_dim(&self, model_name: &str) -> Result { + Ok(voyager_embedding_dim(model_name) as u32) + } +} + +pub fn voyager_embedding_dim(model_name: &str) -> i32 { + match model_name { + "voyage-3-lite" => 512, + "voyage-3" | "voyage-finance-2" | "voyage-multilingual-2" | "voyage-law-2" => 1024, + "voyage-code-2" => 1536, + // older models + "voyage-large-2" => 1536, + "voyage-large-2-instruct" + | "voyage-2" + | "voyage-lite-02-instruct" + | "voyage-02" + | "voyage-01" + | "voyage-lite-01" + | "voyage-lite-01-instruct" => 1024, + _ => 1536, + } +} + +#[cfg(test)] +mod integration_tests { + use super::*; + use std::env; + + #[tokio::test] + async fn test_voyage_ai_embedding() { + let api_key = Some(env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY must be set")); + let provider = VoyageProvider::new(Some(VOYAGE_BASE_URL.to_string()), api_key); + + let request = GenericEmbeddingRequest { + input: vec!["hello world".to_string()], + model: "voyage-3-lite".to_string(), + }; + + let embeddings = provider.generate_embedding(&request).await.unwrap(); + println!("{:?}", embeddings); + assert!( + !embeddings.embeddings.is_empty(), + "Embeddings should not be empty" + ); + assert!( + embeddings.embeddings.len() == 1, + "Embeddings should have length 1" + ); + assert!( + embeddings.embeddings[0].len() == 512, + "Embeddings should have dimension 512" + ); + } +} diff --git a/core/src/types.rs b/core/src/types.rs index 29f05eb..b9f0eb2 100644 --- a/core/src/types.rs +++ b/core/src/types.rs @@ -160,6 +160,7 @@ impl Model { ModelSource::Tembo => self.name.clone(), ModelSource::Cohere => self.name.clone(), ModelSource::Portkey => self.name.clone(), + ModelSource::Voyage => self.name.clone(), } } } @@ -237,6 +238,7 @@ pub enum ModelSource { Tembo, Cohere, Portkey, + Voyage, } impl FromStr for ModelSource { @@ -250,6 +252,7 @@ impl FromStr for ModelSource { "tembo" => Ok(ModelSource::Tembo), "cohere" => Ok(ModelSource::Cohere), "portkey" => Ok(ModelSource::Portkey), + "voyage" => Ok(ModelSource::Voyage), _ => Ok(ModelSource::SentenceTransformers), } } @@ -264,6 +267,7 @@ impl Display for ModelSource { ModelSource::Tembo => write!(f, "tembo"), ModelSource::Cohere => write!(f, "cohere"), ModelSource::Portkey => write!(f, "portkey"), + ModelSource::Voyage => write!(f, "voyage"), } } } @@ -277,6 +281,7 @@ impl From for ModelSource { "tembo" => ModelSource::Tembo, "cohere" => ModelSource::Cohere, "portkey" => ModelSource::Portkey, + "voyage" => ModelSource::Voyage, // other cases are assumed to be private sentence-transformer compatible model // and can be hot-loaded _ => ModelSource::SentenceTransformers, @@ -298,6 +303,15 @@ mod model_tests { assert_eq!(model.api_name(), "text-embedding-ada-002"); } + #[test] + fn test_voyage_parsing() { + let model = Model::new("voyage/voyage-3-lite").unwrap(); + assert_eq!(model.source, ModelSource::Voyage); + assert_eq!(model.fullname, "voyage/voyage-3-lite"); + assert_eq!(model.name, "voyage-3-lite"); + assert_eq!(model.api_name(), "voyage-3-lite"); + } + #[test] fn test_tembo_parsing() { let model = Model::new("tembo/meta-llama/Meta-Llama-3-8B-Instruct").unwrap(); diff --git a/extension/src/chat/ops.rs b/extension/src/chat/ops.rs index cb52ad2..042a416 100644 --- a/extension/src/chat/ops.rs +++ b/extension/src/chat/ops.rs @@ -50,6 +50,9 @@ pub fn call_chat( ModelSource::Portkey => { get_bpe_from_model(&chat_model.name).expect("failed to get BPE from model") } + ModelSource::Voyage => { + get_bpe_from_model(&chat_model.name).expect("failed to get BPE from model") + } }; // can only be 1 column in a chat job, for now, so safe to grab first element @@ -190,7 +193,7 @@ pub fn call_chat_completions( .generate_response(model.api_name(), &messages) .await } - ModelSource::SentenceTransformers | ModelSource::Cohere => { + ModelSource::SentenceTransformers | ModelSource::Cohere | ModelSource::Voyage => { error!("SentenceTransformers and Cohere not yet supported for chat completions") } } diff --git a/extension/src/guc.rs b/extension/src/guc.rs index af1ec04..23dfa90 100644 --- a/extension/src/guc.rs +++ b/extension/src/guc.rs @@ -28,6 +28,8 @@ pub static COHERE_API_KEY: GucSetting> = GucSetting::> = GucSetting::>::new(None); pub static PORTKEY_VIRTUAL_KEY: GucSetting> = GucSetting::>::new(None); pub static PORTKEY_SERVICE_URL: GucSetting> = GucSetting::>::new(None); +pub static VOYAGE_API_KEY: GucSetting> = GucSetting::>::new(None); +pub static VOYAGE_SERVICE_URL: GucSetting> = GucSetting::>::new(None); // initialize GUCs pub fn init_guc() { @@ -175,6 +177,24 @@ pub fn init_guc() { GucContext::Suset, GucFlags::default(), ); + + GucRegistry::define_string_guc( + "vectorize.voyage_service_url", + "Base url for the Voyage AI platform", + "Base url for the Voyage AI platform", + &VOYAGE_SERVICE_URL, + GucContext::Suset, + GucFlags::default(), + ); + + GucRegistry::define_string_guc( + "vectorize.voyage_api_key", + "API Key for the Voyage AI platform", + "API Key for the Voyage AI platform", + &VOYAGE_API_KEY, + GucContext::Suset, + GucFlags::default(), + ); } // for handling of GUCs that can be error prone @@ -193,6 +213,8 @@ pub enum VectorizeGuc { PortkeyApiKey, PortkeyVirtualKey, PortkeyServiceUrl, + VoyageApiKey, + VoyageServiceUrl, } /// a convenience function to get this project's GUCs @@ -211,6 +233,8 @@ pub fn get_guc(guc: VectorizeGuc) -> Option { VectorizeGuc::PortkeyApiKey => PORTKEY_API_KEY.get(), VectorizeGuc::PortkeyVirtualKey => PORTKEY_VIRTUAL_KEY.get(), VectorizeGuc::PortkeyServiceUrl => PORTKEY_SERVICE_URL.get(), + VectorizeGuc::VoyageApiKey => VOYAGE_API_KEY.get(), + VectorizeGuc::VoyageServiceUrl => VOYAGE_SERVICE_URL.get(), }; if let Some(cstr) = val { if let Ok(s) = handle_cstr(cstr) { @@ -273,5 +297,10 @@ pub fn get_guc_configs(model_source: &ModelSource) -> ModelGucConfig { service_url: get_guc(VectorizeGuc::PortkeyServiceUrl), virtual_key: get_guc(VectorizeGuc::PortkeyVirtualKey), }, + ModelSource::Voyage => ModelGucConfig { + api_key: get_guc(VectorizeGuc::VoyageApiKey), + service_url: get_guc(VectorizeGuc::VoyageServiceUrl), + virtual_key: None, + }, } } diff --git a/extension/src/job.rs b/extension/src/job.rs index efe4980..cf312c7 100644 --- a/extension/src/job.rs +++ b/extension/src/job.rs @@ -8,9 +8,7 @@ use crate::util; use pgrx::prelude::*; use tiktoken_rs::cl100k_base; use vectorize_core::transformers::types::Inputs; -use vectorize_core::types::{ - IndexDist, JobMessage, JobParams, Model, SimilarityAlg, TableMethod, VectorizeMeta, -}; +use vectorize_core::types::{IndexDist, JobMessage, JobParams, Model, TableMethod, VectorizeMeta}; /// called by the trigger function when a table is updated /// handles enqueueing the embedding transform jobs