Skip to content

Commit

Permalink
feat: add backend url route completion (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
McPatate committed May 24, 2024
1 parent 0e95bb3 commit 0b5e2c6
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 19 deletions.
12 changes: 12 additions & 0 deletions crates/custom-types/src/llm_ls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ impl Backend {
_ => false,
}
}

pub fn url(self) -> String {
match self {
Self::HuggingFace { url } => url,
Self::LlamaCpp { url } => url,
Self::Ollama { url } => url,
Self::OpenAi { url } => url,
Self::Tgi { url } => url,
}
}
}

#[derive(Clone, Debug, Deserialize, Serialize)]
Expand Down Expand Up @@ -141,6 +151,8 @@ pub struct GetCompletionsParams {
pub tls_skip_verify_insecure: bool,
#[serde(default)]
pub request_body: Map<String, Value>,
#[serde(default)]
pub disable_url_path_completion: bool,
}

#[derive(Clone, Debug, Deserialize, Serialize)]
Expand Down
2 changes: 1 addition & 1 deletion crates/llm-ls/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl Display for APIError {

#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum APIResponse {
pub(crate) enum APIResponse {
Generation(Generation),
Generations(Vec<Generation>),
Error(APIError),
Expand Down
2 changes: 1 addition & 1 deletion crates/llm-ls/src/document.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ fn get_parser(language_id: LanguageId) -> Result<Parser> {
#[derive(Clone, Debug, Copy)]
/// We redeclare this enum here because the `lsp_types` crate exports a Cow
/// type that is unconvenient to deal with.
pub enum PositionEncodingKind {
pub(crate) enum PositionEncodingKind {
Utf8,
Utf16,
Utf32,
Expand Down
10 changes: 0 additions & 10 deletions crates/llm-ls/src/language_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,6 @@ impl fmt::Display for LanguageId {
}
}

pub(crate) struct LanguageIdError {
language_id: String,
}

impl fmt::Display for LanguageIdError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Invalid language id: {}", self.language_id)
}
}

impl From<&str> for LanguageId {
fn from(value: &str) -> Self {
match value {
Expand Down
68 changes: 61 additions & 7 deletions crates/llm-ls/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,15 @@ async fn request_completion(
params.request_body.clone(),
);
let headers = build_headers(&params.backend, params.api_token.as_ref(), params.ide)?;
let url = build_url(
params.backend.clone(),
&params.model,
params.disable_url_path_completion,
);
info!(?headers, url, "sending request to backend");
debug!(?headers, body = ?json, url, "sending request to backend");
let res = http_client
.post(build_url(params.backend.clone(), &params.model))
.post(url)
.json(&json)
.headers(headers)
.send()
Expand Down Expand Up @@ -414,7 +421,12 @@ async fn get_tokenizer(
}
}

fn build_url(backend: Backend, model: &str) -> String {
// TODO: add configuration parameter to disable path auto-complete?
fn build_url(backend: Backend, model: &str, disable_url_path_completion: bool) -> String {
if disable_url_path_completion {
return backend.url();
}

match backend {
Backend::HuggingFace { url } => format!("{url}/models/{model}"),
Backend::LlamaCpp { mut url } => {
Expand All @@ -428,9 +440,51 @@ fn build_url(backend: Backend, model: &str) -> String {
url
}
}
Backend::Ollama { url } => url,
Backend::OpenAi { url } => url,
Backend::Tgi { url } => url,
Backend::Ollama { mut url } => {
if url.ends_with("/api/generate") {
url
} else if url.ends_with("/api/") {
url.push_str("generate");
url
} else if url.ends_with("/api") {
url.push_str("/generate");
url
} else if url.ends_with('/') {
url.push_str("api/generate");
url
} else {
url.push_str("/api/generate");
url
}
}
Backend::OpenAi { mut url } => {
if url.ends_with("/v1/completions") {
url
} else if url.ends_with("/v1/") {
url.push_str("completions");
url
} else if url.ends_with("/v1") {
url.push_str("/completions");
url
} else if url.ends_with('/') {
url.push_str("v1/completions");
url
} else {
url.push_str("/v1/completions");
url
}
}
Backend::Tgi { mut url } => {
if url.ends_with("/generate") {
url
} else if url.ends_with('/') {
url.push_str("generate");
url
} else {
url.push_str("/generate");
url
}
}
}
}

Expand Down Expand Up @@ -466,8 +520,8 @@ impl LlmService {
backend = ?params.backend,
ide = %params.ide,
request_body = serde_json::to_string(&params.request_body).map_err(internal_error)?,
"received completion request for {}",
params.text_document_position.text_document.uri
disable_url_path_completion = params.disable_url_path_completion,
"received completion request",
);
if params.api_token.is_none() && params.backend.is_using_inference_api() {
let now = Instant::now();
Expand Down
1 change: 1 addition & 0 deletions crates/testbed/repositories-ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ tls_skip_verify_insecure: false
tokenizer_config:
repository: codellama/CodeLlama-13b-hf
tokens_to_clear: ["<EOT>"]
disable_url_path_completion: false
repositories:
- source:
type: local
Expand Down
1 change: 1 addition & 0 deletions crates/testbed/repositories.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ tls_skip_verify_insecure: false
tokenizer_config:
repository: bigcode/starcoder
tokens_to_clear: ["<|endoftext|>"]
disable_url_path_completion: false
repositories:
- source:
type: local
Expand Down
3 changes: 3 additions & 0 deletions crates/testbed/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ struct RepositoriesConfig {
tokenizer_config: Option<TokenizerConfig>,
tokens_to_clear: Vec<String>,
request_body: Map<String, Value>,
disable_url_path_completion: bool,
}

struct HoleCompletionResult {
Expand Down Expand Up @@ -490,6 +491,7 @@ async fn complete_holes(
tokenizer_config,
tokens_to_clear,
request_body,
disable_url_path_completion,
..
} = repos_config;
async move {
Expand Down Expand Up @@ -555,6 +557,7 @@ async fn complete_holes(
tokens_to_clear: tokens_to_clear.clone(),
tokenizer_config: tokenizer_config.clone(),
request_body: request_body.clone(),
disable_url_path_completion,
})
.await?;

Expand Down

0 comments on commit 0b5e2c6

Please sign in to comment.