feat: code completion (#2)

This commit is contained in:
Luc Georges 2023-08-24 17:46:26 +02:00 committed by GitHub
parent d46a90c309
commit b3b7bb2b4e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 968 additions and 139 deletions

262
Cargo.lock generated
View file

@ -17,6 +17,15 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "aho-corasick"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86b8f9420f797f2d9e935edf629310eb938a0d839f984e25327f3c7eed22300c"
dependencies = [
"memchr",
]
[[package]]
name = "async-trait"
version = "0.1.72"
@ -100,17 +109,6 @@ dependencies = [
"libc",
]
[[package]]
name = "ccserver"
version = "0.1.0"
dependencies = [
"home",
"reqwest",
"serde",
"tokio",
"tower-lsp",
]
[[package]]
name = "cfg-if"
version = "1.0.0"
@ -133,6 +131,25 @@ version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
[[package]]
name = "crossbeam-channel"
version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200"
dependencies = [
"cfg-if",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294"
dependencies = [
"cfg-if",
]
[[package]]
name = "dashmap"
version = "5.5.0"
@ -146,6 +163,12 @@ dependencies = [
"parking_lot_core",
]
[[package]]
name = "deranged"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7684a49fb1af197853ef7b2ee694bc1f5b4179556f1e5710e1760c5db6f5e929"
[[package]]
name = "encoding_rs"
version = "0.8.32"
@ -471,6 +494,21 @@ version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57bcfdad1b858c2db7c38303a6d2ad4dfaf5eb53dfeb0910128b2c26d6158503"
[[package]]
name = "llm-ls"
version = "0.1.0"
dependencies = [
"home",
"reqwest",
"ropey",
"serde",
"tokio",
"tower-lsp",
"tracing",
"tracing-appender",
"tracing-subscriber",
]
[[package]]
name = "lock_api"
version = "0.4.10"
@ -500,6 +538,15 @@ dependencies = [
"url",
]
[[package]]
name = "matchers"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
dependencies = [
"regex-automata 0.1.10",
]
[[package]]
name = "memchr"
version = "2.5.0"
@ -550,6 +597,16 @@ dependencies = [
"tempfile",
]
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
dependencies = [
"overload",
"winapi",
]
[[package]]
name = "num_cpus"
version = "1.16.0"
@ -619,6 +676,12 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "overload"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "parking_lot_core"
version = "0.9.8"
@ -727,6 +790,50 @@ dependencies = [
"bitflags 1.3.2",
]
[[package]]
name = "regex"
version = "1.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81bc1d4caf89fac26a70747fe603c130093b53c773888797a6329091246d651a"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata 0.3.6",
"regex-syntax 0.7.4",
]
[[package]]
name = "regex-automata"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
dependencies = [
"regex-syntax 0.6.29",
]
[[package]]
name = "regex-automata"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fed1ceff11a1dddaee50c9dc8e4938bd106e9d89ae372f192311e7da498e3b69"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax 0.7.4",
]
[[package]]
name = "regex-syntax"
version = "0.6.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]]
name = "regex-syntax"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2"
[[package]]
name = "reqwest"
version = "0.11.18"
@ -764,6 +871,16 @@ dependencies = [
"winreg",
]
[[package]]
name = "ropey"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53ce7a2c43a32e50d666e33c5a80251b31147bb4b49024bcab11fb6f20c671ed"
dependencies = [
"smallvec",
"str_indices",
]
[[package]]
name = "rustc-demangle"
version = "0.1.23"
@ -881,6 +998,15 @@ dependencies = [
"serde",
]
[[package]]
name = "sharded-slab"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "900fba806f70c630b0a382d0d825e17a0f19fcd059a2ade1ff237bcddf446b31"
dependencies = [
"lazy_static",
]
[[package]]
name = "slab"
version = "0.4.8"
@ -906,6 +1032,12 @@ dependencies = [
"winapi",
]
[[package]]
name = "str_indices"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f026164926842ec52deb1938fae44f83dfdb82d0a5b0270c5bd5935ab74d6dd"
[[package]]
name = "syn"
version = "1.0.109"
@ -941,6 +1073,44 @@ dependencies = [
"windows-sys",
]
[[package]]
name = "thread_local"
version = "1.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152"
dependencies = [
"cfg-if",
"once_cell",
]
[[package]]
name = "time"
version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0fdd63d58b18d663fbdf70e049f00a22c8e42be082203be7f26589213cd75ea"
dependencies = [
"deranged",
"itoa",
"serde",
"time-core",
"time-macros",
]
[[package]]
name = "time-core"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb"
[[package]]
name = "time-macros"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb71511c991639bb078fd5bf97757e03914361c48100d52878b8e52b46fb92cd"
dependencies = [
"time-core",
]
[[package]]
name = "tinyvec"
version = "1.6.0"
@ -1031,9 +1201,9 @@ checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0"
[[package]]
name = "tower-lsp"
version = "0.19.0"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b38fb0e6ce037835174256518aace3ca621c4f96383c56bb846cfc11b341910"
checksum = "d4ba052b54a6627628d9b3c34c176e7eda8359b7da9acd497b9f20998d118508"
dependencies = [
"async-trait",
"auto_impl",
@ -1054,13 +1224,13 @@ dependencies = [
[[package]]
name = "tower-lsp-macros"
version = "0.8.0"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34723c06344244474fdde365b76aebef8050bf6be61a935b91ee9ff7c4e91157"
checksum = "84fd902d4e0b9a4b27f2f440108dc034e1758628a9b702f8ec61ad66355422fa"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
"syn 2.0.28",
]
[[package]]
@ -1081,6 +1251,17 @@ dependencies = [
"tracing-core",
]
[[package]]
name = "tracing-appender"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09d48f71a791638519505cefafe162606f706c25592e4bde4d97600c0195312e"
dependencies = [
"crossbeam-channel",
"time",
"tracing-subscriber",
]
[[package]]
name = "tracing-attributes"
version = "0.1.26"
@ -1099,6 +1280,49 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a"
dependencies = [
"once_cell",
"valuable",
]
[[package]]
name = "tracing-log"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ddad33d2d10b1ed7eb9d1f518a5674713876e97e5bb9b7345a7984fbb4f922"
dependencies = [
"lazy_static",
"log",
"tracing-core",
]
[[package]]
name = "tracing-serde"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc6b213177105856957181934e4920de57730fc69bf42c37ee5bb664d406d9e1"
dependencies = [
"serde",
"tracing-core",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77"
dependencies = [
"matchers",
"nu-ansi-term",
"once_cell",
"regex",
"serde",
"serde_json",
"sharded-slab",
"smallvec",
"thread_local",
"tracing",
"tracing-core",
"tracing-log",
"tracing-serde",
]
[[package]]
@ -1140,6 +1364,12 @@ dependencies = [
"serde",
]
[[package]]
name = "valuable"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
[[package]]
name = "vcpkg"
version = "0.2.15"

View file

@ -1,5 +1,5 @@
[package]
name = "ccserver"
name = "llm-ls"
version = "0.1.0"
edition = "2021"
@ -7,8 +7,12 @@ edition = "2021"
[dependencies]
home = "0.5"
serde = { version="1", features = ["derive"] }
ropey = "1.6"
serde = { version = "1", features = ["derive"] }
reqwest = { version = "0.11", features = ["json"] }
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "rt-multi-thread"] }
tower-lsp = "0.19"
tower-lsp = "0.20"
tracing = "0.1"
tracing-appender = "0.2"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }

View file

@ -1,37 +1,13 @@
# ccserver
# llm-ls
> [!IMPORTANT]
> This is currently a work in progress.
> This is currently a work in progress, expect things to be broken!
**ccserver** is a LSP server for ML code completion (and more?).
**llm-ls** is a LSP server leveraging LLMs for code completion (and more?).
## Developing
## Compatible extensions
Clone/fork this repo and run `cargo build [--release]`.
- [x] [llm.nvim](https://github.com/huggingface/llm.nvim)
- [ ] [huggingface-vscode](https://github.com/huggingface/huggingface-vscode)
- [ ] [jupytercoder](https://github.com/bigcode-project/jupytercoder)
Then add the following code to your lua config:
```lua
local client_id = vim.lsp.start({
name = "ccserver",
cmd = { "/path/to/ccserver/target/{debug|release}/ccserver" },
root_dir = vim.fs.dirname(vim.fs.find({ ".git" }, { upward = true })[1]),
})
if client_id == nil then
vim.notify("[ccserver] Error starting server", vim.log.levels.ERROR)
else
local augroup = "ccserver"
vim.api.nvim_create_augroup(augroup, { clear = true })
vim.api.nvim_create_autocmd("BufEnter", {
pattern = "*",
callback = function(ev)
if not vim.lsp.buf_is_attached(ev.buf, client_id) then
vim.lsp.buf_attach_client(ev.buf, client_id)
end
end,
})
end
```

413
src/language_comments.rs Normal file
View file

@ -0,0 +1,413 @@
use crate::LanguageComment;
use std::collections::HashMap;
pub fn build_language_comments() -> HashMap<String, LanguageComment> {
HashMap::from([
(
"abap".to_owned(),
LanguageComment {
open: "*".to_owned(),
close: None,
},
),
(
"bat".to_owned(),
LanguageComment {
open: "REM".to_owned(),
close: None,
},
),
(
"bibtex".to_owned(),
LanguageComment {
open: "%".to_owned(),
close: None,
},
),
(
"clojure".to_owned(),
LanguageComment {
open: ";;".to_owned(),
close: None,
},
),
(
"coffeescript".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"c".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"cpp".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"csharp".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"css".to_owned(),
LanguageComment {
open: "/*".to_owned(),
close: Some("*/".to_owned()),
},
),
(
"diff".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"dart".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"dockerfile".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"elixir".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"erlang".to_owned(),
LanguageComment {
open: "%".to_owned(),
close: None,
},
),
(
"fsharp".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"git-commit".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"git-rebase".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"go".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"groovy".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"handlebars".to_owned(),
LanguageComment {
open: "{{!--".to_owned(),
close: Some("--}}".to_owned()),
},
),
(
"html".to_owned(),
LanguageComment {
open: "<!--".to_owned(),
close: Some("-->".to_owned()),
},
),
(
"ini".to_owned(),
LanguageComment {
open: ";".to_owned(),
close: None,
},
),
(
"java".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"javascript".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"javascriptreact".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"json".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"latex".to_owned(),
LanguageComment {
open: "%".to_owned(),
close: None,
},
),
(
"less".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"lua".to_owned(),
LanguageComment {
open: "--".to_owned(),
close: None,
},
),
(
"makefile".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"markdown".to_owned(),
LanguageComment {
open: "<!--".to_owned(),
close: Some("-->".to_owned()),
},
),
(
"objective-c".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"objective-cpp".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"perl".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"perl6".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"php".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"powershell".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"jade".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"python".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"r".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"razor".to_owned(),
LanguageComment {
open: "@*".to_owned(),
close: Some("*@".to_owned()),
},
),
(
"ruby".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"rust".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"scss".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"sass".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"scala".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"shaderlab".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"shellscript".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"sql".to_owned(),
LanguageComment {
open: "--".to_owned(),
close: None,
},
),
(
"swift".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"toml".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
(
"typescript".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"typescriptreact".to_owned(),
LanguageComment {
open: "//".to_owned(),
close: None,
},
),
(
"tex".to_owned(),
LanguageComment {
open: "%".to_owned(),
close: None,
},
),
(
"vb".to_owned(),
LanguageComment {
open: "'".to_owned(),
close: None,
},
),
(
"xml".to_owned(),
LanguageComment {
open: "<!--".to_owned(),
close: Some("-->".to_owned()),
},
),
(
"xsl".to_owned(),
LanguageComment {
open: "<!--".to_owned(),
close: Some("-->".to_owned()),
},
),
(
"yaml".to_owned(),
LanguageComment {
open: "#".to_owned(),
close: None,
},
),
])
}

View file

@ -1,12 +1,27 @@
mod language_comments;
use language_comments::build_language_comments;
use reqwest::header::AUTHORIZATION;
use ropey::Rope;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Display;
use std::path::PathBuf;
use tokio::io::AsyncWriteExt;
use std::sync::Arc;
use tokio::sync::RwLock;
use tower_lsp::jsonrpc::{Error, Result};
use tower_lsp::lsp_types::*;
use tower_lsp::{Client, LanguageServer, LspService, Server};
use tracing::{error, info};
use tracing_appender::rolling;
use tracing_subscriber::EnvFilter;
#[derive(Serialize)]
#[derive(Clone, Debug, Deserialize)]
pub struct LanguageComment {
open: String,
close: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
struct RequestParams {
max_new_tokens: u32,
temperature: f32,
@ -15,76 +30,236 @@ struct RequestParams {
stop_token: String,
}
#[derive(Debug, Deserialize, Serialize)]
struct FimParams {
enabled: bool,
prefix: String,
middle: String,
suffix: String,
}
#[derive(Serialize)]
struct APIRequest {
inputs: String,
parameters: RequestParams,
}
#[derive(Deserialize)]
struct APIResponse {
#[derive(Debug, Deserialize)]
struct Generation {
generated_text: String,
}
#[derive(Debug, Deserialize)]
struct APIError {
error: String,
}
impl Display for APIError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.error)
}
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum APIResponse {
Generations(Vec<Generation>),
Error(APIError),
}
#[derive(Debug)]
struct Document {
language_id: String,
text: Rope,
}
impl Document {
fn new(language_id: String, text: Rope) -> Self {
Self { language_id, text }
}
}
#[derive(Debug)]
struct Backend {
client: Client,
document_map: Arc<RwLock<HashMap<String, Document>>>,
http_client: reqwest::Client,
workspace_folders: Arc<RwLock<Option<Vec<WorkspaceFolder>>>>,
language_comments: HashMap<String, LanguageComment>,
}
#[derive(Debug, Deserialize, Serialize)]
struct Completion {
generated_text: String,
}
#[derive(Debug, Deserialize, Serialize)]
struct CompletionParams {
#[serde(flatten)]
text_document_position: TextDocumentPositionParams,
request_params: RequestParams,
fim: FimParams,
api_token: Option<String>,
model: String,
}
fn internal_error<E: Display>(err: E) -> Error {
let err_msg = err.to_string();
error!(err_msg);
Error {
code: tower_lsp::jsonrpc::ErrorCode::InternalError,
message: err.to_string(),
message: err_msg.into(),
data: None,
}
}
fn get_cache_dir_path() -> Result<PathBuf> {
let home_dir = home::home_dir().ok_or(internal_error("Failed to find home dir"))?;
Ok(home_dir.join(".cache/ccserver"))
fn file_path_comment(
file_url: Url,
file_language_id: String,
workspace_folders: Option<&Vec<WorkspaceFolder>>,
language_comments: &HashMap<String, LanguageComment>,
) -> String {
let mut file_path = file_url.path().to_owned();
let path_in_workspace = if let Some(workspace_folders) = workspace_folders {
for workspace_folder in workspace_folders {
let workspace_folder_path = workspace_folder.uri.path();
if file_path.starts_with(workspace_folder_path) {
file_path = file_path.replace(workspace_folder_path, "");
break;
}
}
file_path
} else {
file_path
};
let lc = match language_comments.get(&file_language_id) {
Some(id) => id.clone(),
None => LanguageComment {
open: "//".to_owned(),
close: None,
},
};
let mut commented_path = lc.open;
commented_path.push(' ');
commented_path.push_str(&path_in_workspace);
if let Some(close) = lc.close {
commented_path.push(' ');
commented_path.push_str(&close);
}
commented_path.push('\n');
commented_path
}
async fn request_completion(http_client: &reqwest::Client) -> Result<Vec<APIResponse>> {
http_client
.post("https://api-inference.huggingface.co/models/bigcode/starcoder")
.json(&APIRequest {
inputs: "Hello my name is ".to_owned(),
parameters: RequestParams {
max_new_tokens: 60,
temperature: 0.2,
do_sample: true,
top_p: 0.95,
stop_token: "\n".to_owned(),
},
fn build_prompt(pos: Position, text: &Rope, fim: &FimParams, file_path: String) -> Result<String> {
let mut prompt = file_path;
let cursor_offset = text
.try_line_to_char(pos.line as usize)
.map_err(internal_error)?
+ pos.character as usize;
let text_len = text.len_chars();
// XXX: not sure this is useful, rather be safe than sorry
let cursor_offset = if cursor_offset > text_len {
text_len
} else {
cursor_offset
};
if fim.enabled {
prompt.push_str(&fim.prefix);
prompt.push_str(&text.slice(0..cursor_offset).to_string());
prompt.push_str(&fim.suffix);
prompt.push_str(&text.slice(cursor_offset..text_len).to_string());
prompt.push_str(&fim.middle);
Ok(prompt)
} else {
prompt.push_str(&text.slice(0..cursor_offset).to_string());
Ok(prompt)
}
}
async fn request_completion(
http_client: &reqwest::Client,
model: &str,
request_params: RequestParams,
api_token: Option<String>,
prompt: String,
) -> Result<Vec<Generation>> {
let mut req = http_client.post(model).json(&APIRequest {
inputs: prompt,
parameters: request_params,
});
if let Some(api_token) = api_token.clone() {
req = req.header(AUTHORIZATION, format!("Bearer {api_token}"))
}
let res = req.send().await.map_err(internal_error)?;
match res.json().await.map_err(internal_error)? {
APIResponse::Generations(gens) => Ok(gens),
APIResponse::Error(err) => Err(internal_error(err)),
}
}
fn parse_generations(
generations: Vec<Generation>,
prompt: &str,
stop_token: &str,
) -> Vec<Completion> {
generations
.into_iter()
.map(|g| Completion {
generated_text: g.generated_text.replace(prompt, "").replace(stop_token, ""),
})
.send()
.await
.map_err(internal_error)?
.json()
.await
.map_err(internal_error)?
.collect()
}
impl Backend {
async fn get_completions(&self, params: CompletionParams) -> Result<Vec<Completion>> {
info!("get_completions {params:?}");
let document_map = self.document_map.read().await;
let document = document_map
.get(params.text_document_position.text_document.uri.as_str())
.ok_or_else(|| internal_error("failed to find document"))?;
let file_path = file_path_comment(
params.text_document_position.text_document.uri,
document.language_id.clone(),
self.workspace_folders.read().await.as_ref(),
&self.language_comments,
);
let prompt = build_prompt(
params.text_document_position.position,
&document.text,
&params.fim,
file_path,
)?;
let stop_token = params.request_params.stop_token.clone();
let result = request_completion(
&self.http_client,
&params.model,
params.request_params,
params.api_token,
prompt.clone(),
)
.await?;
Ok(parse_generations(result, &prompt, &stop_token))
}
}
#[tower_lsp::async_trait]
impl LanguageServer for Backend {
async fn initialize(&self, _: InitializeParams) -> Result<InitializeResult> {
tokio::fs::create_dir_all(get_cache_dir_path()?)
.await
.map_err(internal_error)?;
async fn initialize(&self, params: InitializeParams) -> Result<InitializeResult> {
*self.workspace_folders.write().await = params.workspace_folders;
Ok(InitializeResult {
capabilities: ServerCapabilities {
completion_provider: Some(CompletionOptions {
resolve_provider: Some(false),
trigger_characters: Some(vec![
".".to_owned(),
"(".to_owned(),
"{".to_owned(),
":".to_owned(),
":".to_owned(),
]),
..Default::default()
server_info: Some(ServerInfo {
name: "llm-ls".to_owned(),
version: Some("0.1.0".to_owned()),
}),
capabilities: ServerCapabilities {
text_document_sync: Some(TextDocumentSyncCapability::Kind(
TextDocumentSyncKind::FULL,
)),
..Default::default()
},
..Default::default()
@ -93,59 +268,62 @@ impl LanguageServer for Backend {
async fn initialized(&self, _: InitializedParams) {
self.client
.log_message(MessageType::INFO, "{ccserver} initialized")
.log_message(MessageType::INFO, "{llm-ls} initialized")
.await;
if let Ok(cache_dir) = get_cache_dir_path() {
tokio::fs::OpenOptions::new()
.create(true)
.append(true)
.open(cache_dir.join("ccserver.log"))
.await
.unwrap()
.write_all(b"initialized\n")
.await
.unwrap();
}
let _ = info!("initialized");
}
// XXX: tbd if we use code action or completion
async fn completion(&self, _: CompletionParams) -> Result<Option<CompletionResponse>> {
let result = request_completion(&self.http_client).await?;
if result.len() > 0 {
let generated_text = result[0].generated_text.clone();
// TODO:
// textDocument/didClose
tokio::fs::OpenOptions::new()
.create(true)
.append(true)
.open(get_cache_dir_path()?.join("ccserver.log"))
async fn did_open(&self, params: DidOpenTextDocumentParams) {
self.client
.log_message(MessageType::INFO, "{llm-ls} file opened")
.await;
let rope = ropey::Rope::from_str(&params.text_document.text);
let uri = params.text_document.uri.to_string();
*self
.document_map
.write()
.await
.unwrap()
.write_all(format!("completion request: {generated_text}\n").as_bytes())
.await
.unwrap();
Ok(Some(CompletionResponse::Array(vec![CompletionItem {
label: "ccserver completion".to_owned(),
insert_text: Some(generated_text.clone()),
kind: Some(CompletionItemKind::TEXT),
detail: Some(generated_text),
..Default::default()
}])))
} else {
Ok(None)
.entry(uri.clone())
.or_insert(Document::new("unknown".to_owned(), Rope::new())) =
Document::new(params.text_document.language_id, rope);
info!("{uri} opened");
}
async fn did_change(&self, params: DidChangeTextDocumentParams) {
self.client
.log_message(MessageType::INFO, "{llm-ls} file changed")
.await;
let rope = ropey::Rope::from_str(&params.content_changes[0].text);
let uri = params.text_document.uri.to_string();
let mut document_map = self.document_map.write().await;
let doc = document_map
.entry(uri.clone())
.or_insert(Document::new("unknown".to_owned(), Rope::new()));
doc.text = rope;
info!("{uri} changed");
}
async fn did_save(&self, params: DidSaveTextDocumentParams) {
self.client
.log_message(MessageType::INFO, "{llm-ls} file saved")
.await;
let uri = params.text_document.uri.to_string();
info!("{uri} saved");
}
async fn did_close(&self, params: DidCloseTextDocumentParams) {
self.client
.log_message(MessageType::INFO, "{llm-ls} file closed")
.await;
let uri = params.text_document.uri.to_string();
info!("{uri} closed");
}
async fn shutdown(&self) -> Result<()> {
tokio::fs::OpenOptions::new()
.create(true)
.append(true)
.open(get_cache_dir_path()?.join("ccserver.log"))
.await
.unwrap()
.write_all(b"shutdown\n")
.await
.unwrap();
let _ = info!("shutdown");
Ok(())
}
}
@ -155,11 +333,39 @@ async fn main() {
let stdin = tokio::io::stdin();
let stdout = tokio::io::stdout();
let home_dir = home::home_dir().ok_or(()).expect("failed to find home dir");
let cache_dir = home_dir.join(".cache/llm_ls");
tokio::fs::create_dir_all(&cache_dir)
.await
.expect("failed to create cache dir");
let log_file = rolling::never(cache_dir, "llm-ls.log");
let builder = tracing_subscriber::fmt()
.with_writer(log_file)
.with_target(true)
.with_line_number(true)
.with_env_filter(
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")),
);
builder
.json()
.flatten_event(true)
.with_current_span(false)
.with_span_list(true)
.init();
let http_client = reqwest::Client::new();
let (service, socket) = LspService::new(|client| Backend {
let (service, socket) = LspService::build(|client| Backend {
client,
document_map: Arc::new(RwLock::new(HashMap::new())),
http_client,
});
workspace_folders: Arc::new(RwLock::new(None)),
language_comments: build_language_comments(),
})
.custom_method("llm-ls/getCompletions", Backend::get_completions)
.finish();
Server::new(stdin, stdout, socket).serve(service).await;
}