diff --git a/crates/llm-ls/src/adaptors.rs b/crates/llm-ls/src/backend.rs similarity index 57% rename from crates/llm-ls/src/adaptors.rs rename to crates/llm-ls/src/backend.rs index 553fc87..a139870 100644 --- a/crates/llm-ls/src/adaptors.rs +++ b/crates/llm-ls/src/backend.rs @@ -1,26 +1,12 @@ use super::{ - internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, RequestParams, NAME, - VERSION, + internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION, }; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use serde::{Deserialize, Serialize}; -use serde_json::Value; +use serde_json::{Map, Value}; use std::fmt::Display; use tower_lsp::jsonrpc; -fn build_tgi_body(prompt: String, params: &RequestParams) -> Value { - serde_json::json!({ - "inputs": prompt, - "parameters": { - "max_new_tokens": params.max_new_tokens, - "temperature": params.temperature, - "do_sample": params.do_sample, - "top_p": params.top_p, - "stop_tokens": params.stop_tokens.clone() - }, - }) -} - fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result { let mut headers = HeaderMap::new(); let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}"); @@ -45,7 +31,7 @@ fn parse_tgi_text(text: &str) -> Result, jsonrpc::Error> { APIResponse::Generation(gen) => vec![gen], APIResponse::Generations(_) => { return Err(internal_error( - "You are attempting to parse a result in the API inference format when using the `tgi` adaptor", + "You are attempting to parse a result in the API inference format when using the `tgi` backend", )) } APIResponse::Error(err) => return Err(internal_error(err)), @@ -53,10 +39,6 @@ fn parse_tgi_text(text: &str) -> Result, jsonrpc::Error> { Ok(generations) } -fn build_api_body(prompt: String, params: &RequestParams) -> Value { - build_tgi_body(prompt, params) -} - fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result { build_tgi_headers(api_token, ide) } @@ -70,20 +52,6 @@ fn parse_api_text(text: &str) -> Result, jsonrpc::Error> { Ok(generations) } -fn build_ollama_body(prompt: String, params: &CompletionParams) -> Value { - serde_json::json!({ - "prompt": prompt, - "model": params.request_body.as_ref().ok_or_else(|| internal_error("missing request_body")).expect("Unable to make request for ollama").get("model"), - "stream": false, - // As per [modelfile](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values) - "options": { - "num_predict": params.request_params.max_new_tokens, - "temperature": params.request_params.temperature, - "top_p": params.request_params.top_p, - "stop": params.request_params.stop_tokens.clone(), - } - }) -} fn build_ollama_headers() -> Result { Ok(HeaderMap::new()) } @@ -116,17 +84,6 @@ fn parse_ollama_text(text: &str) -> Result, jsonrpc::Error> { Ok(generations) } -fn build_openai_body(prompt: String, params: &CompletionParams) -> Value { - serde_json::json!({ - "prompt": prompt, - "model": params.request_body.as_ref().ok_or_else(|| internal_error("missing request_body")).expect("Unable to make request for openai").get("model"), - "max_tokens": params.request_params.max_new_tokens, - "temperature": params.request_params.temperature, - "top_p": params.request_params.top_p, - "stop": params.request_params.stop_tokens.clone(), - }) -} - fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result { build_api_headers(api_token, ide) } @@ -206,51 +163,47 @@ fn parse_openai_text(text: &str) -> Result, jsonrpc::Error> { } } -pub(crate) const TGI: &str = "tgi"; -pub(crate) const HUGGING_FACE: &str = "huggingface"; -pub(crate) const OLLAMA: &str = "ollama"; -pub(crate) const OPENAI: &str = "openai"; -pub(crate) const DEFAULT_ADAPTOR: &str = HUGGING_FACE; - -fn unknown_adaptor_error(adaptor: Option<&String>) -> jsonrpc::Error { - internal_error(format!("Unknown adaptor {:?}", adaptor)) +#[derive(Debug, Default, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub(crate) enum Backend { + #[default] + HuggingFace, + Ollama, + OpenAi, + Tgi, } -pub fn adapt_body(prompt: String, params: &CompletionParams) -> Result { - match params - .adaptor - .as_ref() - .unwrap_or(&DEFAULT_ADAPTOR.to_string()) - .as_str() - { - TGI => Ok(build_tgi_body(prompt, ¶ms.request_params)), - HUGGING_FACE => Ok(build_api_body(prompt, ¶ms.request_params)), - OLLAMA => Ok(build_ollama_body(prompt, params)), - OPENAI => Ok(build_openai_body(prompt, params)), - _ => Err(unknown_adaptor_error(params.adaptor.as_ref())), - } +pub fn build_body(prompt: String, params: &CompletionParams) -> Map { + let mut body = params.request_body.clone(); + match params.backend { + Backend::HuggingFace | Backend::Tgi => { + body.insert("inputs".to_string(), Value::String(prompt)) + } + Backend::Ollama | Backend::OpenAi => { + body.insert("prompt".to_string(), Value::String(prompt)) + } + }; + body } -pub fn adapt_headers( - adaptor: Option<&String>, +pub fn build_headers( + backend: &Backend, api_token: Option<&String>, ide: Ide, ) -> Result { - match adaptor.unwrap_or(&DEFAULT_ADAPTOR.to_string()).as_str() { - TGI => build_tgi_headers(api_token, ide), - HUGGING_FACE => build_api_headers(api_token, ide), - OLLAMA => build_ollama_headers(), - OPENAI => build_openai_headers(api_token, ide), - _ => Err(unknown_adaptor_error(adaptor)), + match backend { + Backend::HuggingFace => build_api_headers(api_token, ide), + Backend::Ollama => build_ollama_headers(), + Backend::OpenAi => build_openai_headers(api_token, ide), + Backend::Tgi => build_tgi_headers(api_token, ide), } } -pub fn parse_generations(adaptor: Option<&String>, text: &str) -> jsonrpc::Result> { - match adaptor.unwrap_or(&DEFAULT_ADAPTOR.to_string()).as_str() { - TGI => parse_tgi_text(text), - HUGGING_FACE => parse_api_text(text), - OLLAMA => parse_ollama_text(text), - OPENAI => parse_openai_text(text), - _ => Err(unknown_adaptor_error(adaptor)), +pub fn parse_generations(backend: &Backend, text: &str) -> jsonrpc::Result> { + match backend { + Backend::HuggingFace => parse_api_text(text), + Backend::Ollama => parse_ollama_text(text), + Backend::OpenAi => parse_openai_text(text), + Backend::Tgi => parse_tgi_text(text), } } diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index 318be3b..fcd7d3c 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -1,8 +1,8 @@ -use adaptors::{adapt_body, adapt_headers, parse_generations}; +use backend::{build_body, build_headers, parse_generations, Backend}; use document::Document; -use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use ropey::Rope; use serde::{Deserialize, Deserializer, Serialize}; +use serde_json::{Map, Value}; use std::collections::HashMap; use std::fmt::Display; use std::path::{Path, PathBuf}; @@ -19,7 +19,7 @@ use tracing_appender::rolling; use tracing_subscriber::EnvFilter; use uuid::Uuid; -mod adaptors; +mod backend; mod document; mod language_id; @@ -117,10 +117,7 @@ fn should_complete(document: &Document, position: Position) -> Result Result, + }, + Download { + url: String, + to: PathBuf, + }, } #[derive(Clone, Debug, Deserialize, Serialize)] @@ -209,7 +214,7 @@ pub enum APIResponse { Error(APIError), } -struct Backend { +struct LlmService { cache_dir: PathBuf, client: Client, document_map: Arc>>, @@ -272,19 +277,18 @@ struct RejectedCompletion { pub struct CompletionParams { #[serde(flatten)] text_document_position: TextDocumentPositionParams, - request_params: RequestParams, #[serde(default)] #[serde(deserialize_with = "parse_ide")] ide: Ide, fim: FimParams, api_token: Option, model: String, - adaptor: Option, + backend: Backend, tokens_to_clear: Vec, tokenizer_config: Option, context_window: usize, tls_skip_verify_insecure: bool, - request_body: Option>, + request_body: Map, } #[derive(Debug, Deserialize, Serialize)] @@ -413,12 +417,8 @@ async fn request_completion( ) -> Result> { let t = Instant::now(); - let json = adapt_body(prompt, params).map_err(internal_error)?; - let headers = adapt_headers( - params.adaptor.as_ref(), - params.api_token.as_ref(), - params.ide, - )?; + let json = build_body(prompt, params); + let headers = build_headers(¶ms.backend, params.api_token.as_ref(), params.ide)?; let res = http_client .post(build_url(¶ms.model)) .json(&json) @@ -429,7 +429,7 @@ async fn request_completion( let model = ¶ms.model; let generations = parse_generations( - params.adaptor.as_ref(), + ¶ms.backend, res.text().await.map_err(internal_error)?.as_str(), ); let time = t.elapsed().as_millis(); @@ -489,7 +489,7 @@ async fn download_tokenizer_file( ) .await .map_err(internal_error)?; - let headers = build_headers(api_token, ide)?; + let headers = build_headers(&Backend::HuggingFace, api_token, ide)?; let mut file = tokio::fs::OpenOptions::new() .write(true) .create(true) @@ -538,7 +538,6 @@ async fn get_tokenizer( tokenizer_config: Option<&TokenizerConfig>, http_client: &reqwest::Client, cache_dir: impl AsRef, - api_token: Option<&String>, ide: Ide, ) -> Result>> { if let Some(tokenizer) = tokenizer_map.get(model) { @@ -553,11 +552,14 @@ async fn get_tokenizer( None } }, - TokenizerConfig::HuggingFace { repository } => { + TokenizerConfig::HuggingFace { + repository, + api_token, + } => { let path = cache_dir.as_ref().join(repository).join("tokenizer.json"); let url = format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json"); - download_tokenizer_file(http_client, &url, api_token, &path, ide).await?; + download_tokenizer_file(http_client, &url, api_token.as_ref(), &path, ide).await?; match Tokenizer::from_file(path) { Ok(tokenizer) => Some(Arc::new(tokenizer)), Err(err) => { @@ -567,7 +569,7 @@ async fn get_tokenizer( } } TokenizerConfig::Download { url, to } => { - download_tokenizer_file(http_client, url, api_token, &to, ide).await?; + download_tokenizer_file(http_client, url, None, &to, ide).await?; match Tokenizer::from_file(to) { Ok(tokenizer) => Some(Arc::new(tokenizer)), Err(err) => { @@ -594,7 +596,7 @@ fn build_url(model: &str) -> String { } } -impl Backend { +impl LlmService { async fn get_completions(&self, params: CompletionParams) -> Result { let request_id = Uuid::new_v4(); let span = info_span!("completion_request", %request_id); @@ -611,15 +613,11 @@ impl Backend { language_id = %document.language_id, model = params.model, ide = %params.ide, - max_new_tokens = params.request_params.max_new_tokens, - temperature = params.request_params.temperature, - do_sample = params.request_params.do_sample, - top_p = params.request_params.top_p, - stop_tokens = ?params.request_params.stop_tokens, + request_body = serde_json::to_string(¶ms.request_body).map_err(internal_error)?, "received completion request for {}", params.text_document_position.text_document.uri ); - let is_using_inference_api = params.adaptor.as_ref().unwrap_or(&adaptors::DEFAULT_ADAPTOR.to_owned()).as_str() == adaptors::HUGGING_FACE; + let is_using_inference_api = matches!(params.backend, Backend::HuggingFace); if params.api_token.is_none() && is_using_inference_api { let now = Instant::now(); let unauthenticated_warn_at = self.unauthenticated_warn_at.read().await; @@ -642,7 +640,6 @@ impl Backend { params.tokenizer_config.as_ref(), &self.http_client, &self.cache_dir, - params.api_token.as_ref(), params.ide, ) .await?; @@ -693,7 +690,7 @@ impl Backend { } #[tower_lsp::async_trait] -impl LanguageServer for Backend { +impl LanguageServer for LlmService { async fn initialize(&self, params: InitializeParams) -> Result { *self.workspace_folders.write().await = params.workspace_folders; Ok(InitializeResult { @@ -795,24 +792,6 @@ impl LanguageServer for Backend { } } -fn build_headers(api_token: Option<&String>, ide: Ide) -> Result { - let mut headers = HeaderMap::new(); - let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}"); - headers.insert( - USER_AGENT, - HeaderValue::from_str(&user_agent).map_err(internal_error)?, - ); - - if let Some(api_token) = api_token { - headers.insert( - AUTHORIZATION, - HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?, - ); - } - - Ok(headers) -} - #[tokio::main] async fn main() { let stdin = tokio::io::stdin(); @@ -846,7 +825,7 @@ async fn main() { .build() .expect("failed to build reqwest unsafe client"); - let (service, socket) = LspService::build(|client| Backend { + let (service, socket) = LspService::build(|client| LlmService { cache_dir, client, document_map: Arc::new(RwLock::new(HashMap::new())), @@ -860,9 +839,9 @@ async fn main() { .expect("instant to be in bounds"), )), }) - .custom_method("llm-ls/getCompletions", Backend::get_completions) - .custom_method("llm-ls/acceptCompletion", Backend::accept_completion) - .custom_method("llm-ls/rejectCompletion", Backend::reject_completion) + .custom_method("llm-ls/getCompletions", LlmService::get_completions) + .custom_method("llm-ls/acceptCompletion", LlmService::accept_completion) + .custom_method("llm-ls/rejectCompletion", LlmService::reject_completion) .finish(); Server::new(stdin, stdout, socket).serve(service).await;