feat: add backend url route completion (#95)
This commit is contained in:
parent
0e95bb3589
commit
98a12630e7
|
@ -96,6 +96,16 @@ impl Backend {
|
||||||
_ => false,
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn url(self) -> String {
|
||||||
|
match self {
|
||||||
|
Self::HuggingFace { url } => url,
|
||||||
|
Self::LlamaCpp { url } => url,
|
||||||
|
Self::Ollama { url } => url,
|
||||||
|
Self::OpenAi { url } => url,
|
||||||
|
Self::Tgi { url } => url,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
@ -141,6 +151,8 @@ pub struct GetCompletionsParams {
|
||||||
pub tls_skip_verify_insecure: bool,
|
pub tls_skip_verify_insecure: bool,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub request_body: Map<String, Value>,
|
pub request_body: Map<String, Value>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub disable_url_path_completion: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
|
|
@ -26,7 +26,7 @@ impl Display for APIError {
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub enum APIResponse {
|
pub(crate) enum APIResponse {
|
||||||
Generation(Generation),
|
Generation(Generation),
|
||||||
Generations(Vec<Generation>),
|
Generations(Vec<Generation>),
|
||||||
Error(APIError),
|
Error(APIError),
|
||||||
|
|
|
@ -129,7 +129,7 @@ fn get_parser(language_id: LanguageId) -> Result<Parser> {
|
||||||
#[derive(Clone, Debug, Copy)]
|
#[derive(Clone, Debug, Copy)]
|
||||||
/// We redeclare this enum here because the `lsp_types` crate exports a Cow
|
/// We redeclare this enum here because the `lsp_types` crate exports a Cow
|
||||||
/// type that is unconvenient to deal with.
|
/// type that is unconvenient to deal with.
|
||||||
pub enum PositionEncodingKind {
|
pub(crate) enum PositionEncodingKind {
|
||||||
Utf8,
|
Utf8,
|
||||||
Utf16,
|
Utf16,
|
||||||
Utf32,
|
Utf32,
|
||||||
|
@ -168,7 +168,7 @@ impl TryFrom<Vec<tower_lsp::lsp_types::PositionEncodingKind>> for PositionEncodi
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PositionEncodingKind {
|
impl PositionEncodingKind {
|
||||||
pub fn to_lsp_type(self) -> tower_lsp::lsp_types::PositionEncodingKind {
|
pub(crate) fn to_lsp_type(self) -> tower_lsp::lsp_types::PositionEncodingKind {
|
||||||
match self {
|
match self {
|
||||||
PositionEncodingKind::Utf8 => tower_lsp::lsp_types::PositionEncodingKind::UTF8,
|
PositionEncodingKind::Utf8 => tower_lsp::lsp_types::PositionEncodingKind::UTF8,
|
||||||
PositionEncodingKind::Utf16 => tower_lsp::lsp_types::PositionEncodingKind::UTF16,
|
PositionEncodingKind::Utf16 => tower_lsp::lsp_types::PositionEncodingKind::UTF16,
|
||||||
|
|
|
@ -62,16 +62,6 @@ impl fmt::Display for LanguageId {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct LanguageIdError {
|
|
||||||
language_id: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for LanguageIdError {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
||||||
write!(f, "Invalid language id: {}", self.language_id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<&str> for LanguageId {
|
impl From<&str> for LanguageId {
|
||||||
fn from(value: &str) -> Self {
|
fn from(value: &str) -> Self {
|
||||||
match value {
|
match value {
|
||||||
|
|
|
@ -246,8 +246,15 @@ async fn request_completion(
|
||||||
params.request_body.clone(),
|
params.request_body.clone(),
|
||||||
);
|
);
|
||||||
let headers = build_headers(¶ms.backend, params.api_token.as_ref(), params.ide)?;
|
let headers = build_headers(¶ms.backend, params.api_token.as_ref(), params.ide)?;
|
||||||
|
let url = build_url(
|
||||||
|
params.backend.clone(),
|
||||||
|
¶ms.model,
|
||||||
|
params.disable_url_path_completion,
|
||||||
|
);
|
||||||
|
info!(?headers, url, "sending request to backend");
|
||||||
|
debug!(?headers, body = ?json, url, "sending request to backend");
|
||||||
let res = http_client
|
let res = http_client
|
||||||
.post(build_url(params.backend.clone(), ¶ms.model))
|
.post(url)
|
||||||
.json(&json)
|
.json(&json)
|
||||||
.headers(headers)
|
.headers(headers)
|
||||||
.send()
|
.send()
|
||||||
|
@ -414,7 +421,12 @@ async fn get_tokenizer(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_url(backend: Backend, model: &str) -> String {
|
// TODO: add configuration parameter to disable path auto-complete?
|
||||||
|
fn build_url(backend: Backend, model: &str, disable_url_path_completion: bool) -> String {
|
||||||
|
if disable_url_path_completion {
|
||||||
|
return backend.url();
|
||||||
|
}
|
||||||
|
|
||||||
match backend {
|
match backend {
|
||||||
Backend::HuggingFace { url } => format!("{url}/models/{model}"),
|
Backend::HuggingFace { url } => format!("{url}/models/{model}"),
|
||||||
Backend::LlamaCpp { mut url } => {
|
Backend::LlamaCpp { mut url } => {
|
||||||
|
@ -428,9 +440,51 @@ fn build_url(backend: Backend, model: &str) -> String {
|
||||||
url
|
url
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Backend::Ollama { url } => url,
|
Backend::Ollama { mut url } => {
|
||||||
Backend::OpenAi { url } => url,
|
if url.ends_with("/api/generate") {
|
||||||
Backend::Tgi { url } => url,
|
url
|
||||||
|
} else if url.ends_with("/api/") {
|
||||||
|
url.push_str("generate");
|
||||||
|
url
|
||||||
|
} else if url.ends_with("/api") {
|
||||||
|
url.push_str("/generate");
|
||||||
|
url
|
||||||
|
} else if url.ends_with('/') {
|
||||||
|
url.push_str("api/generate");
|
||||||
|
url
|
||||||
|
} else {
|
||||||
|
url.push_str("/api/generate");
|
||||||
|
url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Backend::OpenAi { mut url } => {
|
||||||
|
if url.ends_with("/v1/completions") {
|
||||||
|
url
|
||||||
|
} else if url.ends_with("/v1/") {
|
||||||
|
url.push_str("completions");
|
||||||
|
url
|
||||||
|
} else if url.ends_with("/v1") {
|
||||||
|
url.push_str("/completions");
|
||||||
|
url
|
||||||
|
} else if url.ends_with('/') {
|
||||||
|
url.push_str("v1/completions");
|
||||||
|
url
|
||||||
|
} else {
|
||||||
|
url.push_str("/v1/completions");
|
||||||
|
url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Backend::Tgi { mut url } => {
|
||||||
|
if url.ends_with("/generate") {
|
||||||
|
url
|
||||||
|
} else if url.ends_with('/') {
|
||||||
|
url.push_str("generate");
|
||||||
|
url
|
||||||
|
} else {
|
||||||
|
url.push_str("/generate");
|
||||||
|
url
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -466,8 +520,8 @@ impl LlmService {
|
||||||
backend = ?params.backend,
|
backend = ?params.backend,
|
||||||
ide = %params.ide,
|
ide = %params.ide,
|
||||||
request_body = serde_json::to_string(¶ms.request_body).map_err(internal_error)?,
|
request_body = serde_json::to_string(¶ms.request_body).map_err(internal_error)?,
|
||||||
"received completion request for {}",
|
disable_url_path_completion = params.disable_url_path_completion,
|
||||||
params.text_document_position.text_document.uri
|
"received completion request",
|
||||||
);
|
);
|
||||||
if params.api_token.is_none() && params.backend.is_using_inference_api() {
|
if params.api_token.is_none() && params.backend.is_using_inference_api() {
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
|
|
|
@ -16,6 +16,7 @@ tls_skip_verify_insecure: false
|
||||||
tokenizer_config:
|
tokenizer_config:
|
||||||
repository: codellama/CodeLlama-13b-hf
|
repository: codellama/CodeLlama-13b-hf
|
||||||
tokens_to_clear: ["<EOT>"]
|
tokens_to_clear: ["<EOT>"]
|
||||||
|
disable_url_path_completion: false
|
||||||
repositories:
|
repositories:
|
||||||
- source:
|
- source:
|
||||||
type: local
|
type: local
|
||||||
|
|
|
@ -16,6 +16,7 @@ tls_skip_verify_insecure: false
|
||||||
tokenizer_config:
|
tokenizer_config:
|
||||||
repository: bigcode/starcoder
|
repository: bigcode/starcoder
|
||||||
tokens_to_clear: ["<|endoftext|>"]
|
tokens_to_clear: ["<|endoftext|>"]
|
||||||
|
disable_url_path_completion: false
|
||||||
repositories:
|
repositories:
|
||||||
- source:
|
- source:
|
||||||
type: local
|
type: local
|
||||||
|
|
|
@ -209,6 +209,7 @@ struct RepositoriesConfig {
|
||||||
tokenizer_config: Option<TokenizerConfig>,
|
tokenizer_config: Option<TokenizerConfig>,
|
||||||
tokens_to_clear: Vec<String>,
|
tokens_to_clear: Vec<String>,
|
||||||
request_body: Map<String, Value>,
|
request_body: Map<String, Value>,
|
||||||
|
disable_url_path_completion: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct HoleCompletionResult {
|
struct HoleCompletionResult {
|
||||||
|
@ -490,6 +491,7 @@ async fn complete_holes(
|
||||||
tokenizer_config,
|
tokenizer_config,
|
||||||
tokens_to_clear,
|
tokens_to_clear,
|
||||||
request_body,
|
request_body,
|
||||||
|
disable_url_path_completion,
|
||||||
..
|
..
|
||||||
} = repos_config;
|
} = repos_config;
|
||||||
async move {
|
async move {
|
||||||
|
@ -555,6 +557,7 @@ async fn complete_holes(
|
||||||
tokens_to_clear: tokens_to_clear.clone(),
|
tokens_to_clear: tokens_to_clear.clone(),
|
||||||
tokenizer_config: tokenizer_config.clone(),
|
tokenizer_config: tokenizer_config.clone(),
|
||||||
request_body: request_body.clone(),
|
request_body: request_body.clone(),
|
||||||
|
disable_url_path_completion,
|
||||||
})
|
})
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue