feat: add backend url route completion (#95)
This commit is contained in:
parent
0e95bb3589
commit
98a12630e7
|
@ -96,6 +96,16 @@ impl Backend {
|
|||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn url(self) -> String {
|
||||
match self {
|
||||
Self::HuggingFace { url } => url,
|
||||
Self::LlamaCpp { url } => url,
|
||||
Self::Ollama { url } => url,
|
||||
Self::OpenAi { url } => url,
|
||||
Self::Tgi { url } => url,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
|
@ -141,6 +151,8 @@ pub struct GetCompletionsParams {
|
|||
pub tls_skip_verify_insecure: bool,
|
||||
#[serde(default)]
|
||||
pub request_body: Map<String, Value>,
|
||||
#[serde(default)]
|
||||
pub disable_url_path_completion: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
|
|
|
@ -26,7 +26,7 @@ impl Display for APIError {
|
|||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum APIResponse {
|
||||
pub(crate) enum APIResponse {
|
||||
Generation(Generation),
|
||||
Generations(Vec<Generation>),
|
||||
Error(APIError),
|
||||
|
|
|
@ -129,7 +129,7 @@ fn get_parser(language_id: LanguageId) -> Result<Parser> {
|
|||
#[derive(Clone, Debug, Copy)]
|
||||
/// We redeclare this enum here because the `lsp_types` crate exports a Cow
|
||||
/// type that is unconvenient to deal with.
|
||||
pub enum PositionEncodingKind {
|
||||
pub(crate) enum PositionEncodingKind {
|
||||
Utf8,
|
||||
Utf16,
|
||||
Utf32,
|
||||
|
@ -168,7 +168,7 @@ impl TryFrom<Vec<tower_lsp::lsp_types::PositionEncodingKind>> for PositionEncodi
|
|||
}
|
||||
|
||||
impl PositionEncodingKind {
|
||||
pub fn to_lsp_type(self) -> tower_lsp::lsp_types::PositionEncodingKind {
|
||||
pub(crate) fn to_lsp_type(self) -> tower_lsp::lsp_types::PositionEncodingKind {
|
||||
match self {
|
||||
PositionEncodingKind::Utf8 => tower_lsp::lsp_types::PositionEncodingKind::UTF8,
|
||||
PositionEncodingKind::Utf16 => tower_lsp::lsp_types::PositionEncodingKind::UTF16,
|
||||
|
|
|
@ -62,16 +62,6 @@ impl fmt::Display for LanguageId {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) struct LanguageIdError {
|
||||
language_id: String,
|
||||
}
|
||||
|
||||
impl fmt::Display for LanguageIdError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "Invalid language id: {}", self.language_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for LanguageId {
|
||||
fn from(value: &str) -> Self {
|
||||
match value {
|
||||
|
|
|
@ -246,8 +246,15 @@ async fn request_completion(
|
|||
params.request_body.clone(),
|
||||
);
|
||||
let headers = build_headers(¶ms.backend, params.api_token.as_ref(), params.ide)?;
|
||||
let url = build_url(
|
||||
params.backend.clone(),
|
||||
¶ms.model,
|
||||
params.disable_url_path_completion,
|
||||
);
|
||||
info!(?headers, url, "sending request to backend");
|
||||
debug!(?headers, body = ?json, url, "sending request to backend");
|
||||
let res = http_client
|
||||
.post(build_url(params.backend.clone(), ¶ms.model))
|
||||
.post(url)
|
||||
.json(&json)
|
||||
.headers(headers)
|
||||
.send()
|
||||
|
@ -414,7 +421,12 @@ async fn get_tokenizer(
|
|||
}
|
||||
}
|
||||
|
||||
fn build_url(backend: Backend, model: &str) -> String {
|
||||
// TODO: add configuration parameter to disable path auto-complete?
|
||||
fn build_url(backend: Backend, model: &str, disable_url_path_completion: bool) -> String {
|
||||
if disable_url_path_completion {
|
||||
return backend.url();
|
||||
}
|
||||
|
||||
match backend {
|
||||
Backend::HuggingFace { url } => format!("{url}/models/{model}"),
|
||||
Backend::LlamaCpp { mut url } => {
|
||||
|
@ -428,9 +440,51 @@ fn build_url(backend: Backend, model: &str) -> String {
|
|||
url
|
||||
}
|
||||
}
|
||||
Backend::Ollama { url } => url,
|
||||
Backend::OpenAi { url } => url,
|
||||
Backend::Tgi { url } => url,
|
||||
Backend::Ollama { mut url } => {
|
||||
if url.ends_with("/api/generate") {
|
||||
url
|
||||
} else if url.ends_with("/api/") {
|
||||
url.push_str("generate");
|
||||
url
|
||||
} else if url.ends_with("/api") {
|
||||
url.push_str("/generate");
|
||||
url
|
||||
} else if url.ends_with('/') {
|
||||
url.push_str("api/generate");
|
||||
url
|
||||
} else {
|
||||
url.push_str("/api/generate");
|
||||
url
|
||||
}
|
||||
}
|
||||
Backend::OpenAi { mut url } => {
|
||||
if url.ends_with("/v1/completions") {
|
||||
url
|
||||
} else if url.ends_with("/v1/") {
|
||||
url.push_str("completions");
|
||||
url
|
||||
} else if url.ends_with("/v1") {
|
||||
url.push_str("/completions");
|
||||
url
|
||||
} else if url.ends_with('/') {
|
||||
url.push_str("v1/completions");
|
||||
url
|
||||
} else {
|
||||
url.push_str("/v1/completions");
|
||||
url
|
||||
}
|
||||
}
|
||||
Backend::Tgi { mut url } => {
|
||||
if url.ends_with("/generate") {
|
||||
url
|
||||
} else if url.ends_with('/') {
|
||||
url.push_str("generate");
|
||||
url
|
||||
} else {
|
||||
url.push_str("/generate");
|
||||
url
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -466,8 +520,8 @@ impl LlmService {
|
|||
backend = ?params.backend,
|
||||
ide = %params.ide,
|
||||
request_body = serde_json::to_string(¶ms.request_body).map_err(internal_error)?,
|
||||
"received completion request for {}",
|
||||
params.text_document_position.text_document.uri
|
||||
disable_url_path_completion = params.disable_url_path_completion,
|
||||
"received completion request",
|
||||
);
|
||||
if params.api_token.is_none() && params.backend.is_using_inference_api() {
|
||||
let now = Instant::now();
|
||||
|
|
|
@ -16,6 +16,7 @@ tls_skip_verify_insecure: false
|
|||
tokenizer_config:
|
||||
repository: codellama/CodeLlama-13b-hf
|
||||
tokens_to_clear: ["<EOT>"]
|
||||
disable_url_path_completion: false
|
||||
repositories:
|
||||
- source:
|
||||
type: local
|
||||
|
|
|
@ -16,6 +16,7 @@ tls_skip_verify_insecure: false
|
|||
tokenizer_config:
|
||||
repository: bigcode/starcoder
|
||||
tokens_to_clear: ["<|endoftext|>"]
|
||||
disable_url_path_completion: false
|
||||
repositories:
|
||||
- source:
|
||||
type: local
|
||||
|
|
|
@ -209,6 +209,7 @@ struct RepositoriesConfig {
|
|||
tokenizer_config: Option<TokenizerConfig>,
|
||||
tokens_to_clear: Vec<String>,
|
||||
request_body: Map<String, Value>,
|
||||
disable_url_path_completion: bool,
|
||||
}
|
||||
|
||||
struct HoleCompletionResult {
|
||||
|
@ -490,6 +491,7 @@ async fn complete_holes(
|
|||
tokenizer_config,
|
||||
tokens_to_clear,
|
||||
request_body,
|
||||
disable_url_path_completion,
|
||||
..
|
||||
} = repos_config;
|
||||
async move {
|
||||
|
@ -555,6 +557,7 @@ async fn complete_holes(
|
|||
tokens_to_clear: tokens_to_clear.clone(),
|
||||
tokenizer_config: tokenizer_config.clone(),
|
||||
request_body: request_body.clone(),
|
||||
disable_url_path_completion,
|
||||
})
|
||||
.await?;
|
||||
|
||||
|
|
Loading…
Reference in a new issue