diff --git a/crates/llm-ls/src/adaptors.rs b/crates/llm-ls/src/adaptors.rs new file mode 100644 index 0000000..b0861ba --- /dev/null +++ b/crates/llm-ls/src/adaptors.rs @@ -0,0 +1,250 @@ +use super::{ + internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, RequestParams, NAME, + VERSION, +}; +use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::fmt::Display; +use tower_lsp::jsonrpc; + +fn build_tgi_body(prompt: String, params: &RequestParams) -> Value { + serde_json::json!({ + "inputs": prompt, + "parameters": params, + }) +} + +fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result { + let mut headers = HeaderMap::new(); + let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}"); + headers.insert( + USER_AGENT, + HeaderValue::from_str(&user_agent).map_err(internal_error)?, + ); + + if let Some(api_token) = api_token { + headers.insert( + AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?, + ); + } + + Ok(headers) +} + +fn parse_tgi_text(text: &str) -> Result, jsonrpc::Error> { + let generations = + match serde_json::from_str(text).map_err(internal_error)? { + APIResponse::Generation(gen) => vec![gen], + APIResponse::Generations(_) => { + return Err(internal_error( + "You are attempting to parse a result in the API inference format when using the `tgi` adaptor", + )) + } + APIResponse::Error(err) => return Err(internal_error(err)), + }; + Ok(generations) +} + +fn build_api_body(prompt: String, params: &RequestParams) -> Value { + build_tgi_body(prompt, params) +} + +fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result { + build_tgi_headers(api_token, ide) +} + +fn parse_api_text(text: &str) -> Result, jsonrpc::Error> { + let generations = match serde_json::from_str(text).map_err(internal_error)? { + APIResponse::Generation(gen) => vec![gen], + APIResponse::Generations(gens) => gens, + APIResponse::Error(err) => return Err(internal_error(err)), + }; + Ok(generations) +} + +fn build_ollama_body(prompt: String, params: &CompletionParams) -> Value { + serde_json::json!({ + "prompt": prompt, + "model": params.request_body.as_ref().ok_or_else(|| internal_error("missing request_body")).expect("Unable to make request for ollama").get("model"), + "stream": false, + // As per [modelfile](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values) + "options": { + "num_predict": params.request_params.max_new_tokens, + "temperature": params.request_params.temperature, + "top_p": params.request_params.top_p, + "stop": params.request_params.stop_tokens.clone(), + } + }) +} +fn build_ollama_headers() -> Result { + Ok(HeaderMap::new()) +} + +#[derive(Debug, Serialize, Deserialize)] +struct OllamaGeneration { + response: String, +} + +impl From for Generation { + fn from(value: OllamaGeneration) -> Self { + Generation { + generated_text: value.response, + } + } +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum OllamaAPIResponse { + Generation(OllamaGeneration), + Error(APIError), +} + +fn parse_ollama_text(text: &str) -> Result, jsonrpc::Error> { + let generations = match serde_json::from_str(text).map_err(internal_error)? { + OllamaAPIResponse::Generation(gen) => vec![gen.into()], + OllamaAPIResponse::Error(err) => return Err(internal_error(err)), + }; + Ok(generations) +} + +fn build_openai_body(prompt: String, params: &CompletionParams) -> Value { + serde_json::json!({ + "prompt": prompt, + "model": params.request_body.as_ref().ok_or_else(|| internal_error("missing request_body")).expect("Unable to make request for openai").get("model"), + "max_tokens": params.request_params.max_new_tokens, + "temperature": params.request_params.temperature, + "top_p": params.request_params.top_p, + "stop": params.request_params.stop_tokens.clone(), + }) +} + +fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result { + build_api_headers(api_token, ide) +} + +#[derive(Debug, Deserialize)] +struct OpenAIGenerationChoice { + text: String, +} + +impl From for Generation { + fn from(value: OpenAIGenerationChoice) -> Self { + Generation { + generated_text: value.text, + } + } +} + +#[derive(Debug, Deserialize)] +struct OpenAIGeneration { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum OpenAIErrorLoc { + String(String), + Int(u32), +} + +impl Display for OpenAIErrorLoc { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + OpenAIErrorLoc::String(s) => s.fmt(f), + OpenAIErrorLoc::Int(i) => i.fmt(f), + } + } +} + +#[derive(Debug, Deserialize)] +struct OpenAIErrorDetail { + loc: OpenAIErrorLoc, + msg: String, + r#type: String, +} + +#[derive(Debug, Deserialize)] +struct OpenAIError { + detail: Vec, +} + +impl Display for OpenAIError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for (i, item) in self.detail.iter().enumerate() { + if i != 0 { + writeln!(f)?; + } + write!(f, "{}: {} ({})", item.loc, item.msg, item.r#type)?; + } + Ok(()) + } +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum OpenAIAPIResponse { + Generation(OpenAIGeneration), + Error(OpenAIError), +} + +fn parse_openai_text(text: &str) -> Result, jsonrpc::Error> { + match serde_json::from_str(text).map_err(internal_error) { + Ok(OpenAIAPIResponse::Generation(completion)) => { + Ok(completion.choices.into_iter().map(|x| x.into()).collect()) + } + Ok(OpenAIAPIResponse::Error(err)) => Err(internal_error(err)), + Err(err) => Err(internal_error(err)), + } +} + +const TGI: &str = "tgi"; +const HUGGING_FACE: &str = "huggingface"; +const OLLAMA: &str = "ollama"; +const OPENAI: &str = "openai"; +const DEFAULT_ADAPTOR: &str = HUGGING_FACE; + +fn unknown_adaptor_error(adaptor: Option<&String>) -> jsonrpc::Error { + internal_error(format!("Unknown adaptor {:?}", adaptor)) +} + +pub fn adapt_body(prompt: String, params: &CompletionParams) -> Result { + match params + .adaptor + .as_ref() + .unwrap_or(&DEFAULT_ADAPTOR.to_string()) + .as_str() + { + TGI => Ok(build_tgi_body(prompt, ¶ms.request_params)), + HUGGING_FACE => Ok(build_api_body(prompt, ¶ms.request_params)), + OLLAMA => Ok(build_ollama_body(prompt, params)), + OPENAI => Ok(build_openai_body(prompt, params)), + _ => Err(unknown_adaptor_error(params.adaptor.as_ref())), + } +} + +pub fn adapt_headers( + adaptor: Option<&String>, + api_token: Option<&String>, + ide: Ide, +) -> Result { + match adaptor.unwrap_or(&DEFAULT_ADAPTOR.to_string()).as_str() { + TGI => build_tgi_headers(api_token, ide), + HUGGING_FACE => build_api_headers(api_token, ide), + OLLAMA => build_ollama_headers(), + OPENAI => build_openai_headers(api_token, ide), + _ => Err(unknown_adaptor_error(adaptor)), + } +} + +pub fn parse_generations(adaptor: Option<&String>, text: &str) -> jsonrpc::Result> { + match adaptor.unwrap_or(&DEFAULT_ADAPTOR.to_string()).as_str() { + TGI => parse_tgi_text(text), + HUGGING_FACE => parse_api_text(text), + OLLAMA => parse_ollama_text(text), + OPENAI => parse_openai_text(text), + _ => Err(unknown_adaptor_error(adaptor)), + } +} diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index 9e96c29..054aada 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -1,3 +1,4 @@ +use adaptors::{adapt_body, adapt_headers, parse_generations}; use document::Document; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use ropey::Rope; @@ -18,20 +19,21 @@ use tracing_appender::rolling; use tracing_subscriber::EnvFilter; use uuid::Uuid; +mod adaptors; mod document; mod language_id; const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600); -const NAME: &str = "llm-ls"; -const VERSION: &str = env!("CARGO_PKG_VERSION"); +pub const NAME: &str = "llm-ls"; +pub const VERSION: &str = env!("CARGO_PKG_VERSION"); fn get_position_idx(rope: &Rope, row: usize, col: usize) -> Result { Ok(rope.try_line_to_char(row).map_err(internal_error)? + col.min( - rope.get_line(row.min(rope.len_lines() - 1)) + rope.get_line(row.min(rope.len_lines().saturating_sub(1))) .ok_or_else(|| internal_error(format!("failed to find line at {row}")))? .len_chars() - - 1, + .saturating_sub(1), )) } @@ -130,7 +132,7 @@ enum TokenizerConfig { #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] -struct RequestParams { +pub struct RequestParams { max_new_tokens: u32, temperature: f32, do_sample: bool, @@ -178,12 +180,12 @@ struct APIRequest { } #[derive(Debug, Serialize, Deserialize)] -struct Generation { +pub struct Generation { generated_text: String, } #[derive(Debug, Deserialize)] -struct APIError { +pub struct APIError { error: String, } @@ -195,7 +197,7 @@ impl Display for APIError { #[derive(Debug, Deserialize)] #[serde(untagged)] -enum APIResponse { +pub enum APIResponse { Generation(Generation), Generations(Vec), Error(APIError), @@ -219,7 +221,7 @@ struct Completion { #[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)] #[serde(rename_all = "lowercase")] -enum Ide { +pub enum Ide { Neovim, VSCode, JetBrains, @@ -261,7 +263,7 @@ struct RejectedCompletion { #[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] -struct CompletionParams { +pub struct CompletionParams { #[serde(flatten)] text_document_position: TextDocumentPositionParams, request_params: RequestParams, @@ -271,10 +273,12 @@ struct CompletionParams { fim: FimParams, api_token: Option, model: String, + adaptor: Option, tokens_to_clear: Vec, tokenizer_config: Option, context_window: usize, tls_skip_verify_insecure: bool, + request_body: Option>, } #[derive(Debug, Deserialize, Serialize)] @@ -283,7 +287,7 @@ struct CompletionResult { completions: Vec, } -fn internal_error(err: E) -> Error { +pub fn internal_error(err: E) -> Error { let err_msg = err.to_string(); error!(err_msg); Error { @@ -398,29 +402,30 @@ fn build_prompt( async fn request_completion( http_client: &reqwest::Client, - ide: Ide, - model: &str, - request_params: RequestParams, - api_token: Option<&String>, prompt: String, + params: &CompletionParams, ) -> Result> { let t = Instant::now(); + + let json = adapt_body(prompt, params).map_err(internal_error)?; + let headers = adapt_headers( + params.adaptor.as_ref(), + params.api_token.as_ref(), + params.ide, + )?; let res = http_client - .post(build_url(model)) - .json(&APIRequest { - inputs: prompt, - parameters: request_params.into(), - }) - .headers(build_headers(api_token, ide)?) + .post(build_url(¶ms.model)) + .json(&json) + .headers(headers) .send() .await .map_err(internal_error)?; - let generations = match res.json().await.map_err(internal_error)? { - APIResponse::Generation(gen) => vec![gen], - APIResponse::Generations(gens) => gens, - APIResponse::Error(err) => return Err(internal_error(err)), - }; + let model = ¶ms.model; + let generations = parse_generations( + params.adaptor.as_ref(), + res.text().await.map_err(internal_error)?.as_str(), + ); let time = t.elapsed().as_millis(); info!( model, @@ -428,10 +433,10 @@ async fn request_completion( generations = serde_json::to_string(&generations).map_err(internal_error)?, "{model} computed generations in {time} ms" ); - Ok(generations) + generations } -fn parse_generations( +fn format_generations( generations: Vec, tokens_to_clear: &[String], completion_type: CompletionType, @@ -524,7 +529,7 @@ async fn download_tokenizer_file( async fn get_tokenizer( model: &str, tokenizer_map: &mut HashMap>, - tokenizer_config: Option, + tokenizer_config: Option<&TokenizerConfig>, http_client: &reqwest::Client, cache_dir: impl AsRef, api_token: Option<&String>, @@ -543,7 +548,7 @@ async fn get_tokenizer( } }, TokenizerConfig::HuggingFace { repository } => { - let path = cache_dir.as_ref().join(model).join("tokenizer.json"); + let path = cache_dir.as_ref().join(repository).join("tokenizer.json"); let url = format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json"); download_tokenizer_file(http_client, &url, api_token, &path, ide).await?; @@ -556,7 +561,7 @@ async fn get_tokenizer( } } TokenizerConfig::Download { url, to } => { - download_tokenizer_file(http_client, &url, api_token, &to, ide).await?; + download_tokenizer_file(http_client, url, api_token, &to, ide).await?; match Tokenizer::from_file(to) { Ok(tokenizer) => Some(Arc::new(tokenizer)), Err(err) => { @@ -627,7 +632,7 @@ impl Backend { let tokenizer = get_tokenizer( ¶ms.model, &mut *self.tokenizer_map.write().await, - params.tokenizer_config, + params.tokenizer_config.as_ref(), &self.http_client, &self.cache_dir, params.api_token.as_ref(), @@ -650,15 +655,12 @@ impl Backend { }; let result = request_completion( http_client, - params.ide, - ¶ms.model, - params.request_params, - params.api_token.as_ref(), prompt, + ¶ms, ) .await?; - let completions = parse_generations(result, ¶ms.tokens_to_clear, completion_type); + let completions = format_generations(result, ¶ms.tokens_to_clear, completion_type); Ok(CompletionResult { request_id, completions }) }.instrument(span).await } @@ -849,3 +851,4 @@ async fn main() { Server::new(stdin, stdout, socket).serve(service).await; } +