diff --git a/Cargo.lock b/Cargo.lock index 7148385..02c6fb8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -659,12 +659,13 @@ checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" [[package]] name = "llm-ls" -version = "0.3.0" +version = "0.4.0" dependencies = [ "home", "reqwest", "ropey", "serde", + "serde_json", "tokenizers", "tokio", "tower-lsp", @@ -693,6 +694,7 @@ dependencies = [ "tree-sitter-scala", "tree-sitter-swift", "tree-sitter-typescript", + "uuid", ] [[package]] @@ -2039,6 +2041,17 @@ dependencies = [ "serde", ] +[[package]] +name = "uuid" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d" +dependencies = [ + "getrandom", + "rand", + "serde", +] + [[package]] name = "valuable" version = "0.1.0" diff --git a/README.md b/README.md index c339e7f..507b50f 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,27 @@ > [!IMPORTANT] > This is currently a work in progress, expect things to be broken! -**llm-ls** is a LSP server leveraging LLMs for code completion (and more?). +**llm-ls** is a LSP server leveraging LLMs to make your development experience smoother and more efficient. + +The goal of llm-ls is to provide a common platform for IDE extensions to be build on. llm-ls takes care of the heavy lifting with regards to interacting with LLMs so that extension code can be as lightweight as possible. + +## Features + +### Prompt + +Uses the current file as context to generate the prompt. Can use "fill in the middle" or not depending on your needs. + +It also makes sure that you are within the context window of the model by tokenizing the prompt. + +### Telemetry + +Gathers information about requests and completions that can enable retraining. + +Note that **llm-ls** does not export any data anywhere (other than setting a user agent when querying the model API), everything is stored in a log file if you set the log level to `info`. + +### Completion + +**llm-ls** parses the AST of the code to determine if completions should be multi line, single line or empty (no completion). ## Compatible extensions @@ -12,3 +32,12 @@ - [x] [llm-intellij](https://github.com/huggingface/llm-intellij) - [ ] [jupytercoder](https://github.com/bigcode-project/jupytercoder) +## Roadmap + +- support getting context from multiple files in the workspace +- add `suffix_percent` setting that determines the ratio of # of tokens for the prefix vs the suffix in the prompt +- add context window fill percent or change context_window to `max_tokens` +- filter bad suggestions (repetitive, same as below, etc) +- support for ollama +- support for llama.cpp +- oltp traces ? diff --git a/crates/llm-ls/Cargo.toml b/crates/llm-ls/Cargo.toml index f4683d4..bd8b823 100644 --- a/crates/llm-ls/Cargo.toml +++ b/crates/llm-ls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "llm-ls" -version = "0.3.0" +version = "0.4.0" edition = "2021" [[bin]] @@ -11,6 +11,7 @@ home = "0.5" ropey = "1.6" reqwest = { version = "0.11", default-features = false, features = ["json", "rustls-tls"] } serde = { version = "1", features = ["derive"] } +serde_json = "1" tokenizers = { version = "0.13", default-features = false, features = ["onig"] } tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "rt-multi-thread"] } tower-lsp = "0.20" @@ -40,3 +41,10 @@ tree-sitter-scala = "0.20" tree-sitter-swift = "0.3" tree-sitter-typescript = "0.20" +[dependencies.uuid] +version = "1.4" +features = [ + "v4", + "fast-rng", + "serde", +] diff --git a/crates/llm-ls/src/document.rs b/crates/llm-ls/src/document.rs index 2ee1544..012dc6f 100644 --- a/crates/llm-ls/src/document.rs +++ b/crates/llm-ls/src/document.rs @@ -166,8 +166,7 @@ fn get_parser(language_id: LanguageId) -> Result { } pub(crate) struct Document { - #[allow(dead_code)] - language_id: LanguageId, + pub(crate) language_id: LanguageId, pub(crate) text: Rope, parser: Parser, pub(crate) tree: Option, diff --git a/crates/llm-ls/src/language_id.rs b/crates/llm-ls/src/language_id.rs index 2892268..c311958 100644 --- a/crates/llm-ls/src/language_id.rs +++ b/crates/llm-ls/src/language_id.rs @@ -1,6 +1,7 @@ +use serde::{Deserialize, Serialize}; use std::fmt; -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Serialize, Deserialize)] pub(crate) enum LanguageId { Bash, C, diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index 39e07ae..1b8ed40 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -13,9 +13,10 @@ use tokio::sync::RwLock; use tower_lsp::jsonrpc::{Error, Result}; use tower_lsp::lsp_types::*; use tower_lsp::{Client, LanguageServer, LspService, Server}; -use tracing::{debug, error, info, warn}; +use tracing::{debug, error, info, info_span, warn, Instrument}; use tracing_appender::rolling; use tracing_subscriber::EnvFilter; +use uuid::Uuid; mod document; mod language_id; @@ -24,13 +25,23 @@ const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600); const NAME: &str = "llm-ls"; const VERSION: &str = env!("CARGO_PKG_VERSION"); -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] enum CompletionType { Empty, SingleLine, MultiLine, } +impl Display for CompletionType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CompletionType::Empty => write!(f, "empty"), + CompletionType::SingleLine => write!(f, "single_line"), + CompletionType::MultiLine => write!(f, "multi_line"), + } + } +} + fn should_complete(document: &Document, position: Position) -> CompletionType { let row = position.line as usize; let column = position.character as usize; @@ -129,7 +140,7 @@ struct APIRequest { parameters: APIParams, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] struct Generation { generated_text: String, } @@ -196,6 +207,19 @@ where Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown)) } +#[derive(Debug, Deserialize, Serialize)] +struct AcceptedCompletion { + request_id: Uuid, + accepted_completion: u32, + shown_completions: Vec, +} + +#[derive(Debug, Deserialize, Serialize)] +struct RejectedCompletion { + request_id: Uuid, + shown_completions: Vec, +} + #[derive(Debug, Deserialize, Serialize)] struct CompletionParams { #[serde(flatten)] @@ -213,6 +237,12 @@ struct CompletionParams { tls_skip_verify_insecure: bool, } +#[derive(Debug, Deserialize, Serialize)] +struct CompletionResult { + request_id: Uuid, + completions: Vec, +} + fn internal_error(err: E) -> Error { let err_msg = err.to_string(); error!(err_msg); @@ -292,7 +322,7 @@ fn build_prompt( fim.middle ); let time = t.elapsed().as_millis(); - info!(build_prompt_ms = time, "built prompt in {time} ms"); + info!(prompt, build_prompt_ms = time, "built prompt in {time} ms"); Ok(prompt) } else { let mut token_count = context_window; @@ -321,7 +351,7 @@ fn build_prompt( } let prompt = before.into_iter().rev().collect::>().join(""); let time = t.elapsed().as_millis(); - info!(build_prompt_ms = time, "built prompt in {time} ms"); + info!(prompt, build_prompt_ms = time, "built prompt in {time} ms"); Ok(prompt) } } @@ -334,6 +364,7 @@ async fn request_completion( api_token: Option<&String>, prompt: String, ) -> Result> { + let t = Instant::now(); let res = http_client .post(build_url(model)) .json(&APIRequest { @@ -345,11 +376,19 @@ async fn request_completion( .await .map_err(internal_error)?; - match res.json().await.map_err(internal_error)? { - APIResponse::Generation(gen) => Ok(vec![gen]), - APIResponse::Generations(gens) => Ok(gens), - APIResponse::Error(err) => Err(internal_error(err)), - } + let generations = match res.json().await.map_err(internal_error)? { + APIResponse::Generation(gen) => vec![gen], + APIResponse::Generations(gens) => gens, + APIResponse::Error(err) => return Err(internal_error(err)), + }; + let time = t.elapsed().as_millis(); + info!( + model, + compute_generations_ms = time, + generations = serde_json::to_string(&generations).map_err(internal_error)?, + "{model} computed generations in {time} ms" + ); + Ok(generations) } fn parse_generations( @@ -505,68 +544,102 @@ fn build_url(model: &str) -> String { } impl Backend { - async fn get_completions(&self, params: CompletionParams) -> Result> { - info!("get_completions {params:?}"); - let document_map = self.document_map.read().await; + async fn get_completions(&self, params: CompletionParams) -> Result { + let request_id = Uuid::new_v4(); + let span = info_span!("completion_request", %request_id); + async move { + let document_map = self.document_map.read().await; - let document = document_map - .get(params.text_document_position.text_document.uri.as_str()) - .ok_or_else(|| internal_error("failed to find document"))?; - if params.api_token.is_none() { - let now = Instant::now(); - let unauthenticated_warn_at = self.unauthenticated_warn_at.read().await; - if now.duration_since(*unauthenticated_warn_at) > MAX_WARNING_REPEAT { - drop(unauthenticated_warn_at); - self.client.show_message(MessageType::WARNING, "You are currently unauthenticated and will get rate limited. To reduce rate limiting, login with your API Token and consider subscribing to PRO: https://huggingface.co/pricing#pro").await; - let mut unauthenticated_warn_at = self.unauthenticated_warn_at.write().await; - *unauthenticated_warn_at = Instant::now(); + let document = document_map + .get(params.text_document_position.text_document.uri.as_str()) + .ok_or_else(|| internal_error("failed to find document"))?; + info!( + document_url = %params.text_document_position.text_document.uri, + cursor_line = ?params.text_document_position.position.line, + cursor_character = ?params.text_document_position.position.character, + language_id = %document.language_id, + model = params.model, + ide = %params.ide, + max_new_tokens = params.request_params.max_new_tokens, + temperature = params.request_params.temperature, + do_sample = params.request_params.do_sample, + top_p = params.request_params.top_p, + stop_tokens = ?params.request_params.stop_tokens, + "received completion request for {}", + params.text_document_position.text_document.uri + ); + if params.api_token.is_none() { + let now = Instant::now(); + let unauthenticated_warn_at = self.unauthenticated_warn_at.read().await; + if now.duration_since(*unauthenticated_warn_at) > MAX_WARNING_REPEAT { + drop(unauthenticated_warn_at); + self.client.show_message(MessageType::WARNING, "You are currently unauthenticated and will get rate limited. To reduce rate limiting, login with your API Token and consider subscribing to PRO: https://huggingface.co/pricing#pro").await; + let mut unauthenticated_warn_at = self.unauthenticated_warn_at.write().await; + *unauthenticated_warn_at = Instant::now(); + } + } + let completion_type = should_complete(document, params.text_document_position.position); + info!(%completion_type, "completion type: {completion_type:?}"); + if completion_type == CompletionType::Empty { + return Ok(CompletionResult { request_id, completions: vec![]}); } - } - let completion_type = should_complete(document, params.text_document_position.position); - info!("completion type: {completion_type:?}"); - if completion_type == CompletionType::Empty { - return Ok(vec![]); - } - let tokenizer = get_tokenizer( - ¶ms.model, - &mut *self.tokenizer_map.write().await, - params.tokenizer_config, - &self.http_client, - &self.cache_dir, - params.api_token.as_ref(), - params.ide, - ) - .await?; - let prompt = build_prompt( - params.text_document_position.position, - &document.text, - ¶ms.fim, - tokenizer, - params.context_window, - )?; + let tokenizer = get_tokenizer( + ¶ms.model, + &mut *self.tokenizer_map.write().await, + params.tokenizer_config, + &self.http_client, + &self.cache_dir, + params.api_token.as_ref(), + params.ide, + ) + .await?; + let prompt = build_prompt( + params.text_document_position.position, + &document.text, + ¶ms.fim, + tokenizer, + params.context_window, + )?; - let http_client = if params.tls_skip_verify_insecure { - info!("tls verification is disabled"); - &self.unsafe_http_client - } else { - &self.http_client - }; - let result = request_completion( - http_client, - params.ide, - ¶ms.model, - params.request_params, - params.api_token.as_ref(), - prompt, - ) - .await?; + let http_client = if params.tls_skip_verify_insecure { + info!("tls verification is disabled"); + &self.unsafe_http_client + } else { + &self.http_client + }; + let result = request_completion( + http_client, + params.ide, + ¶ms.model, + params.request_params, + params.api_token.as_ref(), + prompt, + ) + .await?; - Ok(parse_generations( - result, - ¶ms.tokens_to_clear, - completion_type, - )) + let completions = parse_generations(result, ¶ms.tokens_to_clear, completion_type); + Ok(CompletionResult { request_id, completions }) + }.instrument(span).await + } + + async fn accept_completion(&self, accepted: AcceptedCompletion) -> Result<()> { + info!( + request_id = %accepted.request_id, + accepted_position = accepted.accepted_completion, + shown_completions = serde_json::to_string(&accepted.shown_completions).map_err(internal_error)?, + "accepted completion" + ); + Ok(()) + } + + async fn reject_completion(&self, rejected: RejectedCompletion) -> Result<()> { + info!( + request_id = %rejected.request_id, + shown_completions = serde_json::to_string(&rejected.shown_completions).map_err(internal_error)?, + "rejected completion" + ); + Ok(()) } } @@ -724,6 +797,8 @@ async fn main() { )), }) .custom_method("llm-ls/getCompletions", Backend::get_completions) + .custom_method("llm-ls/acceptCompletion", Backend::accept_completion) + .custom_method("llm-ls/rejectCompletion", Backend::reject_completion) .finish(); Server::new(stdin, stdout, socket).serve(service).await;