feat: update backend & model parameter (#74)

* feat: update backend & model parameter

* fix: add `stream: false` in request body for ollama & openai
This commit is contained in:
Luc Georges 2024-02-09 18:42:41 +01:00 committed by GitHub
parent 92fc885503
commit 4891468c1a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 82 additions and 43 deletions

View file

@ -19,12 +19,16 @@ It also makes sure that you are within the context window of the model by tokeni
Gathers information about requests and completions that can enable retraining. Gathers information about requests and completions that can enable retraining.
Note that **llm-ls** does not export any data anywhere (other than setting a user agent when querying the model API), everything is stored in a log file if you set the log level to `info`. Note that **llm-ls** does not export any data anywhere (other than setting a user agent when querying the model API), everything is stored in a log file (`~/.cache/llm_ls/llm-ls.log`) if you set the log level to `info`.
### Completion ### Completion
**llm-ls** parses the AST of the code to determine if completions should be multi line, single line or empty (no completion). **llm-ls** parses the AST of the code to determine if completions should be multi line, single line or empty (no completion).
### Multiple backends
**llm-ls** is compatible with Hugging Face's [Inference API](https://huggingface.co/docs/api-inference/en/index), Hugging Face's [text-generation-inference](https://github.com/huggingface/text-generation-inference), [ollama](https://github.com/ollama/ollama) and OpenAI compatible APIs, like [llama.cpp](https://github.com/ggerganov/llama.cpp/tree/master/examples/server).
## Compatible extensions ## Compatible extensions
- [x] [llm.nvim](https://github.com/huggingface/llm.nvim) - [x] [llm.nvim](https://github.com/huggingface/llm.nvim)
@ -38,6 +42,4 @@ Note that **llm-ls** does not export any data anywhere (other than setting a use
- add `suffix_percent` setting that determines the ratio of # of tokens for the prefix vs the suffix in the prompt - add `suffix_percent` setting that determines the ratio of # of tokens for the prefix vs the suffix in the prompt
- add context window fill percent or change context_window to `max_tokens` - add context window fill percent or change context_window to `max_tokens`
- filter bad suggestions (repetitive, same as below, etc) - filter bad suggestions (repetitive, same as below, etc)
- support for ollama
- support for llama.cpp
- oltp traces ? - oltp traces ?

View file

@ -5,6 +5,8 @@ use serde::{Deserialize, Deserializer, Serialize};
use serde_json::{Map, Value}; use serde_json::{Map, Value};
use uuid::Uuid; use uuid::Uuid;
const HF_INFERENCE_API_HOSTNAME: &str = "api-inference.huggingface.co";
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct AcceptCompletionParams { pub struct AcceptCompletionParams {
@ -47,19 +49,42 @@ where
Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown)) Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown))
} }
#[derive(Clone, Debug, Default, Deserialize, Serialize)] fn hf_default_url() -> String {
#[serde(rename_all = "lowercase")] format!("https://{HF_INFERENCE_API_HOSTNAME}")
pub enum Backend {
#[default]
HuggingFace,
Ollama,
OpenAi,
Tgi,
} }
impl Display for Backend { #[derive(Clone, Debug, Deserialize, Serialize)]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { #[serde(rename_all = "lowercase", tag = "backend")]
self.serialize(f) pub enum Backend {
HuggingFace {
#[serde(default = "hf_default_url")]
url: String,
},
Ollama {
url: String,
},
OpenAi {
url: String,
},
Tgi {
url: String,
},
}
impl Default for Backend {
fn default() -> Self {
Self::HuggingFace {
url: hf_default_url(),
}
}
}
impl Backend {
pub fn is_using_inference_api(&self) -> bool {
match self {
Self::HuggingFace { url } => url.contains(HF_INFERENCE_API_HOSTNAME),
_ => false,
}
} }
} }
@ -98,11 +123,13 @@ pub struct GetCompletionsParams {
pub fim: FimParams, pub fim: FimParams,
pub api_token: Option<String>, pub api_token: Option<String>,
pub model: String, pub model: String,
#[serde(flatten)]
pub backend: Backend, pub backend: Backend,
pub tokens_to_clear: Vec<String>, pub tokens_to_clear: Vec<String>,
pub tokenizer_config: Option<TokenizerConfig>, pub tokenizer_config: Option<TokenizerConfig>,
pub context_window: usize, pub context_window: usize,
pub tls_skip_verify_insecure: bool, pub tls_skip_verify_insecure: bool,
#[serde(default)]
pub request_body: Map<String, Value>, pub request_body: Map<String, Value>,
} }

View file

@ -1,5 +1,5 @@
use super::{APIError, APIResponse, Generation, NAME, VERSION}; use super::{APIError, APIResponse, Generation, NAME, VERSION};
use custom_types::llm_ls::{Backend, GetCompletionsParams, Ide}; use custom_types::llm_ls::{Backend, Ide};
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{Map, Value}; use serde_json::{Map, Value};
@ -151,33 +151,39 @@ fn parse_openai_text(text: &str) -> Result<Vec<Generation>> {
} }
} }
pub fn build_body(prompt: String, params: &GetCompletionsParams) -> Map<String, Value> { pub fn build_body(
let mut body = params.request_body.clone(); backend: &Backend,
match params.backend { model: String,
Backend::HuggingFace | Backend::Tgi => { prompt: String,
body.insert("inputs".to_string(), Value::String(prompt)) mut request_body: Map<String, Value>,
) -> Map<String, Value> {
match backend {
Backend::HuggingFace { .. } | Backend::Tgi { .. } => {
request_body.insert("inputs".to_owned(), Value::String(prompt));
} }
Backend::Ollama | Backend::OpenAi => { Backend::Ollama { .. } | Backend::OpenAi { .. } => {
body.insert("prompt".to_string(), Value::String(prompt)) request_body.insert("prompt".to_owned(), Value::String(prompt));
request_body.insert("model".to_owned(), Value::String(model));
request_body.insert("stream".to_owned(), Value::Bool(false));
} }
}; };
body request_body
} }
pub fn build_headers(backend: &Backend, api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> { pub fn build_headers(backend: &Backend, api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
match backend { match backend {
Backend::HuggingFace => build_api_headers(api_token, ide), Backend::HuggingFace { .. } => build_api_headers(api_token, ide),
Backend::Ollama => Ok(build_ollama_headers()), Backend::Ollama { .. } => Ok(build_ollama_headers()),
Backend::OpenAi => build_openai_headers(api_token, ide), Backend::OpenAi { .. } => build_openai_headers(api_token, ide),
Backend::Tgi => build_tgi_headers(api_token, ide), Backend::Tgi { .. } => build_tgi_headers(api_token, ide),
} }
} }
pub fn parse_generations(backend: &Backend, text: &str) -> Result<Vec<Generation>> { pub fn parse_generations(backend: &Backend, text: &str) -> Result<Vec<Generation>> {
match backend { match backend {
Backend::HuggingFace => parse_api_text(text), Backend::HuggingFace { .. } => parse_api_text(text),
Backend::Ollama => parse_ollama_text(text), Backend::Ollama { .. } => parse_ollama_text(text),
Backend::OpenAi => parse_openai_text(text), Backend::OpenAi { .. } => parse_openai_text(text),
Backend::Tgi => parse_tgi_text(text), Backend::Tgi { .. } => parse_tgi_text(text),
} }
} }

View file

@ -34,7 +34,6 @@ mod language_id;
const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600); const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600);
pub const NAME: &str = "llm-ls"; pub const NAME: &str = "llm-ls";
pub const VERSION: &str = env!("CARGO_PKG_VERSION"); pub const VERSION: &str = env!("CARGO_PKG_VERSION");
const HF_INFERENCE_API_HOSTNAME: &str = "api-inference.huggingface.co";
fn get_position_idx(rope: &Rope, row: usize, col: usize) -> Result<usize> { fn get_position_idx(rope: &Rope, row: usize, col: usize) -> Result<usize> {
Ok(rope.try_line_to_char(row)? Ok(rope.try_line_to_char(row)?
@ -305,10 +304,15 @@ async fn request_completion(
) -> Result<Vec<Generation>> { ) -> Result<Vec<Generation>> {
let t = Instant::now(); let t = Instant::now();
let json = build_body(prompt, params); let json = build_body(
&params.backend,
params.model.clone(),
prompt,
params.request_body.clone(),
);
let headers = build_headers(&params.backend, params.api_token.as_ref(), params.ide)?; let headers = build_headers(&params.backend, params.api_token.as_ref(), params.ide)?;
let res = http_client let res = http_client
.post(build_url(&params.model)) .post(build_url(params.backend.clone(), &params.model))
.json(&json) .json(&json)
.headers(headers) .headers(headers)
.send() .send()
@ -367,7 +371,7 @@ async fn download_tokenizer_file(
return Ok(()); return Ok(());
} }
tokio::fs::create_dir_all(to.as_ref().parent().ok_or(Error::InvalidTokenizerPath)?).await?; tokio::fs::create_dir_all(to.as_ref().parent().ok_or(Error::InvalidTokenizerPath)?).await?;
let headers = build_headers(&Backend::HuggingFace, api_token, ide)?; let headers = build_headers(&Backend::default(), api_token, ide)?;
let mut file = tokio::fs::OpenOptions::new() let mut file = tokio::fs::OpenOptions::new()
.write(true) .write(true)
.create(true) .create(true)
@ -475,11 +479,12 @@ async fn get_tokenizer(
} }
} }
fn build_url(model: &str) -> String { fn build_url(backend: Backend, model: &str) -> String {
if model.starts_with("http://") || model.starts_with("https://") { match backend {
model.to_owned() Backend::HuggingFace { url } => format!("{url}/models/{model}"),
} else { Backend::Ollama { url } => url,
format!("https://{HF_INFERENCE_API_HOSTNAME}/models/{model}") Backend::OpenAi { url } => url,
Backend::Tgi { url } => url,
} }
} }
@ -502,14 +507,13 @@ impl LlmService {
cursor_character = ?params.text_document_position.position.character, cursor_character = ?params.text_document_position.position.character,
language_id = %document.language_id, language_id = %document.language_id,
model = params.model, model = params.model,
backend = %params.backend, backend = ?params.backend,
ide = %params.ide, ide = %params.ide,
request_body = serde_json::to_string(&params.request_body).map_err(internal_error)?, request_body = serde_json::to_string(&params.request_body).map_err(internal_error)?,
"received completion request for {}", "received completion request for {}",
params.text_document_position.text_document.uri params.text_document_position.text_document.uri
); );
let is_using_inference_api = matches!(params.backend, Backend::HuggingFace); if params.api_token.is_none() && params.backend.is_using_inference_api() {
if params.api_token.is_none() && is_using_inference_api {
let now = Instant::now(); let now = Instant::now();
let unauthenticated_warn_at = self.unauthenticated_warn_at.read().await; let unauthenticated_warn_at = self.unauthenticated_warn_at.read().await;
if now.duration_since(*unauthenticated_warn_at) > MAX_WARNING_REPEAT { if now.duration_since(*unauthenticated_warn_at) > MAX_WARNING_REPEAT {