feat: code completion (#2)

This commit is contained in:
Luc Georges 2023-08-24 17:46:26 +02:00 committed by GitHub
parent d46a90c309
commit b3b7bb2b4e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 968 additions and 139 deletions

262
Cargo.lock generated
View file

@ -17,6 +17,15 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "aho-corasick"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86b8f9420f797f2d9e935edf629310eb938a0d839f984e25327f3c7eed22300c"
dependencies = [
"memchr",
]
[[package]] [[package]]
name = "async-trait" name = "async-trait"
version = "0.1.72" version = "0.1.72"
@ -100,17 +109,6 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "ccserver"
version = "0.1.0"
dependencies = [
"home",
"reqwest",
"serde",
"tokio",
"tower-lsp",
]
[[package]] [[package]]
name = "cfg-if" name = "cfg-if"
version = "1.0.0" version = "1.0.0"
@ -133,6 +131,25 @@ version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
[[package]]
name = "crossbeam-channel"
version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200"
dependencies = [
"cfg-if",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294"
dependencies = [
"cfg-if",
]
[[package]] [[package]]
name = "dashmap" name = "dashmap"
version = "5.5.0" version = "5.5.0"
@ -146,6 +163,12 @@ dependencies = [
"parking_lot_core", "parking_lot_core",
] ]
[[package]]
name = "deranged"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7684a49fb1af197853ef7b2ee694bc1f5b4179556f1e5710e1760c5db6f5e929"
[[package]] [[package]]
name = "encoding_rs" name = "encoding_rs"
version = "0.8.32" version = "0.8.32"
@ -471,6 +494,21 @@ version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57bcfdad1b858c2db7c38303a6d2ad4dfaf5eb53dfeb0910128b2c26d6158503" checksum = "57bcfdad1b858c2db7c38303a6d2ad4dfaf5eb53dfeb0910128b2c26d6158503"
[[package]]
name = "llm-ls"
version = "0.1.0"
dependencies = [
"home",
"reqwest",
"ropey",
"serde",
"tokio",
"tower-lsp",
"tracing",
"tracing-appender",
"tracing-subscriber",
]
[[package]] [[package]]
name = "lock_api" name = "lock_api"
version = "0.4.10" version = "0.4.10"
@ -500,6 +538,15 @@ dependencies = [
"url", "url",
] ]
[[package]]
name = "matchers"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
dependencies = [
"regex-automata 0.1.10",
]
[[package]] [[package]]
name = "memchr" name = "memchr"
version = "2.5.0" version = "2.5.0"
@ -550,6 +597,16 @@ dependencies = [
"tempfile", "tempfile",
] ]
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
dependencies = [
"overload",
"winapi",
]
[[package]] [[package]]
name = "num_cpus" name = "num_cpus"
version = "1.16.0" version = "1.16.0"
@ -619,6 +676,12 @@ dependencies = [
"vcpkg", "vcpkg",
] ]
[[package]]
name = "overload"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]] [[package]]
name = "parking_lot_core" name = "parking_lot_core"
version = "0.9.8" version = "0.9.8"
@ -727,6 +790,50 @@ dependencies = [
"bitflags 1.3.2", "bitflags 1.3.2",
] ]
[[package]]
name = "regex"
version = "1.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81bc1d4caf89fac26a70747fe603c130093b53c773888797a6329091246d651a"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata 0.3.6",
"regex-syntax 0.7.4",
]
[[package]]
name = "regex-automata"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
dependencies = [
"regex-syntax 0.6.29",
]
[[package]]
name = "regex-automata"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fed1ceff11a1dddaee50c9dc8e4938bd106e9d89ae372f192311e7da498e3b69"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax 0.7.4",
]
[[package]]
name = "regex-syntax"
version = "0.6.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]]
name = "regex-syntax"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2"
[[package]] [[package]]
name = "reqwest" name = "reqwest"
version = "0.11.18" version = "0.11.18"
@ -764,6 +871,16 @@ dependencies = [
"winreg", "winreg",
] ]
[[package]]
name = "ropey"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53ce7a2c43a32e50d666e33c5a80251b31147bb4b49024bcab11fb6f20c671ed"
dependencies = [
"smallvec",
"str_indices",
]
[[package]] [[package]]
name = "rustc-demangle" name = "rustc-demangle"
version = "0.1.23" version = "0.1.23"
@ -881,6 +998,15 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "sharded-slab"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "900fba806f70c630b0a382d0d825e17a0f19fcd059a2ade1ff237bcddf446b31"
dependencies = [
"lazy_static",
]
[[package]] [[package]]
name = "slab" name = "slab"
version = "0.4.8" version = "0.4.8"
@ -906,6 +1032,12 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "str_indices"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f026164926842ec52deb1938fae44f83dfdb82d0a5b0270c5bd5935ab74d6dd"
[[package]] [[package]]
name = "syn" name = "syn"
version = "1.0.109" version = "1.0.109"
@ -941,6 +1073,44 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "thread_local"
version = "1.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152"
dependencies = [
"cfg-if",
"once_cell",
]
[[package]]
name = "time"
version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0fdd63d58b18d663fbdf70e049f00a22c8e42be082203be7f26589213cd75ea"
dependencies = [
"deranged",
"itoa",
"serde",
"time-core",
"time-macros",
]
[[package]]
name = "time-core"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb"
[[package]]
name = "time-macros"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb71511c991639bb078fd5bf97757e03914361c48100d52878b8e52b46fb92cd"
dependencies = [
"time-core",
]
[[package]] [[package]]
name = "tinyvec" name = "tinyvec"
version = "1.6.0" version = "1.6.0"
@ -1031,9 +1201,9 @@ checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0"
[[package]] [[package]]
name = "tower-lsp" name = "tower-lsp"
version = "0.19.0" version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b38fb0e6ce037835174256518aace3ca621c4f96383c56bb846cfc11b341910" checksum = "d4ba052b54a6627628d9b3c34c176e7eda8359b7da9acd497b9f20998d118508"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"auto_impl", "auto_impl",
@ -1054,13 +1224,13 @@ dependencies = [
[[package]] [[package]]
name = "tower-lsp-macros" name = "tower-lsp-macros"
version = "0.8.0" version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34723c06344244474fdde365b76aebef8050bf6be61a935b91ee9ff7c4e91157" checksum = "84fd902d4e0b9a4b27f2f440108dc034e1758628a9b702f8ec61ad66355422fa"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 1.0.109", "syn 2.0.28",
] ]
[[package]] [[package]]
@ -1081,6 +1251,17 @@ dependencies = [
"tracing-core", "tracing-core",
] ]
[[package]]
name = "tracing-appender"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09d48f71a791638519505cefafe162606f706c25592e4bde4d97600c0195312e"
dependencies = [
"crossbeam-channel",
"time",
"tracing-subscriber",
]
[[package]] [[package]]
name = "tracing-attributes" name = "tracing-attributes"
version = "0.1.26" version = "0.1.26"
@ -1099,6 +1280,49 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a"
dependencies = [ dependencies = [
"once_cell", "once_cell",
"valuable",
]
[[package]]
name = "tracing-log"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ddad33d2d10b1ed7eb9d1f518a5674713876e97e5bb9b7345a7984fbb4f922"
dependencies = [
"lazy_static",
"log",
"tracing-core",
]
[[package]]
name = "tracing-serde"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc6b213177105856957181934e4920de57730fc69bf42c37ee5bb664d406d9e1"
dependencies = [
"serde",
"tracing-core",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77"
dependencies = [
"matchers",
"nu-ansi-term",
"once_cell",
"regex",
"serde",
"serde_json",
"sharded-slab",
"smallvec",
"thread_local",
"tracing",
"tracing-core",
"tracing-log",
"tracing-serde",
] ]
[[package]] [[package]]
@ -1140,6 +1364,12 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "valuable"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
[[package]] [[package]]
name = "vcpkg" name = "vcpkg"
version = "0.2.15" version = "0.2.15"

View file

@ -1,5 +1,5 @@
[package] [package]
name = "ccserver" name = "llm-ls"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
@ -7,8 +7,12 @@ edition = "2021"
[dependencies] [dependencies]
home = "0.5" home = "0.5"
serde = { version="1", features = ["derive"] } ropey = "1.6"
serde = { version = "1", features = ["derive"] }
reqwest = { version = "0.11", features = ["json"] } reqwest = { version = "0.11", features = ["json"] }
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "rt-multi-thread"] } tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "rt-multi-thread"] }
tower-lsp = "0.19" tower-lsp = "0.20"
tracing = "0.1"
tracing-appender = "0.2"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }

View file

@ -1,37 +1,13 @@
# ccserver # llm-ls
> [!IMPORTANT] > [!IMPORTANT]
> This is currently a work in progress. > This is currently a work in progress, expect things to be broken!
**ccserver** is a LSP server for ML code completion (and more?). **llm-ls** is a LSP server leveraging LLMs for code completion (and more?).
## Developing ## Compatible extensions
Clone/fork this repo and run `cargo build [--release]`. - [x] [llm.nvim](https://github.com/huggingface/llm.nvim)
- [ ] [huggingface-vscode](https://github.com/huggingface/huggingface-vscode)
- [ ] [jupytercoder](https://github.com/bigcode-project/jupytercoder)
Then add the following code to your lua config:
```lua
local client_id = vim.lsp.start({
name = "ccserver",
cmd = { "/path/to/ccserver/target/{debug|release}/ccserver" },
root_dir = vim.fs.dirname(vim.fs.find({ ".git" }, { upward = true })[1]),
})
if client_id == nil then
vim.notify("[ccserver] Error starting server", vim.log.levels.ERROR)
else
local augroup = "ccserver"
vim.api.nvim_create_augroup(augroup, { clear = true })
vim.api.nvim_create_autocmd("BufEnter", {
pattern = "*",
callback = function(ev)
if not vim.lsp.buf_is_attached(ev.buf, client_id) then
vim.lsp.buf_attach_client(ev.buf, client_id)
end
end,
})
end
```

413
src/language_comments.rs Normal file
View file

@ -0,0 +1,413 @@
use crate::LanguageComment;
use std::collections::HashMap;
pub fn build_language_comments() -> HashMap<String, LanguageComment> {
HashMap::from([
(
"abap".to_owned(),
LanguageComment {
open: "*".to_owned(),
close: None,
},
),
(
"bat".to_owned(),
LanguageComment {
open: "REM".to_owned(),
close: None,
},
),
(
"bibtex".to_owned(),
LanguageComment {
open: "%".to_owned(),
close: None,
},
),
(
"clojure".to_owned(),
LanguageComment {
open: ";;".to_owned(),
close: None,
},
),
(
"coffeescript".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"c".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"cpp".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"csharp".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"css".to_owned(),
LanguageComment {
open: "/*".to_owned(),
close: Some("*/".to_owned()),
},
),
(
"diff".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"dart".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"dockerfile".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"elixir".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"erlang".to_owned(),
LanguageComment {
open: "%".to_owned(),
close: None,
},
),
(
"fsharp".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"git-commit".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"git-rebase".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"go".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"groovy".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"handlebars".to_owned(),
LanguageComment {
open: "{{!--".to_owned(),
close: Some("--}}".to_owned()),
},
),
(
"html".to_owned(),
LanguageComment {
open: "<!--".to_owned(),
close: Some("-->".to_owned()),
},
),
(
"ini".to_owned(),
LanguageComment {
open: ";".to_owned(),
close: None,
},
),
(
"java".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"javascript".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"javascriptreact".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"json".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"latex".to_owned(),
LanguageComment {
open: "%".to_owned(),
close: None,
},
),
(
"less".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"lua".to_owned(),
LanguageComment {
open: "--".to_owned(),
close: None,
},
),
(
"makefile".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"markdown".to_owned(),
LanguageComment {
open: "<!--".to_owned(),
close: Some("-->".to_owned()),
},
),
(
"objective-c".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"objective-cpp".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"perl".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"perl6".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"php".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"powershell".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"jade".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"python".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"r".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"razor".to_owned(),
LanguageComment {
open: "@*".to_owned(),
close: Some("*@".to_owned()),
},
),
(
"ruby".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"rust".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"scss".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"sass".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"scala".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"shaderlab".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"shellscript".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"sql".to_owned(),
LanguageComment {
open: "--".to_owned(),
close: None,
},
),
(
"swift".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"toml".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"typescript".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"typescriptreact".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"tex".to_owned(),
LanguageComment {
open: "%".to_owned(),
close: None,
},
),
(
"vb".to_owned(),
LanguageComment {
open: "'".to_owned(),
close: None,
},
),
(
"xml".to_owned(),
LanguageComment {
open: "<!--".to_owned(),
close: Some("-->".to_owned()),
},
),
(
"xsl".to_owned(),
LanguageComment {
open: "<!--".to_owned(),
close: Some("-->".to_owned()),
},
),
(
"yaml".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
])
}

View file

@ -1,12 +1,27 @@
mod language_comments;
use language_comments::build_language_comments;
use reqwest::header::AUTHORIZATION;
use ropey::Rope;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Display; use std::fmt::Display;
use std::path::PathBuf; use std::sync::Arc;
use tokio::io::AsyncWriteExt; use tokio::sync::RwLock;
use tower_lsp::jsonrpc::{Error, Result}; use tower_lsp::jsonrpc::{Error, Result};
use tower_lsp::lsp_types::*; use tower_lsp::lsp_types::*;
use tower_lsp::{Client, LanguageServer, LspService, Server}; use tower_lsp::{Client, LanguageServer, LspService, Server};
use tracing::{error, info};
use tracing_appender::rolling;
use tracing_subscriber::EnvFilter;
#[derive(Serialize)] #[derive(Clone, Debug, Deserialize)]
pub struct LanguageComment {
open: String,
close: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
struct RequestParams { struct RequestParams {
max_new_tokens: u32, max_new_tokens: u32,
temperature: f32, temperature: f32,
@ -15,76 +30,236 @@ struct RequestParams {
stop_token: String, stop_token: String,
} }
#[derive(Debug, Deserialize, Serialize)]
struct FimParams {
enabled: bool,
prefix: String,
middle: String,
suffix: String,
}
#[derive(Serialize)] #[derive(Serialize)]
struct APIRequest { struct APIRequest {
inputs: String, inputs: String,
parameters: RequestParams, parameters: RequestParams,
} }
#[derive(Deserialize)] #[derive(Debug, Deserialize)]
struct APIResponse { struct Generation {
generated_text: String, generated_text: String,
} }
#[derive(Debug, Deserialize)]
struct APIError {
error: String,
}
impl Display for APIError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.error)
}
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum APIResponse {
Generations(Vec<Generation>),
Error(APIError),
}
#[derive(Debug)]
struct Document {
language_id: String,
text: Rope,
}
impl Document {
fn new(language_id: String, text: Rope) -> Self {
Self { language_id, text }
}
}
#[derive(Debug)] #[derive(Debug)]
struct Backend { struct Backend {
client: Client, client: Client,
document_map: Arc<RwLock<HashMap<String, Document>>>,
http_client: reqwest::Client, http_client: reqwest::Client,
workspace_folders: Arc<RwLock<Option<Vec<WorkspaceFolder>>>>,
language_comments: HashMap<String, LanguageComment>,
}
#[derive(Debug, Deserialize, Serialize)]
struct Completion {
generated_text: String,
}
#[derive(Debug, Deserialize, Serialize)]
struct CompletionParams {
#[serde(flatten)]
text_document_position: TextDocumentPositionParams,
request_params: RequestParams,
fim: FimParams,
api_token: Option<String>,
model: String,
} }
fn internal_error<E: Display>(err: E) -> Error { fn internal_error<E: Display>(err: E) -> Error {
let err_msg = err.to_string();
error!(err_msg);
Error { Error {
code: tower_lsp::jsonrpc::ErrorCode::InternalError, code: tower_lsp::jsonrpc::ErrorCode::InternalError,
message: err.to_string(), message: err_msg.into(),
data: None, data: None,
} }
} }
fn get_cache_dir_path() -> Result<PathBuf> { fn file_path_comment(
let home_dir = home::home_dir().ok_or(internal_error("Failed to find home dir"))?; file_url: Url,
Ok(home_dir.join(".cache/ccserver")) file_language_id: String,
workspace_folders: Option<&Vec<WorkspaceFolder>>,
language_comments: &HashMap<String, LanguageComment>,
) -> String {
let mut file_path = file_url.path().to_owned();
let path_in_workspace = if let Some(workspace_folders) = workspace_folders {
for workspace_folder in workspace_folders {
let workspace_folder_path = workspace_folder.uri.path();
if file_path.starts_with(workspace_folder_path) {
file_path = file_path.replace(workspace_folder_path, "");
break;
}
}
file_path
} else {
file_path
};
let lc = match language_comments.get(&file_language_id) {
Some(id) => id.clone(),
None => LanguageComment {
open: "//".to_owned(),
close: None,
},
};
let mut commented_path = lc.open;
commented_path.push(' ');
commented_path.push_str(&path_in_workspace);
if let Some(close) = lc.close {
commented_path.push(' ');
commented_path.push_str(&close);
}
commented_path.push('\n');
commented_path
} }
async fn request_completion(http_client: &reqwest::Client) -> Result<Vec<APIResponse>> { fn build_prompt(pos: Position, text: &Rope, fim: &FimParams, file_path: String) -> Result<String> {
http_client let mut prompt = file_path;
.post("https://api-inference.huggingface.co/models/bigcode/starcoder") let cursor_offset = text
.json(&APIRequest { .try_line_to_char(pos.line as usize)
inputs: "Hello my name is ".to_owned(), .map_err(internal_error)?
parameters: RequestParams { + pos.character as usize;
max_new_tokens: 60, let text_len = text.len_chars();
temperature: 0.2, // XXX: not sure this is useful, rather be safe than sorry
do_sample: true, let cursor_offset = if cursor_offset > text_len {
top_p: 0.95, text_len
stop_token: "\n".to_owned(), } else {
}, cursor_offset
};
if fim.enabled {
prompt.push_str(&fim.prefix);
prompt.push_str(&text.slice(0..cursor_offset).to_string());
prompt.push_str(&fim.suffix);
prompt.push_str(&text.slice(cursor_offset..text_len).to_string());
prompt.push_str(&fim.middle);
Ok(prompt)
} else {
prompt.push_str(&text.slice(0..cursor_offset).to_string());
Ok(prompt)
}
}
async fn request_completion(
http_client: &reqwest::Client,
model: &str,
request_params: RequestParams,
api_token: Option<String>,
prompt: String,
) -> Result<Vec<Generation>> {
let mut req = http_client.post(model).json(&APIRequest {
inputs: prompt,
parameters: request_params,
});
if let Some(api_token) = api_token.clone() {
req = req.header(AUTHORIZATION, format!("Bearer {api_token}"))
}
let res = req.send().await.map_err(internal_error)?;
match res.json().await.map_err(internal_error)? {
APIResponse::Generations(gens) => Ok(gens),
APIResponse::Error(err) => Err(internal_error(err)),
}
}
fn parse_generations(
generations: Vec<Generation>,
prompt: &str,
stop_token: &str,
) -> Vec<Completion> {
generations
.into_iter()
.map(|g| Completion {
generated_text: g.generated_text.replace(prompt, "").replace(stop_token, ""),
}) })
.send() .collect()
.await }
.map_err(internal_error)?
.json() impl Backend {
.await async fn get_completions(&self, params: CompletionParams) -> Result<Vec<Completion>> {
.map_err(internal_error)? info!("get_completions {params:?}");
let document_map = self.document_map.read().await;
let document = document_map
.get(params.text_document_position.text_document.uri.as_str())
.ok_or_else(|| internal_error("failed to find document"))?;
let file_path = file_path_comment(
params.text_document_position.text_document.uri,
document.language_id.clone(),
self.workspace_folders.read().await.as_ref(),
&self.language_comments,
);
let prompt = build_prompt(
params.text_document_position.position,
&document.text,
&params.fim,
file_path,
)?;
let stop_token = params.request_params.stop_token.clone();
let result = request_completion(
&self.http_client,
&params.model,
params.request_params,
params.api_token,
prompt.clone(),
)
.await?;
Ok(parse_generations(result, &prompt, &stop_token))
}
} }
#[tower_lsp::async_trait] #[tower_lsp::async_trait]
impl LanguageServer for Backend { impl LanguageServer for Backend {
async fn initialize(&self, _: InitializeParams) -> Result<InitializeResult> { async fn initialize(&self, params: InitializeParams) -> Result<InitializeResult> {
tokio::fs::create_dir_all(get_cache_dir_path()?) *self.workspace_folders.write().await = params.workspace_folders;
.await
.map_err(internal_error)?;
Ok(InitializeResult { Ok(InitializeResult {
server_info: Some(ServerInfo {
name: "llm-ls".to_owned(),
version: Some("0.1.0".to_owned()),
}),
capabilities: ServerCapabilities { capabilities: ServerCapabilities {
completion_provider: Some(CompletionOptions { text_document_sync: Some(TextDocumentSyncCapability::Kind(
resolve_provider: Some(false), TextDocumentSyncKind::FULL,
trigger_characters: Some(vec![ )),
".".to_owned(),
"(".to_owned(),
"{".to_owned(),
":".to_owned(),
":".to_owned(),
]),
..Default::default()
}),
..Default::default() ..Default::default()
}, },
..Default::default() ..Default::default()
@ -93,59 +268,62 @@ impl LanguageServer for Backend {
async fn initialized(&self, _: InitializedParams) { async fn initialized(&self, _: InitializedParams) {
self.client self.client
.log_message(MessageType::INFO, "{ccserver} initialized") .log_message(MessageType::INFO, "{llm-ls} initialized")
.await; .await;
if let Ok(cache_dir) = get_cache_dir_path() { let _ = info!("initialized");
tokio::fs::OpenOptions::new()
.create(true)
.append(true)
.open(cache_dir.join("ccserver.log"))
.await
.unwrap()
.write_all(b"initialized\n")
.await
.unwrap();
}
} }
// XXX: tbd if we use code action or completion // TODO:
async fn completion(&self, _: CompletionParams) -> Result<Option<CompletionResponse>> { // textDocument/didClose
let result = request_completion(&self.http_client).await?;
if result.len() > 0 {
let generated_text = result[0].generated_text.clone();
tokio::fs::OpenOptions::new() async fn did_open(&self, params: DidOpenTextDocumentParams) {
.create(true) self.client
.append(true) .log_message(MessageType::INFO, "{llm-ls} file opened")
.open(get_cache_dir_path()?.join("ccserver.log")) .await;
.await let rope = ropey::Rope::from_str(&params.text_document.text);
.unwrap() let uri = params.text_document.uri.to_string();
.write_all(format!("completion request: {generated_text}\n").as_bytes()) *self
.await .document_map
.unwrap(); .write()
.await
.entry(uri.clone())
.or_insert(Document::new("unknown".to_owned(), Rope::new())) =
Document::new(params.text_document.language_id, rope);
info!("{uri} opened");
}
Ok(Some(CompletionResponse::Array(vec![CompletionItem { async fn did_change(&self, params: DidChangeTextDocumentParams) {
label: "ccserver completion".to_owned(), self.client
insert_text: Some(generated_text.clone()), .log_message(MessageType::INFO, "{llm-ls} file changed")
kind: Some(CompletionItemKind::TEXT), .await;
detail: Some(generated_text), let rope = ropey::Rope::from_str(&params.content_changes[0].text);
..Default::default() let uri = params.text_document.uri.to_string();
}]))) let mut document_map = self.document_map.write().await;
} else { let doc = document_map
Ok(None) .entry(uri.clone())
} .or_insert(Document::new("unknown".to_owned(), Rope::new()));
doc.text = rope;
info!("{uri} changed");
}
async fn did_save(&self, params: DidSaveTextDocumentParams) {
self.client
.log_message(MessageType::INFO, "{llm-ls} file saved")
.await;
let uri = params.text_document.uri.to_string();
info!("{uri} saved");
}
async fn did_close(&self, params: DidCloseTextDocumentParams) {
self.client
.log_message(MessageType::INFO, "{llm-ls} file closed")
.await;
let uri = params.text_document.uri.to_string();
info!("{uri} closed");
} }
async fn shutdown(&self) -> Result<()> { async fn shutdown(&self) -> Result<()> {
tokio::fs::OpenOptions::new() let _ = info!("shutdown");
.create(true)
.append(true)
.open(get_cache_dir_path()?.join("ccserver.log"))
.await
.unwrap()
.write_all(b"shutdown\n")
.await
.unwrap();
Ok(()) Ok(())
} }
} }
@ -155,11 +333,39 @@ async fn main() {
let stdin = tokio::io::stdin(); let stdin = tokio::io::stdin();
let stdout = tokio::io::stdout(); let stdout = tokio::io::stdout();
let home_dir = home::home_dir().ok_or(()).expect("failed to find home dir");
let cache_dir = home_dir.join(".cache/llm_ls");
tokio::fs::create_dir_all(&cache_dir)
.await
.expect("failed to create cache dir");
let log_file = rolling::never(cache_dir, "llm-ls.log");
let builder = tracing_subscriber::fmt()
.with_writer(log_file)
.with_target(true)
.with_line_number(true)
.with_env_filter(
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")),
);
builder
.json()
.flatten_event(true)
.with_current_span(false)
.with_span_list(true)
.init();
let http_client = reqwest::Client::new(); let http_client = reqwest::Client::new();
let (service, socket) = LspService::new(|client| Backend { let (service, socket) = LspService::build(|client| Backend {
client, client,
document_map: Arc::new(RwLock::new(HashMap::new())),
http_client, http_client,
}); workspace_folders: Arc::new(RwLock::new(None)),
language_comments: build_language_comments(),
})
.custom_method("llm-ls/getCompletions", Backend::get_completions)
.finish();
Server::new(stdin, stdout, socket).serve(service).await; Server::new(stdin, stdout, socket).serve(service).await;
} }