feat: add backend url route completion (#95)

This commit is contained in:
Luc Georges 2024-05-24 13:15:59 +02:00 committed by Luc Georges
parent 0e95bb3589
commit 98a12630e7
No known key found for this signature in database
GPG key ID: 22924A120A2C2CE0
8 changed files with 81 additions and 20 deletions

View file

@ -96,6 +96,16 @@ impl Backend {
_ => false,
}
}
pub fn url(self) -> String {
match self {
Self::HuggingFace { url } => url,
Self::LlamaCpp { url } => url,
Self::Ollama { url } => url,
Self::OpenAi { url } => url,
Self::Tgi { url } => url,
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
@ -141,6 +151,8 @@ pub struct GetCompletionsParams {
pub tls_skip_verify_insecure: bool,
#[serde(default)]
pub request_body: Map<String, Value>,
#[serde(default)]
pub disable_url_path_completion: bool,
}
#[derive(Clone, Debug, Deserialize, Serialize)]

View file

@ -26,7 +26,7 @@ impl Display for APIError {
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum APIResponse {
pub(crate) enum APIResponse {
Generation(Generation),
Generations(Vec<Generation>),
Error(APIError),

View file

@ -129,7 +129,7 @@ fn get_parser(language_id: LanguageId) -> Result<Parser> {
#[derive(Clone, Debug, Copy)]
/// We redeclare this enum here because the `lsp_types` crate exports a Cow
/// type that is unconvenient to deal with.
pub enum PositionEncodingKind {
pub(crate) enum PositionEncodingKind {
Utf8,
Utf16,
Utf32,
@ -168,7 +168,7 @@ impl TryFrom<Vec<tower_lsp::lsp_types::PositionEncodingKind>> for PositionEncodi
}
impl PositionEncodingKind {
pub fn to_lsp_type(self) -> tower_lsp::lsp_types::PositionEncodingKind {
pub(crate) fn to_lsp_type(self) -> tower_lsp::lsp_types::PositionEncodingKind {
match self {
PositionEncodingKind::Utf8 => tower_lsp::lsp_types::PositionEncodingKind::UTF8,
PositionEncodingKind::Utf16 => tower_lsp::lsp_types::PositionEncodingKind::UTF16,

View file

@ -62,16 +62,6 @@ impl fmt::Display for LanguageId {
}
}
pub(crate) struct LanguageIdError {
language_id: String,
}
impl fmt::Display for LanguageIdError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Invalid language id: {}", self.language_id)
}
}
impl From<&str> for LanguageId {
fn from(value: &str) -> Self {
match value {

View file

@ -246,8 +246,15 @@ async fn request_completion(
params.request_body.clone(),
);
let headers = build_headers(&params.backend, params.api_token.as_ref(), params.ide)?;
let url = build_url(
params.backend.clone(),
&params.model,
params.disable_url_path_completion,
);
info!(?headers, url, "sending request to backend");
debug!(?headers, body = ?json, url, "sending request to backend");
let res = http_client
.post(build_url(params.backend.clone(), &params.model))
.post(url)
.json(&json)
.headers(headers)
.send()
@ -414,7 +421,12 @@ async fn get_tokenizer(
}
}
fn build_url(backend: Backend, model: &str) -> String {
// TODO: add configuration parameter to disable path auto-complete?
fn build_url(backend: Backend, model: &str, disable_url_path_completion: bool) -> String {
if disable_url_path_completion {
return backend.url();
}
match backend {
Backend::HuggingFace { url } => format!("{url}/models/{model}"),
Backend::LlamaCpp { mut url } => {
@ -428,9 +440,51 @@ fn build_url(backend: Backend, model: &str) -> String {
url
}
}
Backend::Ollama { url } => url,
Backend::OpenAi { url } => url,
Backend::Tgi { url } => url,
Backend::Ollama { mut url } => {
if url.ends_with("/api/generate") {
url
} else if url.ends_with("/api/") {
url.push_str("generate");
url
} else if url.ends_with("/api") {
url.push_str("/generate");
url
} else if url.ends_with('/') {
url.push_str("api/generate");
url
} else {
url.push_str("/api/generate");
url
}
}
Backend::OpenAi { mut url } => {
if url.ends_with("/v1/completions") {
url
} else if url.ends_with("/v1/") {
url.push_str("completions");
url
} else if url.ends_with("/v1") {
url.push_str("/completions");
url
} else if url.ends_with('/') {
url.push_str("v1/completions");
url
} else {
url.push_str("/v1/completions");
url
}
}
Backend::Tgi { mut url } => {
if url.ends_with("/generate") {
url
} else if url.ends_with('/') {
url.push_str("generate");
url
} else {
url.push_str("/generate");
url
}
}
}
}
@ -466,8 +520,8 @@ impl LlmService {
backend = ?params.backend,
ide = %params.ide,
request_body = serde_json::to_string(&params.request_body).map_err(internal_error)?,
"received completion request for {}",
params.text_document_position.text_document.uri
disable_url_path_completion = params.disable_url_path_completion,
"received completion request",
);
if params.api_token.is_none() && params.backend.is_using_inference_api() {
let now = Instant::now();

View file

@ -16,6 +16,7 @@ tls_skip_verify_insecure: false
tokenizer_config:
repository: codellama/CodeLlama-13b-hf
tokens_to_clear: ["<EOT>"]
disable_url_path_completion: false
repositories:
- source:
type: local

View file

@ -16,6 +16,7 @@ tls_skip_verify_insecure: false
tokenizer_config:
repository: bigcode/starcoder
tokens_to_clear: ["<|endoftext|>"]
disable_url_path_completion: false
repositories:
- source:
type: local

View file

@ -209,6 +209,7 @@ struct RepositoriesConfig {
tokenizer_config: Option<TokenizerConfig>,
tokens_to_clear: Vec<String>,
request_body: Map<String, Value>,
disable_url_path_completion: bool,
}
struct HoleCompletionResult {
@ -490,6 +491,7 @@ async fn complete_holes(
tokenizer_config,
tokens_to_clear,
request_body,
disable_url_path_completion,
..
} = repos_config;
async move {
@ -555,6 +557,7 @@ async fn complete_holes(
tokens_to_clear: tokens_to_clear.clone(),
tokenizer_config: tokenizer_config.clone(),
request_body: request_body.clone(),
disable_url_path_completion,
})
.await?;