diff --git a/Cargo.lock b/Cargo.lock index d941672..3f3e531 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,15 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "aho-corasick" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b8f9420f797f2d9e935edf629310eb938a0d839f984e25327f3c7eed22300c" +dependencies = [ + "memchr", +] + [[package]] name = "async-trait" version = "0.1.72" @@ -100,17 +109,6 @@ dependencies = [ "libc", ] -[[package]] -name = "ccserver" -version = "0.1.0" -dependencies = [ - "home", - "reqwest", - "serde", - "tokio", - "tower-lsp", -] - [[package]] name = "cfg-if" version = "1.0.0" @@ -133,6 +131,25 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" +[[package]] +name = "crossbeam-channel" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" +dependencies = [ + "cfg-if", +] + [[package]] name = "dashmap" version = "5.5.0" @@ -146,6 +163,12 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "deranged" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7684a49fb1af197853ef7b2ee694bc1f5b4179556f1e5710e1760c5db6f5e929" + [[package]] name = "encoding_rs" version = "0.8.32" @@ -471,6 +494,21 @@ version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57bcfdad1b858c2db7c38303a6d2ad4dfaf5eb53dfeb0910128b2c26d6158503" +[[package]] +name = "llm-ls" +version = "0.1.0" +dependencies = [ + "home", + "reqwest", + "ropey", + "serde", + "tokio", + "tower-lsp", + "tracing", + "tracing-appender", + "tracing-subscriber", +] + [[package]] name = "lock_api" version = "0.4.10" @@ -500,6 +538,15 @@ dependencies = [ "url", ] +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] + [[package]] name = "memchr" version = "2.5.0" @@ -550,6 +597,16 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "num_cpus" version = "1.16.0" @@ -619,6 +676,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "parking_lot_core" version = "0.9.8" @@ -727,6 +790,50 @@ dependencies = [ "bitflags 1.3.2", ] +[[package]] +name = "regex" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81bc1d4caf89fac26a70747fe603c130093b53c773888797a6329091246d651a" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata 0.3.6", + "regex-syntax 0.7.4", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", +] + +[[package]] +name = "regex-automata" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fed1ceff11a1dddaee50c9dc8e4938bd106e9d89ae372f192311e7da498e3b69" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax 0.7.4", +] + +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + +[[package]] +name = "regex-syntax" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2" + [[package]] name = "reqwest" version = "0.11.18" @@ -764,6 +871,16 @@ dependencies = [ "winreg", ] +[[package]] +name = "ropey" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53ce7a2c43a32e50d666e33c5a80251b31147bb4b49024bcab11fb6f20c671ed" +dependencies = [ + "smallvec", + "str_indices", +] + [[package]] name = "rustc-demangle" version = "0.1.23" @@ -881,6 +998,15 @@ dependencies = [ "serde", ] +[[package]] +name = "sharded-slab" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "900fba806f70c630b0a382d0d825e17a0f19fcd059a2ade1ff237bcddf446b31" +dependencies = [ + "lazy_static", +] + [[package]] name = "slab" version = "0.4.8" @@ -906,6 +1032,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "str_indices" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f026164926842ec52deb1938fae44f83dfdb82d0a5b0270c5bd5935ab74d6dd" + [[package]] name = "syn" version = "1.0.109" @@ -941,6 +1073,44 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "thread_local" +version = "1.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" +dependencies = [ + "cfg-if", + "once_cell", +] + +[[package]] +name = "time" +version = "0.3.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fdd63d58b18d663fbdf70e049f00a22c8e42be082203be7f26589213cd75ea" +dependencies = [ + "deranged", + "itoa", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" + +[[package]] +name = "time-macros" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb71511c991639bb078fd5bf97757e03914361c48100d52878b8e52b46fb92cd" +dependencies = [ + "time-core", +] + [[package]] name = "tinyvec" version = "1.6.0" @@ -1031,9 +1201,9 @@ checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" [[package]] name = "tower-lsp" -version = "0.19.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b38fb0e6ce037835174256518aace3ca621c4f96383c56bb846cfc11b341910" +checksum = "d4ba052b54a6627628d9b3c34c176e7eda8359b7da9acd497b9f20998d118508" dependencies = [ "async-trait", "auto_impl", @@ -1054,13 +1224,13 @@ dependencies = [ [[package]] name = "tower-lsp-macros" -version = "0.8.0" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34723c06344244474fdde365b76aebef8050bf6be61a935b91ee9ff7c4e91157" +checksum = "84fd902d4e0b9a4b27f2f440108dc034e1758628a9b702f8ec61ad66355422fa" dependencies = [ "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.28", ] [[package]] @@ -1081,6 +1251,17 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-appender" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09d48f71a791638519505cefafe162606f706c25592e4bde4d97600c0195312e" +dependencies = [ + "crossbeam-channel", + "time", + "tracing-subscriber", +] + [[package]] name = "tracing-attributes" version = "0.1.26" @@ -1099,6 +1280,49 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ddad33d2d10b1ed7eb9d1f518a5674713876e97e5bb9b7345a7984fbb4f922" +dependencies = [ + "lazy_static", + "log", + "tracing-core", +] + +[[package]] +name = "tracing-serde" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc6b213177105856957181934e4920de57730fc69bf42c37ee5bb664d406d9e1" +dependencies = [ + "serde", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex", + "serde", + "serde_json", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", + "tracing-serde", ] [[package]] @@ -1140,6 +1364,12 @@ dependencies = [ "serde", ] +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + [[package]] name = "vcpkg" version = "0.2.15" diff --git a/Cargo.toml b/Cargo.toml index ae2eaeb..ceeb119 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "ccserver" +name = "llm-ls" version = "0.1.0" edition = "2021" @@ -7,8 +7,12 @@ edition = "2021" [dependencies] home = "0.5" -serde = { version="1", features = ["derive"] } +ropey = "1.6" +serde = { version = "1", features = ["derive"] } reqwest = { version = "0.11", features = ["json"] } tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "rt-multi-thread"] } -tower-lsp = "0.19" +tower-lsp = "0.20" +tracing = "0.1" +tracing-appender = "0.2" +tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } diff --git a/README.md b/README.md index 7071508..2c73da3 100644 --- a/README.md +++ b/README.md @@ -1,37 +1,13 @@ -# ccserver +# llm-ls > [!IMPORTANT] -> This is currently a work in progress. +> This is currently a work in progress, expect things to be broken! -**ccserver** is a LSP server for ML code completion (and more?). +**llm-ls** is a LSP server leveraging LLMs for code completion (and more?). -## Developing +## Compatible extensions -Clone/fork this repo and run `cargo build [--release]`. +- [x] [llm.nvim](https://github.com/huggingface/llm.nvim) +- [ ] [huggingface-vscode](https://github.com/huggingface/huggingface-vscode) +- [ ] [jupytercoder](https://github.com/bigcode-project/jupytercoder) -Then add the following code to your lua config: - -```lua -local client_id = vim.lsp.start({ - name = "ccserver", - cmd = { "/path/to/ccserver/target/{debug|release}/ccserver" }, - root_dir = vim.fs.dirname(vim.fs.find({ ".git" }, { upward = true })[1]), -}) - -if client_id == nil then - vim.notify("[ccserver] Error starting server", vim.log.levels.ERROR) -else - local augroup = "ccserver" - - vim.api.nvim_create_augroup(augroup, { clear = true }) - - vim.api.nvim_create_autocmd("BufEnter", { - pattern = "*", - callback = function(ev) - if not vim.lsp.buf_is_attached(ev.buf, client_id) then - vim.lsp.buf_attach_client(ev.buf, client_id) - end - end, - }) -end -``` diff --git a/src/language_comments.rs b/src/language_comments.rs new file mode 100644 index 0000000..41c521a --- /dev/null +++ b/src/language_comments.rs @@ -0,0 +1,413 @@ +use crate::LanguageComment; +use std::collections::HashMap; + +pub fn build_language_comments() -> HashMap { + HashMap::from([ + ( + "abap".to_owned(), + LanguageComment { + open: "*".to_owned(), + close: None, + }, + ), + ( + "bat".to_owned(), + LanguageComment { + open: "REM".to_owned(), + close: None, + }, + ), + ( + "bibtex".to_owned(), + LanguageComment { + open: "%".to_owned(), + close: None, + }, + ), + ( + "clojure".to_owned(), + LanguageComment { + open: ";;".to_owned(), + close: None, + }, + ), + ( + "coffeescript".to_owned(), + LanguageComment { + open: "#".to_owned(), + close: None, + }, + ), + ( + "c".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "cpp".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "csharp".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "css".to_owned(), + LanguageComment { + open: "/*".to_owned(), + close: Some("*/".to_owned()), + }, + ), + ( + "diff".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "dart".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "dockerfile".to_owned(), + LanguageComment { + open: "#".to_owned(), + close: None, + }, + ), + ( + "elixir".to_owned(), + LanguageComment { + open: "#".to_owned(), + close: None, + }, + ), + ( + "erlang".to_owned(), + LanguageComment { + open: "%".to_owned(), + close: None, + }, + ), + ( + "fsharp".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "git-commit".to_owned(), + LanguageComment { + open: "#".to_owned(), + close: None, + }, + ), + ( + "git-rebase".to_owned(), + LanguageComment { + open: "#".to_owned(), + close: None, + }, + ), + ( + "go".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "groovy".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "handlebars".to_owned(), + LanguageComment { + open: "{{!--".to_owned(), + close: Some("--}}".to_owned()), + }, + ), + ( + "html".to_owned(), + LanguageComment { + open: "".to_owned()), + }, + ), + ( + "ini".to_owned(), + LanguageComment { + open: ";".to_owned(), + close: None, + }, + ), + ( + "java".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "javascript".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "javascriptreact".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "json".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "latex".to_owned(), + LanguageComment { + open: "%".to_owned(), + close: None, + }, + ), + ( + "less".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "lua".to_owned(), + LanguageComment { + open: "--".to_owned(), + close: None, + }, + ), + ( + "makefile".to_owned(), + LanguageComment { + open: "#".to_owned(), + close: None, + }, + ), + ( + "markdown".to_owned(), + LanguageComment { + open: "".to_owned()), + }, + ), + ( + "objective-c".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "objective-cpp".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "perl".to_owned(), + LanguageComment { + open: "#".to_owned(), + close: None, + }, + ), + ( + "perl6".to_owned(), + LanguageComment { + open: "#".to_owned(), + close: None, + }, + ), + ( + "php".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "powershell".to_owned(), + LanguageComment { + open: "#".to_owned(), + close: None, + }, + ), + ( + "jade".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "python".to_owned(), + LanguageComment { + open: "#".to_owned(), + close: None, + }, + ), + ( + "r".to_owned(), + LanguageComment { + open: "#".to_owned(), + close: None, + }, + ), + ( + "razor".to_owned(), + LanguageComment { + open: "@*".to_owned(), + close: Some("*@".to_owned()), + }, + ), + ( + "ruby".to_owned(), + LanguageComment { + open: "#".to_owned(), + close: None, + }, + ), + ( + "rust".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "scss".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "sass".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "scala".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "shaderlab".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "shellscript".to_owned(), + LanguageComment { + open: "#".to_owned(), + close: None, + }, + ), + ( + "sql".to_owned(), + LanguageComment { + open: "--".to_owned(), + close: None, + }, + ), + ( + "swift".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "toml".to_owned(), + LanguageComment { + open: "#".to_owned(), + close: None, + }, + ), + ( + "typescript".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "typescriptreact".to_owned(), + LanguageComment { + open: "//".to_owned(), + close: None, + }, + ), + ( + "tex".to_owned(), + LanguageComment { + open: "%".to_owned(), + close: None, + }, + ), + ( + "vb".to_owned(), + LanguageComment { + open: "'".to_owned(), + close: None, + }, + ), + ( + "xml".to_owned(), + LanguageComment { + open: "".to_owned()), + }, + ), + ( + "xsl".to_owned(), + LanguageComment { + open: "".to_owned()), + }, + ), + ( + "yaml".to_owned(), + LanguageComment { + open: "#".to_owned(), + close: None, + }, + ), + ]) +} diff --git a/src/main.rs b/src/main.rs index 2de6821..3841ab0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,27 @@ +mod language_comments; + +use language_comments::build_language_comments; +use reqwest::header::AUTHORIZATION; +use ropey::Rope; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use std::fmt::Display; -use std::path::PathBuf; -use tokio::io::AsyncWriteExt; +use std::sync::Arc; +use tokio::sync::RwLock; use tower_lsp::jsonrpc::{Error, Result}; use tower_lsp::lsp_types::*; use tower_lsp::{Client, LanguageServer, LspService, Server}; +use tracing::{error, info}; +use tracing_appender::rolling; +use tracing_subscriber::EnvFilter; -#[derive(Serialize)] +#[derive(Clone, Debug, Deserialize)] +pub struct LanguageComment { + open: String, + close: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] struct RequestParams { max_new_tokens: u32, temperature: f32, @@ -15,76 +30,236 @@ struct RequestParams { stop_token: String, } +#[derive(Debug, Deserialize, Serialize)] +struct FimParams { + enabled: bool, + prefix: String, + middle: String, + suffix: String, +} + #[derive(Serialize)] struct APIRequest { inputs: String, parameters: RequestParams, } -#[derive(Deserialize)] -struct APIResponse { +#[derive(Debug, Deserialize)] +struct Generation { generated_text: String, } +#[derive(Debug, Deserialize)] +struct APIError { + error: String, +} + +impl Display for APIError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.error) + } +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum APIResponse { + Generations(Vec), + Error(APIError), +} + +#[derive(Debug)] +struct Document { + language_id: String, + text: Rope, +} + +impl Document { + fn new(language_id: String, text: Rope) -> Self { + Self { language_id, text } + } +} + #[derive(Debug)] struct Backend { client: Client, + document_map: Arc>>, http_client: reqwest::Client, + workspace_folders: Arc>>>, + language_comments: HashMap, +} + +#[derive(Debug, Deserialize, Serialize)] +struct Completion { + generated_text: String, +} + +#[derive(Debug, Deserialize, Serialize)] +struct CompletionParams { + #[serde(flatten)] + text_document_position: TextDocumentPositionParams, + request_params: RequestParams, + fim: FimParams, + api_token: Option, + model: String, } fn internal_error(err: E) -> Error { + let err_msg = err.to_string(); + error!(err_msg); Error { code: tower_lsp::jsonrpc::ErrorCode::InternalError, - message: err.to_string(), + message: err_msg.into(), data: None, } } -fn get_cache_dir_path() -> Result { - let home_dir = home::home_dir().ok_or(internal_error("Failed to find home dir"))?; - Ok(home_dir.join(".cache/ccserver")) +fn file_path_comment( + file_url: Url, + file_language_id: String, + workspace_folders: Option<&Vec>, + language_comments: &HashMap, +) -> String { + let mut file_path = file_url.path().to_owned(); + let path_in_workspace = if let Some(workspace_folders) = workspace_folders { + for workspace_folder in workspace_folders { + let workspace_folder_path = workspace_folder.uri.path(); + if file_path.starts_with(workspace_folder_path) { + file_path = file_path.replace(workspace_folder_path, ""); + break; + } + } + file_path + } else { + file_path + }; + let lc = match language_comments.get(&file_language_id) { + Some(id) => id.clone(), + None => LanguageComment { + open: "//".to_owned(), + close: None, + }, + }; + let mut commented_path = lc.open; + commented_path.push(' '); + commented_path.push_str(&path_in_workspace); + if let Some(close) = lc.close { + commented_path.push(' '); + commented_path.push_str(&close); + } + commented_path.push('\n'); + commented_path } -async fn request_completion(http_client: &reqwest::Client) -> Result> { - http_client - .post("https://api-inference.huggingface.co/models/bigcode/starcoder") - .json(&APIRequest { - inputs: "Hello my name is ".to_owned(), - parameters: RequestParams { - max_new_tokens: 60, - temperature: 0.2, - do_sample: true, - top_p: 0.95, - stop_token: "\n".to_owned(), - }, +fn build_prompt(pos: Position, text: &Rope, fim: &FimParams, file_path: String) -> Result { + let mut prompt = file_path; + let cursor_offset = text + .try_line_to_char(pos.line as usize) + .map_err(internal_error)? + + pos.character as usize; + let text_len = text.len_chars(); + // XXX: not sure this is useful, rather be safe than sorry + let cursor_offset = if cursor_offset > text_len { + text_len + } else { + cursor_offset + }; + if fim.enabled { + prompt.push_str(&fim.prefix); + prompt.push_str(&text.slice(0..cursor_offset).to_string()); + prompt.push_str(&fim.suffix); + prompt.push_str(&text.slice(cursor_offset..text_len).to_string()); + prompt.push_str(&fim.middle); + Ok(prompt) + } else { + prompt.push_str(&text.slice(0..cursor_offset).to_string()); + Ok(prompt) + } +} + +async fn request_completion( + http_client: &reqwest::Client, + model: &str, + request_params: RequestParams, + api_token: Option, + prompt: String, +) -> Result> { + let mut req = http_client.post(model).json(&APIRequest { + inputs: prompt, + parameters: request_params, + }); + + if let Some(api_token) = api_token.clone() { + req = req.header(AUTHORIZATION, format!("Bearer {api_token}")) + } + + let res = req.send().await.map_err(internal_error)?; + + match res.json().await.map_err(internal_error)? { + APIResponse::Generations(gens) => Ok(gens), + APIResponse::Error(err) => Err(internal_error(err)), + } +} + +fn parse_generations( + generations: Vec, + prompt: &str, + stop_token: &str, +) -> Vec { + generations + .into_iter() + .map(|g| Completion { + generated_text: g.generated_text.replace(prompt, "").replace(stop_token, ""), }) - .send() - .await - .map_err(internal_error)? - .json() - .await - .map_err(internal_error)? + .collect() +} + +impl Backend { + async fn get_completions(&self, params: CompletionParams) -> Result> { + info!("get_completions {params:?}"); + let document_map = self.document_map.read().await; + + let document = document_map + .get(params.text_document_position.text_document.uri.as_str()) + .ok_or_else(|| internal_error("failed to find document"))?; + let file_path = file_path_comment( + params.text_document_position.text_document.uri, + document.language_id.clone(), + self.workspace_folders.read().await.as_ref(), + &self.language_comments, + ); + let prompt = build_prompt( + params.text_document_position.position, + &document.text, + ¶ms.fim, + file_path, + )?; + let stop_token = params.request_params.stop_token.clone(); + let result = request_completion( + &self.http_client, + ¶ms.model, + params.request_params, + params.api_token, + prompt.clone(), + ) + .await?; + + Ok(parse_generations(result, &prompt, &stop_token)) + } } #[tower_lsp::async_trait] impl LanguageServer for Backend { - async fn initialize(&self, _: InitializeParams) -> Result { - tokio::fs::create_dir_all(get_cache_dir_path()?) - .await - .map_err(internal_error)?; + async fn initialize(&self, params: InitializeParams) -> Result { + *self.workspace_folders.write().await = params.workspace_folders; Ok(InitializeResult { + server_info: Some(ServerInfo { + name: "llm-ls".to_owned(), + version: Some("0.1.0".to_owned()), + }), capabilities: ServerCapabilities { - completion_provider: Some(CompletionOptions { - resolve_provider: Some(false), - trigger_characters: Some(vec![ - ".".to_owned(), - "(".to_owned(), - "{".to_owned(), - ":".to_owned(), - ":".to_owned(), - ]), - ..Default::default() - }), + text_document_sync: Some(TextDocumentSyncCapability::Kind( + TextDocumentSyncKind::FULL, + )), ..Default::default() }, ..Default::default() @@ -93,59 +268,62 @@ impl LanguageServer for Backend { async fn initialized(&self, _: InitializedParams) { self.client - .log_message(MessageType::INFO, "{ccserver} initialized") + .log_message(MessageType::INFO, "{llm-ls} initialized") .await; - if let Ok(cache_dir) = get_cache_dir_path() { - tokio::fs::OpenOptions::new() - .create(true) - .append(true) - .open(cache_dir.join("ccserver.log")) - .await - .unwrap() - .write_all(b"initialized\n") - .await - .unwrap(); - } + let _ = info!("initialized"); } - // XXX: tbd if we use code action or completion - async fn completion(&self, _: CompletionParams) -> Result> { - let result = request_completion(&self.http_client).await?; - if result.len() > 0 { - let generated_text = result[0].generated_text.clone(); + // TODO: + // textDocument/didClose - tokio::fs::OpenOptions::new() - .create(true) - .append(true) - .open(get_cache_dir_path()?.join("ccserver.log")) - .await - .unwrap() - .write_all(format!("completion request: {generated_text}\n").as_bytes()) - .await - .unwrap(); + async fn did_open(&self, params: DidOpenTextDocumentParams) { + self.client + .log_message(MessageType::INFO, "{llm-ls} file opened") + .await; + let rope = ropey::Rope::from_str(¶ms.text_document.text); + let uri = params.text_document.uri.to_string(); + *self + .document_map + .write() + .await + .entry(uri.clone()) + .or_insert(Document::new("unknown".to_owned(), Rope::new())) = + Document::new(params.text_document.language_id, rope); + info!("{uri} opened"); + } - Ok(Some(CompletionResponse::Array(vec![CompletionItem { - label: "ccserver completion".to_owned(), - insert_text: Some(generated_text.clone()), - kind: Some(CompletionItemKind::TEXT), - detail: Some(generated_text), - ..Default::default() - }]))) - } else { - Ok(None) - } + async fn did_change(&self, params: DidChangeTextDocumentParams) { + self.client + .log_message(MessageType::INFO, "{llm-ls} file changed") + .await; + let rope = ropey::Rope::from_str(¶ms.content_changes[0].text); + let uri = params.text_document.uri.to_string(); + let mut document_map = self.document_map.write().await; + let doc = document_map + .entry(uri.clone()) + .or_insert(Document::new("unknown".to_owned(), Rope::new())); + doc.text = rope; + info!("{uri} changed"); + } + + async fn did_save(&self, params: DidSaveTextDocumentParams) { + self.client + .log_message(MessageType::INFO, "{llm-ls} file saved") + .await; + let uri = params.text_document.uri.to_string(); + info!("{uri} saved"); + } + + async fn did_close(&self, params: DidCloseTextDocumentParams) { + self.client + .log_message(MessageType::INFO, "{llm-ls} file closed") + .await; + let uri = params.text_document.uri.to_string(); + info!("{uri} closed"); } async fn shutdown(&self) -> Result<()> { - tokio::fs::OpenOptions::new() - .create(true) - .append(true) - .open(get_cache_dir_path()?.join("ccserver.log")) - .await - .unwrap() - .write_all(b"shutdown\n") - .await - .unwrap(); + let _ = info!("shutdown"); Ok(()) } } @@ -155,11 +333,39 @@ async fn main() { let stdin = tokio::io::stdin(); let stdout = tokio::io::stdout(); + let home_dir = home::home_dir().ok_or(()).expect("failed to find home dir"); + let cache_dir = home_dir.join(".cache/llm_ls"); + tokio::fs::create_dir_all(&cache_dir) + .await + .expect("failed to create cache dir"); + + let log_file = rolling::never(cache_dir, "llm-ls.log"); + let builder = tracing_subscriber::fmt() + .with_writer(log_file) + .with_target(true) + .with_line_number(true) + .with_env_filter( + EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")), + ); + + builder + .json() + .flatten_event(true) + .with_current_span(false) + .with_span_list(true) + .init(); + let http_client = reqwest::Client::new(); - let (service, socket) = LspService::new(|client| Backend { + let (service, socket) = LspService::build(|client| Backend { client, + document_map: Arc::new(RwLock::new(HashMap::new())), http_client, - }); + workspace_folders: Arc::new(RwLock::new(None)), + language_comments: build_language_comments(), + }) + .custom_method("llm-ls/getCompletions", Backend::get_completions) + .finish(); + Server::new(stdin, stdout, socket).serve(service).await; }