Skip to content

Commit

Permalink
Ollama generation integration (#103)
Browse files Browse the repository at this point in the history
* Added support for Ollama chat completions

---------

Co-authored-by: Akshat Jaimini <[email protected]>
  • Loading branch information
ChuckHend and destrex271 authored Apr 19, 2024
1 parent 3ac2430 commit b9fa75b
Show file tree
Hide file tree
Showing 18 changed files with 263 additions and 20 deletions.
80 changes: 80 additions & 0 deletions .github/workflows/build-ollama-serve.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
name: Build and deploy ollama server

on:
push:
branches:
- main
- ollama-integration
paths:
- ".github/workflows/build-ollama-serve.yml"
- "ollama-serve/**"

pull_request:
branches:
- main
paths:
- ".github/workflows/build-ollama-serve.yml"
- "ollama-serve/**"

permissions:
id-token: write
contents: read

defaults:
run:
shell: bash
working-directory: ./ollama-serve/

jobs:
build_and_push:
name: Build and push images
runs-on:
- self-hosted
- dind
- large-8x8
outputs:
short_sha: ${{ steps.versions.outputs.SHORT_SHA }}
steps:
- name: Check out the repo
uses: actions/checkout@v3
- name: Set version strings
id: versions
run: |
echo "SHORT_SHA=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: Login to Quay
uses: docker/login-action@v2
with:
registry: quay.io/tembo
username: ${{ secrets.QUAY_USER_TEMBO }}
password: ${{ secrets.QUAY_PASSWORD_TEMBO }}

- name: Build and push -- Commit
# push a build for every commit
uses: docker/build-push-action@v5
with:
file: ./ollama-serve/Dockerfile
context: .
platforms: linux/amd64, linux/arm64
push: true
tags: |
quay.io/tembo/ollama-serve:${{ steps.versions.outputs.SHORT_SHA }}
- name: Build and push -- Latest
# only push latest off main
if: github.ref == 'refs/heads/main'
uses: docker/build-push-action@v5
with:
file: ./ollama-serve/Dockerfile
context: .
platforms: linux/amd64, linux/arm64
push: true
tags: |
quay.io/tembo/ollama-serve:latest
2 changes: 1 addition & 1 deletion .github/workflows/build-vector-serve.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Build and deploy server
name: Build and deploy embedding server

on:
push:
Expand Down
18 changes: 8 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ This project relies heavily on the work by [pgvector](https://github.com/pgvecto
[![PGXN version](https://badge.fury.io/pg/vectorize.svg)](https://pgxn.org/dist/vectorize/)
[![OSSRank](https://shields.io/endpoint?url=https://ossrank.com/shield/3815)](https://ossrank.com/p/3815)


pg_vectorize powers the [VectorDB Stack](https://tembo.io/docs/product/stacks/ai/vectordb) on [Tembo Cloud](https://cloud.tembo.io/) and is available in all hobby tier instances.

**API Documentation**: https://tembo-io.github.io/pg_vectorize/
Expand All @@ -34,15 +33,14 @@ pg_vectorize powers the [VectorDB Stack](https://tembo.io/docs/product/stacks/ai
- Integrations with OpenAI's [embeddings](https://platform.openai.com/docs/guides/embeddings) and [chat-completion](https://platform.openai.com/docs/guides/text-generation) endpoints and a self-hosted container for running [Hugging Face Sentence-Transformers](https://huggingface.co/sentence-transformers)
- Automated creation of Postgres triggers to keep your embeddings up to date
- High level API - one function to initialize embeddings transformations, and another function to search

## Table of Contents
- [Features](#features)
- [Table of Contents](#table-of-contents)
- [Installation](#installation)
- [Vector Search Example](#vector-search-example)
- [RAG Example](#rag-example)
- [Updating Embeddings](#updating-embeddings)
- [Try it on Tembo Cloud](#try-it-on-tembo-cloud)

## Installation

Expand Down Expand Up @@ -188,18 +186,18 @@ ADD COLUMN context TEXT GENERATED ALWAYS AS (product_name || ': ' || description

```sql
SELECT vectorize.init_rag(
agent_name => 'product_chat',
table_name => 'products',
"column" => 'context',
unique_record_id => 'product_id',
transformer => 'sentence-transformers/all-MiniLM-L12-v2'
agent_name => 'product_chat',
table_name => 'products',
"column" => 'context',
unique_record_id => 'product_id',
transformer => 'sentence-transformers/all-MiniLM-L12-v2'
);
```

```sql
SELECT vectorize.rag(
agent_name => 'product_chat',
query => 'What is a pencil?'
agent_name => 'product_chat',
query => 'What is a pencil?'
) -> 'chat_response';
```

Expand Down
2 changes: 2 additions & 0 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ chrono = {version = "0.4.26", features = ["serde"] }
env_logger = "0.11.3"
lazy_static = "1.4.0"
log = "0.4.21"
ollama-rs = "0.1.7"
pgmq = "0.26.1"
regex = "1.9.2"
reqwest = {version = "0.11.18", features = ["json"] }
Expand All @@ -32,3 +33,4 @@ sqlx = { version = "=0.7.3", features = [
thiserror = "1.0.44"
tiktoken-rs = "0.5.7"
tokio = {version = "1.29.1", features = ["rt-multi-thread"] }
url = "2.5.0"
1 change: 1 addition & 0 deletions core/src/transformers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod generic;
pub mod http_handler;
pub mod ollama;
pub mod openai;
pub mod types;
48 changes: 48 additions & 0 deletions core/src/transformers/ollama.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use anyhow::Result;
use ollama_rs::{generation::completion::request::GenerationRequest, Ollama};
use url::Url;

pub struct OllamaInstance {
pub model_name: String,
pub instance: Ollama,
}

pub trait LLMFunctions {
fn new(model_name: String, url: String) -> Self;
#[allow(async_fn_in_trait)]
async fn generate_reponse(&self, prompt_text: String) -> Result<String, String>;
}

impl LLMFunctions for OllamaInstance {
fn new(model_name: String, url: String) -> Self {
let parsed_url = Url::parse(&url).unwrap_or_else(|_| panic!("invalid url: {}", url));
let instance = Ollama::new(
format!(
"{}://{}",
parsed_url.scheme(),
parsed_url.host_str().expect("parsed url missing")
),
parsed_url.port().expect("parsed port missing"),
);
OllamaInstance {
model_name,
instance,
}
}
async fn generate_reponse(&self, prompt_text: String) -> Result<String, String> {
let req = GenerationRequest::new(self.model_name.clone(), prompt_text);
println!("ollama instance: {:?}", self.instance);
let res = self.instance.generate(req).await;
match res {
Ok(res) => Ok(res.response),
Err(e) => Err(e.to_string()),
}
}
}

pub fn ollama_embedding_dim(model_name: &str) -> i32 {
match model_name {
"llama2" => 5192,
_ => 1536,
}
}
11 changes: 11 additions & 0 deletions core/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,15 @@ impl fmt::Display for Model {
pub enum ModelSource {
OpenAI,
SentenceTransformers,
Ollama,
}

impl FromStr for ModelSource {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"ollama" => Ok(ModelSource::Ollama),
"openai" => Ok(ModelSource::OpenAI),
"sentence-transformers" => Ok(ModelSource::SentenceTransformers),
_ => Ok(ModelSource::SentenceTransformers),
Expand All @@ -251,6 +253,7 @@ impl FromStr for ModelSource {
impl Display for ModelSource {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
match self {
ModelSource::Ollama => write!(f, "ollama"),
ModelSource::OpenAI => write!(f, "openai"),
ModelSource::SentenceTransformers => write!(f, "sentence-transformers"),
}
Expand All @@ -260,6 +263,7 @@ impl Display for ModelSource {
impl From<String> for ModelSource {
fn from(s: String) -> Self {
match s.as_str() {
"ollama" => ModelSource::Ollama,
"openai" => ModelSource::OpenAI,
"sentence-transformers" => ModelSource::SentenceTransformers,
// other cases are assumed to be private sentence-transformer compatible model
Expand All @@ -274,6 +278,13 @@ impl From<String> for ModelSource {
mod model_tests {
use super::*;

#[test]
fn test_ollama_parsing() {
let model = Model::new("ollama/wizardlm2:7b").unwrap();
assert_eq!(model.source, ModelSource::Ollama);
assert_eq!(model.name, "wizardlm2:7b");
}

#[test]
fn test_legacy_fullname() {
let model = Model::new("text-embedding-ada-002").unwrap();
Expand Down
3 changes: 3 additions & 0 deletions core/src/worker/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub struct Config {
pub queue_name: String,
pub embedding_svc_url: String,
pub openai_api_key: Option<String>,
pub ollama_svc_url: String,
pub embedding_request_timeout: i32,
pub poll_interval: u64,
pub poll_interval_error: u64,
Expand All @@ -62,6 +63,7 @@ impl Config {
"http://localhost:3000/v1/embeddings",
),
openai_api_key: env::var("OPENAI_API_KEY").ok(),
ollama_svc_url: from_env_default("OLLAMA_SVC_URL", "http://localhost:3001"),
embedding_request_timeout: from_env_default("EMBEDDING_REQUEST_TIMEOUT", "6")
.parse()
.unwrap(),
Expand Down Expand Up @@ -96,6 +98,7 @@ async fn execute_job(
&msg.message.inputs,
cfg.openai_api_key.clone(),
)?,
ModelSource::Ollama => Err(anyhow::anyhow!("Ollama transformer not implemented yet"))?,
ModelSource::SentenceTransformers => generic::prepare_generic_embedding_request(
job_meta.clone(),
&msg.message.inputs,
Expand Down
8 changes: 6 additions & 2 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
version: '3.2'

services:
postgres:
restart: always
Expand All @@ -13,3 +11,9 @@ services:
image: quay.io/tembo/vector-serve:latest
ports:
- 3000:3000
ollama-serve:
image: quay.io/tembo/ollama-serve:latest
ports:
- 3001:3001
environment:
- OLLAMA_HOST=0.0.0.0:3001
47 changes: 46 additions & 1 deletion extension/src/chat/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use handlebars::Handlebars;
use openai_api_rs::v1::api::Client;
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest};
use pgrx::prelude::*;
use vectorize_core::transformers::ollama::LLMFunctions;
use vectorize_core::transformers::ollama::OllamaInstance;
use vectorize_core::types::Model;
use vectorize_core::types::ModelSource;

Expand All @@ -30,7 +32,18 @@ pub fn call_chat(
.unwrap_or_else(|e| error!("failed to deserialize job params: {}", e));

// for various token count estimations
let bpe = get_bpe_from_model(&chat_model.name).expect("failed to get BPE from model");
let bpe = match chat_model.source {
ModelSource::Ollama => {
// Using gpt-3.5-turbo tokenizer for Ollama since the library does not support llama2
get_bpe_from_model("gpt-3.5-turbo").expect("failed to get BPE from model")
}
ModelSource::OpenAI => {
get_bpe_from_model(&chat_model.name).expect("failed to get BPE from model")
}
ModelSource::SentenceTransformers => {
error!("SentenceTransformers not supported for chat completions")
}
};

// can only be 1 column in a chat job, for now, so safe to grab first element
let content_column = job_params.columns[0].clone();
Expand Down Expand Up @@ -108,6 +121,7 @@ pub fn call_chat(
ModelSource::SentenceTransformers => {
error!("SentenceTransformers not supported for chat completions");
}
ModelSource::Ollama => call_ollama_chat_completions(rendered_prompt, &chat_model.name)?,
};

Ok(ChatResponse {
Expand Down Expand Up @@ -166,6 +180,37 @@ fn call_chat_completions(
Ok(chat_response)
}

fn call_ollama_chat_completions(prompts: RenderedPrompt, model: &str) -> Result<String> {
// get url from guc
let url = match guc::get_guc(guc::VectorizeGuc::OllamaServiceUrl) {
Some(k) => k,
None => {
error!("failed to get Ollama url from GUC");
}
};

let runtime = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e));

let instance = OllamaInstance::new(model.to_string(), url.to_string());

let response = runtime.block_on(async {
instance
.generate_reponse(prompts.sys_rendered + "\n" + &prompts.user_rendered)
.await
});

match response {
Ok(k) => Ok(k),
Err(k) => {
error!("Unable to generate response. Error: {k}");
}
}
}

// Trims the context to fit within the token limit when force_trim = True
// Otherwise returns an error if the context exceeds the token limit
fn trim_context(context: &str, overage: i32, bpe: &CoreBPE) -> Result<String> {
Expand Down
Loading

0 comments on commit b9fa75b

Please sign in to comment.