diff --git a/crates/custom-types/src/llm_ls.rs b/crates/custom-types/src/llm_ls.rs index cd66d5c..debf465 100644 --- a/crates/custom-types/src/llm_ls.rs +++ b/crates/custom-types/src/llm_ls.rs @@ -96,6 +96,16 @@ impl Backend { _ => false, } } + + pub fn url(self) -> String { + match self { + Self::HuggingFace { url } => url, + Self::LlamaCpp { url } => url, + Self::Ollama { url } => url, + Self::OpenAi { url } => url, + Self::Tgi { url } => url, + } + } } #[derive(Clone, Debug, Deserialize, Serialize)] @@ -141,6 +151,8 @@ pub struct GetCompletionsParams { pub tls_skip_verify_insecure: bool, #[serde(default)] pub request_body: Map, + #[serde(default)] + pub disable_url_path_completion: bool, } #[derive(Clone, Debug, Deserialize, Serialize)] diff --git a/crates/llm-ls/src/backend.rs b/crates/llm-ls/src/backend.rs index 90324c5..d1467b5 100644 --- a/crates/llm-ls/src/backend.rs +++ b/crates/llm-ls/src/backend.rs @@ -26,7 +26,7 @@ impl Display for APIError { #[derive(Debug, Deserialize)] #[serde(untagged)] -pub enum APIResponse { +pub(crate) enum APIResponse { Generation(Generation), Generations(Vec), Error(APIError), diff --git a/crates/llm-ls/src/document.rs b/crates/llm-ls/src/document.rs index 83c1072..fc7fc81 100644 --- a/crates/llm-ls/src/document.rs +++ b/crates/llm-ls/src/document.rs @@ -129,7 +129,7 @@ fn get_parser(language_id: LanguageId) -> Result { #[derive(Clone, Debug, Copy)] /// We redeclare this enum here because the `lsp_types` crate exports a Cow /// type that is unconvenient to deal with. -pub enum PositionEncodingKind { +pub(crate) enum PositionEncodingKind { Utf8, Utf16, Utf32, diff --git a/crates/llm-ls/src/language_id.rs b/crates/llm-ls/src/language_id.rs index 31ce3d9..26c6362 100644 --- a/crates/llm-ls/src/language_id.rs +++ b/crates/llm-ls/src/language_id.rs @@ -62,16 +62,6 @@ impl fmt::Display for LanguageId { } } -pub(crate) struct LanguageIdError { - language_id: String, -} - -impl fmt::Display for LanguageIdError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Invalid language id: {}", self.language_id) - } -} - impl From<&str> for LanguageId { fn from(value: &str) -> Self { match value { diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index ff74be3..3088268 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -246,8 +246,15 @@ async fn request_completion( params.request_body.clone(), ); let headers = build_headers(¶ms.backend, params.api_token.as_ref(), params.ide)?; + let url = build_url( + params.backend.clone(), + ¶ms.model, + params.disable_url_path_completion, + ); + info!(?headers, url, "sending request to backend"); + debug!(?headers, body = ?json, url, "sending request to backend"); let res = http_client - .post(build_url(params.backend.clone(), ¶ms.model)) + .post(url) .json(&json) .headers(headers) .send() @@ -414,7 +421,12 @@ async fn get_tokenizer( } } -fn build_url(backend: Backend, model: &str) -> String { +// TODO: add configuration parameter to disable path auto-complete? +fn build_url(backend: Backend, model: &str, disable_url_path_completion: bool) -> String { + if disable_url_path_completion { + return backend.url(); + } + match backend { Backend::HuggingFace { url } => format!("{url}/models/{model}"), Backend::LlamaCpp { mut url } => { @@ -428,9 +440,51 @@ fn build_url(backend: Backend, model: &str) -> String { url } } - Backend::Ollama { url } => url, - Backend::OpenAi { url } => url, - Backend::Tgi { url } => url, + Backend::Ollama { mut url } => { + if url.ends_with("/api/generate") { + url + } else if url.ends_with("/api/") { + url.push_str("generate"); + url + } else if url.ends_with("/api") { + url.push_str("/generate"); + url + } else if url.ends_with('/') { + url.push_str("api/generate"); + url + } else { + url.push_str("/api/generate"); + url + } + } + Backend::OpenAi { mut url } => { + if url.ends_with("/v1/completions") { + url + } else if url.ends_with("/v1/") { + url.push_str("completions"); + url + } else if url.ends_with("/v1") { + url.push_str("/completions"); + url + } else if url.ends_with('/') { + url.push_str("v1/completions"); + url + } else { + url.push_str("/v1/completions"); + url + } + } + Backend::Tgi { mut url } => { + if url.ends_with("/generate") { + url + } else if url.ends_with('/') { + url.push_str("generate"); + url + } else { + url.push_str("/generate"); + url + } + } } } @@ -466,8 +520,8 @@ impl LlmService { backend = ?params.backend, ide = %params.ide, request_body = serde_json::to_string(¶ms.request_body).map_err(internal_error)?, - "received completion request for {}", - params.text_document_position.text_document.uri + disable_url_path_completion = params.disable_url_path_completion, + "received completion request", ); if params.api_token.is_none() && params.backend.is_using_inference_api() { let now = Instant::now(); diff --git a/crates/testbed/repositories-ci.yaml b/crates/testbed/repositories-ci.yaml index 1bb7d4f..5bb3d1f 100644 --- a/crates/testbed/repositories-ci.yaml +++ b/crates/testbed/repositories-ci.yaml @@ -16,6 +16,7 @@ tls_skip_verify_insecure: false tokenizer_config: repository: codellama/CodeLlama-13b-hf tokens_to_clear: [""] +disable_url_path_completion: false repositories: - source: type: local diff --git a/crates/testbed/repositories.yaml b/crates/testbed/repositories.yaml index 418ac3e..c565829 100644 --- a/crates/testbed/repositories.yaml +++ b/crates/testbed/repositories.yaml @@ -16,6 +16,7 @@ tls_skip_verify_insecure: false tokenizer_config: repository: bigcode/starcoder tokens_to_clear: ["<|endoftext|>"] +disable_url_path_completion: false repositories: - source: type: local diff --git a/crates/testbed/src/main.rs b/crates/testbed/src/main.rs index b1c8a1c..0eb14f9 100644 --- a/crates/testbed/src/main.rs +++ b/crates/testbed/src/main.rs @@ -209,6 +209,7 @@ struct RepositoriesConfig { tokenizer_config: Option, tokens_to_clear: Vec, request_body: Map, + disable_url_path_completion: bool, } struct HoleCompletionResult { @@ -490,6 +491,7 @@ async fn complete_holes( tokenizer_config, tokens_to_clear, request_body, + disable_url_path_completion, .. } = repos_config; async move { @@ -555,6 +557,7 @@ async fn complete_holes( tokens_to_clear: tokens_to_clear.clone(), tokenizer_config: tokenizer_config.clone(), request_body: request_body.clone(), + disable_url_path_completion, }) .await?;