Skip to content

Commit

Permalink
Add embedding provider for VoyageAI.
Browse files Browse the repository at this point in the history
Fixes: #152
  • Loading branch information
palash25 committed Oct 31, 2024
1 parent 3d194a9 commit 8d563d2
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 4 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/extension_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ jobs:
CO_API_KEY: ${{ secrets.CO_API_KEY }}
PORTKEY_API_KEY: ${{ secrets.PORTKEY_API_KEY }}
PORTKEY_VIRTUAL_KEY_OPENAI: ${{ secrets.PORTKEY_VIRTUAL_KEY_OPENAI }}
VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }}
run: |
cd ../core && cargo test
- name: Restore cached binaries
Expand All @@ -132,6 +133,7 @@ jobs:
CO_API_KEY: ${{ secrets.CO_API_KEY }}
PORTKEY_API_KEY: ${{ secrets.PORTKEY_API_KEY }}
PORTKEY_VIRTUAL_KEY_OPENAI: ${{ secrets.PORTKEY_VIRTUAL_KEY_OPENAI }}
VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }}
run: |
echo "\q" | make run
make test-integration
Expand Down
4 changes: 4 additions & 0 deletions core/src/transformers/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod ollama;
pub mod openai;
pub mod portkey;
pub mod vector_serve;
pub mod voyage;

use anyhow::Result;
use async_trait::async_trait;
Expand Down Expand Up @@ -66,6 +67,9 @@ pub fn get_provider(
api_key,
virtual_key,
))),
ModelSource::Voyage => Ok(Box::new(providers::voyage::VoyageProvider::new(
url, api_key,
))),
ModelSource::SentenceTransformers => Ok(Box::new(
providers::vector_serve::VectorServeProvider::new(url, api_key),
)),
Expand Down
152 changes: 152 additions & 0 deletions core/src/transformers/providers/voyage.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
use reqwest::Client;
use serde::{Deserialize, Serialize};

use super::{EmbeddingProvider, GenericEmbeddingRequest, GenericEmbeddingResponse};
use crate::errors::VectorizeError;
use crate::transformers::http_handler::handle_response;
use async_trait::async_trait;
use std::env;

pub const VOYAGE_BASE_URL: &str = "https://api.voyageai.com/v1";

pub struct VoyageProvider {
pub url: String,
pub api_key: String,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct VoyageEmbeddingBody {
pub input: Vec<String>,
pub model: String,
pub input_type: String,
}

impl From<GenericEmbeddingRequest> for VoyageEmbeddingBody {
fn from(request: GenericEmbeddingRequest) -> Self {
VoyageEmbeddingBody {
input: request.input,
model: request.model,
input_type: "document".to_string(),
}
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct VoyageEmbeddingResponse {
pub data: Vec<EmbeddingObject>,
}

#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct EmbeddingObject {
pub embedding: Vec<f64>,
}

impl From<VoyageEmbeddingResponse> for GenericEmbeddingResponse {
fn from(response: VoyageEmbeddingResponse) -> Self {
GenericEmbeddingResponse {
embeddings: response.data.iter().map(|x| x.embedding.clone()).collect(),
}
}
}

impl VoyageProvider {
pub fn new(url: Option<String>, api_key: Option<String>) -> Self {
let final_url = match url {
Some(url) => url,
None => VOYAGE_BASE_URL.to_string(),
};
let final_api_key = match api_key {
Some(api_key) => api_key,
None => env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set"),
};
VoyageProvider {
url: final_url,
api_key: final_api_key,
}
}
}

#[async_trait]
impl EmbeddingProvider for VoyageProvider {
async fn generate_embedding<'a>(
&self,
request: &'a GenericEmbeddingRequest,
) -> Result<GenericEmbeddingResponse, VectorizeError> {
let client = Client::new();

let req_body = VoyageEmbeddingBody::from(request.clone());
let embedding_url = format!("{}/embeddings", self.url);

let response = client
.post(&embedding_url)
.timeout(std::time::Duration::from_secs(120_u64))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&req_body)
.send()
.await?;

let embeddings = handle_response::<VoyageEmbeddingResponse>(response, "embeddings").await?;
Ok(GenericEmbeddingResponse {
embeddings: embeddings
.data
.iter()
.map(|x| x.embedding.clone())
.collect(),
})
}

async fn model_dim(&self, model_name: &str) -> Result<u32, VectorizeError> {
Ok(voyager_embedding_dim(model_name) as u32)
}
}

pub fn voyager_embedding_dim(model_name: &str) -> i32 {
match model_name {
"voyage-3-lite" => 512,
"voyage-3" | "voyage-finance-2" | "voyage-multilingual-2" | "voyage-law-2" => 1024,
"voyage-code-2" => 1536,
// older models
"voyage-large-2" => 1536,
"voyage-large-2-instruct"
| "voyage-2"
| "voyage-lite-02-instruct"
| "voyage-02"
| "voyage-01"
| "voyage-lite-01"
| "voyage-lite-01-instruct" => 1024,
_ => 1536,
}
}

#[cfg(test)]
mod integration_tests {
use super::*;
use std::env;

#[tokio::test]
async fn test_voyage_ai_embedding() {
let api_key = Some(env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY must be set"));
let provider = VoyageProvider::new(Some(VOYAGE_BASE_URL.to_string()), api_key);

let request = GenericEmbeddingRequest {
input: vec!["hello world".to_string()],
model: "voyage-3-lite".to_string(),
};

let embeddings = provider.generate_embedding(&request).await.unwrap();
println!("{:?}", embeddings);
assert!(
!embeddings.embeddings.is_empty(),
"Embeddings should not be empty"
);
assert!(
embeddings.embeddings.len() == 1,
"Embeddings should have length 1"
);
assert!(
embeddings.embeddings[0].len() == 512,
"Embeddings should have dimension 512"
);
}
}
14 changes: 14 additions & 0 deletions core/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ impl Model {
ModelSource::Tembo => self.name.clone(),
ModelSource::Cohere => self.name.clone(),
ModelSource::Portkey => self.name.clone(),
ModelSource::Voyage => self.name.clone(),
}
}
}
Expand Down Expand Up @@ -237,6 +238,7 @@ pub enum ModelSource {
Tembo,
Cohere,
Portkey,
Voyage,
}

impl FromStr for ModelSource {
Expand All @@ -250,6 +252,7 @@ impl FromStr for ModelSource {
"tembo" => Ok(ModelSource::Tembo),
"cohere" => Ok(ModelSource::Cohere),
"portkey" => Ok(ModelSource::Portkey),
"voyage" => Ok(ModelSource::Voyage),
_ => Ok(ModelSource::SentenceTransformers),
}
}
Expand All @@ -264,6 +267,7 @@ impl Display for ModelSource {
ModelSource::Tembo => write!(f, "tembo"),
ModelSource::Cohere => write!(f, "cohere"),
ModelSource::Portkey => write!(f, "portkey"),
ModelSource::Voyage => write!(f, "voyage"),
}
}
}
Expand All @@ -277,6 +281,7 @@ impl From<String> for ModelSource {
"tembo" => ModelSource::Tembo,
"cohere" => ModelSource::Cohere,
"portkey" => ModelSource::Portkey,
"voyage" => ModelSource::Voyage,
// other cases are assumed to be private sentence-transformer compatible model
// and can be hot-loaded
_ => ModelSource::SentenceTransformers,
Expand All @@ -298,6 +303,15 @@ mod model_tests {
assert_eq!(model.api_name(), "text-embedding-ada-002");
}

#[test]
fn test_voyage_parsing() {
let model = Model::new("voyage/voyage-3-lite").unwrap();
assert_eq!(model.source, ModelSource::Voyage);
assert_eq!(model.fullname, "voyage/voyage-3-lite");
assert_eq!(model.name, "voyage-3-lite");
assert_eq!(model.api_name(), "voyage-3-lite");
}

#[test]
fn test_tembo_parsing() {
let model = Model::new("tembo/meta-llama/Meta-Llama-3-8B-Instruct").unwrap();
Expand Down
5 changes: 4 additions & 1 deletion extension/src/chat/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ pub fn call_chat(
ModelSource::Portkey => {
get_bpe_from_model(&chat_model.name).expect("failed to get BPE from model")
}
ModelSource::Voyage => {
get_bpe_from_model(&chat_model.name).expect("failed to get BPE from model")
}
};

// can only be 1 column in a chat job, for now, so safe to grab first element
Expand Down Expand Up @@ -190,7 +193,7 @@ pub fn call_chat_completions(
.generate_response(model.api_name(), &messages)
.await
}
ModelSource::SentenceTransformers | ModelSource::Cohere => {
ModelSource::SentenceTransformers | ModelSource::Cohere | ModelSource::Voyage => {
error!("SentenceTransformers and Cohere not yet supported for chat completions")
}
}
Expand Down
29 changes: 29 additions & 0 deletions extension/src/guc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ pub static COHERE_API_KEY: GucSetting<Option<&CStr>> = GucSetting::<Option<&CStr
pub static PORTKEY_API_KEY: GucSetting<Option<&CStr>> = GucSetting::<Option<&CStr>>::new(None);
pub static PORTKEY_VIRTUAL_KEY: GucSetting<Option<&CStr>> = GucSetting::<Option<&CStr>>::new(None);
pub static PORTKEY_SERVICE_URL: GucSetting<Option<&CStr>> = GucSetting::<Option<&CStr>>::new(None);
pub static VOYAGE_API_KEY: GucSetting<Option<&CStr>> = GucSetting::<Option<&CStr>>::new(None);
pub static VOYAGE_SERVICE_URL: GucSetting<Option<&CStr>> = GucSetting::<Option<&CStr>>::new(None);

// initialize GUCs
pub fn init_guc() {
Expand Down Expand Up @@ -175,6 +177,24 @@ pub fn init_guc() {
GucContext::Suset,
GucFlags::default(),
);

GucRegistry::define_string_guc(
"vectorize.voyage_service_url",
"Base url for the Voyage AI platform",
"Base url for the Voyage AI platform",
&VOYAGE_SERVICE_URL,
GucContext::Suset,
GucFlags::default(),
);

GucRegistry::define_string_guc(
"vectorize.voyage_api_key",
"API Key for the Voyage AI platform",
"API Key for the Voyage AI platform",
&VOYAGE_API_KEY,
GucContext::Suset,
GucFlags::default(),
);
}

// for handling of GUCs that can be error prone
Expand All @@ -193,6 +213,8 @@ pub enum VectorizeGuc {
PortkeyApiKey,
PortkeyVirtualKey,
PortkeyServiceUrl,
VoyageApiKey,
VoyageServiceUrl,
}

/// a convenience function to get this project's GUCs
Expand All @@ -211,6 +233,8 @@ pub fn get_guc(guc: VectorizeGuc) -> Option<String> {
VectorizeGuc::PortkeyApiKey => PORTKEY_API_KEY.get(),
VectorizeGuc::PortkeyVirtualKey => PORTKEY_VIRTUAL_KEY.get(),
VectorizeGuc::PortkeyServiceUrl => PORTKEY_SERVICE_URL.get(),
VectorizeGuc::VoyageApiKey => VOYAGE_API_KEY.get(),
VectorizeGuc::VoyageServiceUrl => VOYAGE_SERVICE_URL.get(),
};
if let Some(cstr) = val {
if let Ok(s) = handle_cstr(cstr) {
Expand Down Expand Up @@ -273,5 +297,10 @@ pub fn get_guc_configs(model_source: &ModelSource) -> ModelGucConfig {
service_url: get_guc(VectorizeGuc::PortkeyServiceUrl),
virtual_key: get_guc(VectorizeGuc::PortkeyVirtualKey),
},
ModelSource::Voyage => ModelGucConfig {
api_key: get_guc(VectorizeGuc::VoyageApiKey),
service_url: get_guc(VectorizeGuc::VoyageServiceUrl),
virtual_key: None,
},
}
}
4 changes: 1 addition & 3 deletions extension/src/job.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ use crate::util;
use pgrx::prelude::*;
use tiktoken_rs::cl100k_base;
use vectorize_core::transformers::types::Inputs;
use vectorize_core::types::{
IndexDist, JobMessage, JobParams, Model, SimilarityAlg, TableMethod, VectorizeMeta,
};
use vectorize_core::types::{IndexDist, JobMessage, JobParams, Model, TableMethod, VectorizeMeta};

/// called by the trigger function when a table is updated
/// handles enqueueing the embedding transform jobs
Expand Down

0 comments on commit 8d563d2

Please sign in to comment.