diff --git a/Cargo.lock b/Cargo.lock index 05c8c8a..30a3e49 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -70,6 +70,55 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "axum" +version = "0.6.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf" +dependencies = [ + "async-trait", + "axum-core", + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "hyper", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "mime", + "rustversion", + "tower-layer", + "tower-service", +] + [[package]] name = "backtrace" version = "0.3.69" @@ -610,7 +659,7 @@ checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" [[package]] name = "llm-ls" -version = "0.0.0" +version = "0.1.0" dependencies = [ "home", "reqwest", @@ -678,6 +727,12 @@ dependencies = [ "regex-automata 0.1.10", ] +[[package]] +name = "matchit" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed1202b2a6f884ae56f04cff409ab315c5ce26b5e58d7412e484f01fd52f52ef" + [[package]] name = "memchr" version = "2.6.3" @@ -725,6 +780,15 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "mock_server" +version = "0.1.0" +dependencies = [ + "axum", + "serde", + "tokio", +] + [[package]] name = "monostate" version = "0.1.9" @@ -819,6 +883,16 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + [[package]] name = "parking_lot_core" version = "0.9.8" @@ -1147,6 +1221,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustversion" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" + [[package]] name = "ryu" version = "1.0.15" @@ -1200,6 +1280,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4beec8bce849d58d06238cb50db2e1c417cfeafa4c63f692b15c82b7c80f8335" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_repr" version = "0.1.16" @@ -1232,6 +1322,15 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + [[package]] name = "slab" version = "0.4.9" @@ -1319,6 +1418,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + [[package]] name = "thiserror" version = "1.0.48" @@ -1434,7 +1539,9 @@ dependencies = [ "libc", "mio", "num_cpus", + "parking_lot", "pin-project-lite", + "signal-hook-registry", "socket2 0.5.3", "tokio-macros", "windows-sys", @@ -1485,8 +1592,10 @@ dependencies = [ "futures-util", "pin-project", "pin-project-lite", + "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -1542,6 +1651,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" dependencies = [ "cfg-if", + "log", "pin-project-lite", "tracing-attributes", "tracing-core", diff --git a/Cargo.toml b/Cargo.toml index e407b91..5a0df45 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,6 @@ members = ["xtask/", "crates/*"] resolver = "2" [workspace.package] -rust-version = "1.71" edition = "2021" license = "Apache-2.0" authors = ["Luc Georges "] diff --git a/crates/llm-ls/Cargo.toml b/crates/llm-ls/Cargo.toml index fd803be..87809a3 100644 --- a/crates/llm-ls/Cargo.toml +++ b/crates/llm-ls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "llm-ls" -version = "0.0.0" +version = "0.1.1" edition = "2021" [[bin]] diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index 2218814..44b838b 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -1,6 +1,6 @@ -use reqwest::header::AUTHORIZATION; +use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use ropey::Rope; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; use std::collections::HashMap; use std::fmt::Display; use std::path::{Path, PathBuf}; @@ -16,6 +16,9 @@ use tracing::{debug, error, info}; use tracing_appender::rolling; use tracing_subscriber::EnvFilter; +const NAME: &str = "llm-ls"; +const VERSION: &str = env!("CARGO_PKG_VERSION"); + #[derive(Clone, Debug, Deserialize, Serialize)] struct RequestParams { max_new_tokens: u32, @@ -117,11 +120,46 @@ struct Completion { generated_text: String, } +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +enum IDE { + Neovim, + VSCode, + JetBrains, + Emacs, + Jupyter, + Sublime, + VisualStudio, + Unknown, +} + +impl Default for IDE { + fn default() -> Self { + IDE::Unknown + } +} + +impl Display for IDE { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.serialize(f) + } +} + +fn parse_ide<'de, D>(d: D) -> std::result::Result +where + D: Deserializer<'de>, +{ + Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(IDE::Unknown)) +} + #[derive(Debug, Deserialize, Serialize)] struct CompletionParams { #[serde(flatten)] text_document_position: TextDocumentPositionParams, request_params: RequestParams, + #[serde(default)] + #[serde(deserialize_with = "parse_ide")] + ide: IDE, fim: FimParams, api_token: Option, model: String, @@ -234,21 +272,22 @@ fn build_prompt( async fn request_completion( http_client: &reqwest::Client, + ide: IDE, model: &str, request_params: RequestParams, api_token: Option<&String>, prompt: String, ) -> Result> { - let mut req = http_client.post(build_url(model)).json(&APIRequest { - inputs: prompt, - parameters: request_params.into(), - }); - - if let Some(api_token) = api_token { - req = req.header(AUTHORIZATION, format!("Bearer {api_token}")) - } - - let res = req.send().await.map_err(internal_error)?; + let res = http_client + .post(build_url(model)) + .json(&APIRequest { + inputs: prompt, + parameters: request_params.into(), + }) + .headers(build_headers(api_token, ide)?) + .send() + .await + .map_err(internal_error)?; match res.json().await.map_err(internal_error)? { APIResponse::Generation(gen) => Ok(vec![gen]), @@ -377,6 +416,7 @@ impl Backend { }; let result = request_completion( http_client, + params.ide, ¶ms.model, params.request_params, params.api_token.as_ref(), @@ -395,7 +435,7 @@ impl LanguageServer for Backend { Ok(InitializeResult { server_info: Some(ServerInfo { name: "llm-ls".to_owned(), - version: Some("0.1.0".to_owned()), + version: Some(VERSION.to_owned()), }), capabilities: ServerCapabilities { text_document_sync: Some(TextDocumentSyncCapability::Kind( @@ -468,6 +508,24 @@ impl LanguageServer for Backend { } } +fn build_headers(api_token: Option<&String>, ide: IDE) -> Result { + let mut headers = HeaderMap::new(); + let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}"); + headers.insert( + USER_AGENT, + HeaderValue::from_str(&user_agent).map_err(internal_error)?, + ); + + if let Some(api_token) = api_token { + headers.insert( + AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?, + ); + } + + Ok(headers) +} + #[tokio::main] async fn main() { let stdin = tokio::io::stdin(); diff --git a/crates/mock_server/.gitignore b/crates/mock_server/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/crates/mock_server/.gitignore @@ -0,0 +1 @@ +/target diff --git a/crates/mock_server/Cargo.toml b/crates/mock_server/Cargo.toml new file mode 100644 index 0000000..1aef7d2 --- /dev/null +++ b/crates/mock_server/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "mock_server" +version = "0.1.0" +edition = "2021" + +[[bin]] +name = "mock_server" + +[dependencies] +axum = "0.6" +# use this is you need axum::debug_handler +# axum = { version = "0.6", features = ["macros"] } +serde = { version = "1", features = ["derive"] } +tokio = { version = "1", features = ["full"] } diff --git a/crates/mock_server/src/main.rs b/crates/mock_server/src/main.rs new file mode 100644 index 0000000..d5b3c1d --- /dev/null +++ b/crates/mock_server/src/main.rs @@ -0,0 +1,63 @@ +use axum::{extract::State, http::HeaderMap, routing::post, Json, Router}; +use serde::{Deserialize, Serialize}; +use std::{net::SocketAddr, sync::Arc}; +use tokio::sync::Mutex; + +#[derive(Clone)] +struct AppState { + counter: Arc>, +} + +#[derive(Deserialize, Serialize)] +struct GeneratedText { + generated_text: String, +} + +async fn default(state: State) -> Json> { + let mut lock = state.counter.lock().await; + *lock += 1; + println!("got request {}", lock); + Json(vec![GeneratedText { + generated_text: "dummy".to_owned(), + }]) +} + +async fn tgi(state: State) -> Json { + let mut lock = state.counter.lock().await; + *lock += 1; + Json(GeneratedText { + generated_text: "dummy".to_owned(), + }) +} + +async fn log_headers(headers: HeaderMap, state: State) -> Json { + let mut lock = state.counter.lock().await; + *lock += 1; + for (name, value) in headers.iter() { + println!("{lock} - {}: {}", name, value.to_str().unwrap()); + } + Json(GeneratedText { + generated_text: "dummy".to_owned(), + }) +} + +#[tokio::main] +async fn main() { + let app_state = AppState { + counter: Arc::new(Mutex::new(0)), + }; + let app = Router::new() + .route("/", post(default)) + .route("/tgi", post(tgi)) + .route("/headers", post(log_headers)) + .with_state(app_state); + let addr: SocketAddr = format!("{}:{}", "0.0.0.0", 4242) + .parse() + .expect("string to parse to socket addr"); + println!("starting server {}:{}", addr.ip().to_string(), addr.port(),); + + axum::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .expect("server to start"); +} diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml index 5e0101a..d488ef0 100644 --- a/xtask/Cargo.toml +++ b/xtask/Cargo.toml @@ -2,7 +2,6 @@ name = "xtask" version = "0.1.0" publish = false -rust-version.workspace = true edition.workspace = true license.workspace = true authors.workspace = true