feat: update backend & model parameter (#74)
* feat: update backend & model parameter * fix: add `stream: false` in request body for ollama & openai
This commit is contained in:
parent
92fc885503
commit
4891468c1a
|
@ -19,12 +19,16 @@ It also makes sure that you are within the context window of the model by tokeni
|
|||
|
||||
Gathers information about requests and completions that can enable retraining.
|
||||
|
||||
Note that **llm-ls** does not export any data anywhere (other than setting a user agent when querying the model API), everything is stored in a log file if you set the log level to `info`.
|
||||
Note that **llm-ls** does not export any data anywhere (other than setting a user agent when querying the model API), everything is stored in a log file (`~/.cache/llm_ls/llm-ls.log`) if you set the log level to `info`.
|
||||
|
||||
### Completion
|
||||
|
||||
**llm-ls** parses the AST of the code to determine if completions should be multi line, single line or empty (no completion).
|
||||
|
||||
### Multiple backends
|
||||
|
||||
**llm-ls** is compatible with Hugging Face's [Inference API](https://huggingface.co/docs/api-inference/en/index), Hugging Face's [text-generation-inference](https://github.com/huggingface/text-generation-inference), [ollama](https://github.com/ollama/ollama) and OpenAI compatible APIs, like [llama.cpp](https://github.com/ggerganov/llama.cpp/tree/master/examples/server).
|
||||
|
||||
## Compatible extensions
|
||||
|
||||
- [x] [llm.nvim](https://github.com/huggingface/llm.nvim)
|
||||
|
@ -38,6 +42,4 @@ Note that **llm-ls** does not export any data anywhere (other than setting a use
|
|||
- add `suffix_percent` setting that determines the ratio of # of tokens for the prefix vs the suffix in the prompt
|
||||
- add context window fill percent or change context_window to `max_tokens`
|
||||
- filter bad suggestions (repetitive, same as below, etc)
|
||||
- support for ollama
|
||||
- support for llama.cpp
|
||||
- oltp traces ?
|
||||
|
|
|
@ -5,6 +5,8 @@ use serde::{Deserialize, Deserializer, Serialize};
|
|||
use serde_json::{Map, Value};
|
||||
use uuid::Uuid;
|
||||
|
||||
const HF_INFERENCE_API_HOSTNAME: &str = "api-inference.huggingface.co";
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AcceptCompletionParams {
|
||||
|
@ -47,19 +49,42 @@ where
|
|||
Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown))
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Backend {
|
||||
#[default]
|
||||
HuggingFace,
|
||||
Ollama,
|
||||
OpenAi,
|
||||
Tgi,
|
||||
fn hf_default_url() -> String {
|
||||
format!("https://{HF_INFERENCE_API_HOSTNAME}")
|
||||
}
|
||||
|
||||
impl Display for Backend {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.serialize(f)
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "lowercase", tag = "backend")]
|
||||
pub enum Backend {
|
||||
HuggingFace {
|
||||
#[serde(default = "hf_default_url")]
|
||||
url: String,
|
||||
},
|
||||
Ollama {
|
||||
url: String,
|
||||
},
|
||||
OpenAi {
|
||||
url: String,
|
||||
},
|
||||
Tgi {
|
||||
url: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl Default for Backend {
|
||||
fn default() -> Self {
|
||||
Self::HuggingFace {
|
||||
url: hf_default_url(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Backend {
|
||||
pub fn is_using_inference_api(&self) -> bool {
|
||||
match self {
|
||||
Self::HuggingFace { url } => url.contains(HF_INFERENCE_API_HOSTNAME),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -98,11 +123,13 @@ pub struct GetCompletionsParams {
|
|||
pub fim: FimParams,
|
||||
pub api_token: Option<String>,
|
||||
pub model: String,
|
||||
#[serde(flatten)]
|
||||
pub backend: Backend,
|
||||
pub tokens_to_clear: Vec<String>,
|
||||
pub tokenizer_config: Option<TokenizerConfig>,
|
||||
pub context_window: usize,
|
||||
pub tls_skip_verify_insecure: bool,
|
||||
#[serde(default)]
|
||||
pub request_body: Map<String, Value>,
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::{APIError, APIResponse, Generation, NAME, VERSION};
|
||||
use custom_types::llm_ls::{Backend, GetCompletionsParams, Ide};
|
||||
use custom_types::llm_ls::{Backend, Ide};
|
||||
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Map, Value};
|
||||
|
@ -151,33 +151,39 @@ fn parse_openai_text(text: &str) -> Result<Vec<Generation>> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn build_body(prompt: String, params: &GetCompletionsParams) -> Map<String, Value> {
|
||||
let mut body = params.request_body.clone();
|
||||
match params.backend {
|
||||
Backend::HuggingFace | Backend::Tgi => {
|
||||
body.insert("inputs".to_string(), Value::String(prompt))
|
||||
pub fn build_body(
|
||||
backend: &Backend,
|
||||
model: String,
|
||||
prompt: String,
|
||||
mut request_body: Map<String, Value>,
|
||||
) -> Map<String, Value> {
|
||||
match backend {
|
||||
Backend::HuggingFace { .. } | Backend::Tgi { .. } => {
|
||||
request_body.insert("inputs".to_owned(), Value::String(prompt));
|
||||
}
|
||||
Backend::Ollama | Backend::OpenAi => {
|
||||
body.insert("prompt".to_string(), Value::String(prompt))
|
||||
Backend::Ollama { .. } | Backend::OpenAi { .. } => {
|
||||
request_body.insert("prompt".to_owned(), Value::String(prompt));
|
||||
request_body.insert("model".to_owned(), Value::String(model));
|
||||
request_body.insert("stream".to_owned(), Value::Bool(false));
|
||||
}
|
||||
};
|
||||
body
|
||||
request_body
|
||||
}
|
||||
|
||||
pub fn build_headers(backend: &Backend, api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
|
||||
match backend {
|
||||
Backend::HuggingFace => build_api_headers(api_token, ide),
|
||||
Backend::Ollama => Ok(build_ollama_headers()),
|
||||
Backend::OpenAi => build_openai_headers(api_token, ide),
|
||||
Backend::Tgi => build_tgi_headers(api_token, ide),
|
||||
Backend::HuggingFace { .. } => build_api_headers(api_token, ide),
|
||||
Backend::Ollama { .. } => Ok(build_ollama_headers()),
|
||||
Backend::OpenAi { .. } => build_openai_headers(api_token, ide),
|
||||
Backend::Tgi { .. } => build_tgi_headers(api_token, ide),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_generations(backend: &Backend, text: &str) -> Result<Vec<Generation>> {
|
||||
match backend {
|
||||
Backend::HuggingFace => parse_api_text(text),
|
||||
Backend::Ollama => parse_ollama_text(text),
|
||||
Backend::OpenAi => parse_openai_text(text),
|
||||
Backend::Tgi => parse_tgi_text(text),
|
||||
Backend::HuggingFace { .. } => parse_api_text(text),
|
||||
Backend::Ollama { .. } => parse_ollama_text(text),
|
||||
Backend::OpenAi { .. } => parse_openai_text(text),
|
||||
Backend::Tgi { .. } => parse_tgi_text(text),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,7 +34,6 @@ mod language_id;
|
|||
const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600);
|
||||
pub const NAME: &str = "llm-ls";
|
||||
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
const HF_INFERENCE_API_HOSTNAME: &str = "api-inference.huggingface.co";
|
||||
|
||||
fn get_position_idx(rope: &Rope, row: usize, col: usize) -> Result<usize> {
|
||||
Ok(rope.try_line_to_char(row)?
|
||||
|
@ -305,10 +304,15 @@ async fn request_completion(
|
|||
) -> Result<Vec<Generation>> {
|
||||
let t = Instant::now();
|
||||
|
||||
let json = build_body(prompt, params);
|
||||
let json = build_body(
|
||||
¶ms.backend,
|
||||
params.model.clone(),
|
||||
prompt,
|
||||
params.request_body.clone(),
|
||||
);
|
||||
let headers = build_headers(¶ms.backend, params.api_token.as_ref(), params.ide)?;
|
||||
let res = http_client
|
||||
.post(build_url(¶ms.model))
|
||||
.post(build_url(params.backend.clone(), ¶ms.model))
|
||||
.json(&json)
|
||||
.headers(headers)
|
||||
.send()
|
||||
|
@ -367,7 +371,7 @@ async fn download_tokenizer_file(
|
|||
return Ok(());
|
||||
}
|
||||
tokio::fs::create_dir_all(to.as_ref().parent().ok_or(Error::InvalidTokenizerPath)?).await?;
|
||||
let headers = build_headers(&Backend::HuggingFace, api_token, ide)?;
|
||||
let headers = build_headers(&Backend::default(), api_token, ide)?;
|
||||
let mut file = tokio::fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
|
@ -475,11 +479,12 @@ async fn get_tokenizer(
|
|||
}
|
||||
}
|
||||
|
||||
fn build_url(model: &str) -> String {
|
||||
if model.starts_with("http://") || model.starts_with("https://") {
|
||||
model.to_owned()
|
||||
} else {
|
||||
format!("https://{HF_INFERENCE_API_HOSTNAME}/models/{model}")
|
||||
fn build_url(backend: Backend, model: &str) -> String {
|
||||
match backend {
|
||||
Backend::HuggingFace { url } => format!("{url}/models/{model}"),
|
||||
Backend::Ollama { url } => url,
|
||||
Backend::OpenAi { url } => url,
|
||||
Backend::Tgi { url } => url,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -502,14 +507,13 @@ impl LlmService {
|
|||
cursor_character = ?params.text_document_position.position.character,
|
||||
language_id = %document.language_id,
|
||||
model = params.model,
|
||||
backend = %params.backend,
|
||||
backend = ?params.backend,
|
||||
ide = %params.ide,
|
||||
request_body = serde_json::to_string(¶ms.request_body).map_err(internal_error)?,
|
||||
"received completion request for {}",
|
||||
params.text_document_position.text_document.uri
|
||||
);
|
||||
let is_using_inference_api = matches!(params.backend, Backend::HuggingFace);
|
||||
if params.api_token.is_none() && is_using_inference_api {
|
||||
if params.api_token.is_none() && params.backend.is_using_inference_api() {
|
||||
let now = Instant::now();
|
||||
let unauthenticated_warn_at = self.unauthenticated_warn_at.read().await;
|
||||
if now.duration_since(*unauthenticated_warn_at) > MAX_WARNING_REPEAT {
|
||||
|
|
Loading…
Reference in a new issue