feat: add user agent (#20)

* feat: add user agent

* feat: add mock_server to repo

* feat: bump to `0.1.1`
This commit is contained in:
Luc Georges 2023-09-21 14:32:21 +02:00 committed by GitHub
parent 7f9c7855d5
commit eeb443feb3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 261 additions and 17 deletions

112
Cargo.lock generated
View file

@ -70,6 +70,55 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "axum"
version = "0.6.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf"
dependencies = [
"async-trait",
"axum-core",
"bitflags",
"bytes",
"futures-util",
"http",
"http-body",
"hyper",
"itoa",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-core"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"http",
"http-body",
"mime",
"rustversion",
"tower-layer",
"tower-service",
]
[[package]]
name = "backtrace"
version = "0.3.69"
@ -610,7 +659,7 @@ checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3"
[[package]]
name = "llm-ls"
version = "0.0.0"
version = "0.1.0"
dependencies = [
"home",
"reqwest",
@ -678,6 +727,12 @@ dependencies = [
"regex-automata 0.1.10",
]
[[package]]
name = "matchit"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed1202b2a6f884ae56f04cff409ab315c5ce26b5e58d7412e484f01fd52f52ef"
[[package]]
name = "memchr"
version = "2.6.3"
@ -725,6 +780,15 @@ dependencies = [
"windows-sys",
]
[[package]]
name = "mock_server"
version = "0.1.0"
dependencies = [
"axum",
"serde",
"tokio",
]
[[package]]
name = "monostate"
version = "0.1.9"
@ -819,6 +883,16 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "parking_lot"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f"
dependencies = [
"lock_api",
"parking_lot_core",
]
[[package]]
name = "parking_lot_core"
version = "0.9.8"
@ -1147,6 +1221,12 @@ dependencies = [
"untrusted",
]
[[package]]
name = "rustversion"
version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4"
[[package]]
name = "ryu"
version = "1.0.15"
@ -1200,6 +1280,16 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_path_to_error"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4beec8bce849d58d06238cb50db2e1c417cfeafa4c63f692b15c82b7c80f8335"
dependencies = [
"itoa",
"serde",
]
[[package]]
name = "serde_repr"
version = "0.1.16"
@ -1232,6 +1322,15 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "signal-hook-registry"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1"
dependencies = [
"libc",
]
[[package]]
name = "slab"
version = "0.4.9"
@ -1319,6 +1418,12 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "sync_wrapper"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
[[package]]
name = "thiserror"
version = "1.0.48"
@ -1434,7 +1539,9 @@ dependencies = [
"libc",
"mio",
"num_cpus",
"parking_lot",
"pin-project-lite",
"signal-hook-registry",
"socket2 0.5.3",
"tokio-macros",
"windows-sys",
@ -1485,8 +1592,10 @@ dependencies = [
"futures-util",
"pin-project",
"pin-project-lite",
"tokio",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
@ -1542,6 +1651,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
dependencies = [
"cfg-if",
"log",
"pin-project-lite",
"tracing-attributes",
"tracing-core",

View file

@ -3,7 +3,6 @@ members = ["xtask/", "crates/*"]
resolver = "2"
[workspace.package]
rust-version = "1.71"
edition = "2021"
license = "Apache-2.0"
authors = ["Luc Georges <luc@huggingface.co>"]

View file

@ -1,6 +1,6 @@
[package]
name = "llm-ls"
version = "0.0.0"
version = "0.1.1"
edition = "2021"
[[bin]]

View file

@ -1,6 +1,6 @@
use reqwest::header::AUTHORIZATION;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
use ropey::Rope;
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};
use std::collections::HashMap;
use std::fmt::Display;
use std::path::{Path, PathBuf};
@ -16,6 +16,9 @@ use tracing::{debug, error, info};
use tracing_appender::rolling;
use tracing_subscriber::EnvFilter;
const NAME: &str = "llm-ls";
const VERSION: &str = env!("CARGO_PKG_VERSION");
#[derive(Clone, Debug, Deserialize, Serialize)]
struct RequestParams {
max_new_tokens: u32,
@ -117,11 +120,46 @@ struct Completion {
generated_text: String,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
enum IDE {
Neovim,
VSCode,
JetBrains,
Emacs,
Jupyter,
Sublime,
VisualStudio,
Unknown,
}
impl Default for IDE {
fn default() -> Self {
IDE::Unknown
}
}
impl Display for IDE {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.serialize(f)
}
}
fn parse_ide<'de, D>(d: D) -> std::result::Result<IDE, D::Error>
where
D: Deserializer<'de>,
{
Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(IDE::Unknown))
}
#[derive(Debug, Deserialize, Serialize)]
struct CompletionParams {
#[serde(flatten)]
text_document_position: TextDocumentPositionParams,
request_params: RequestParams,
#[serde(default)]
#[serde(deserialize_with = "parse_ide")]
ide: IDE,
fim: FimParams,
api_token: Option<String>,
model: String,
@ -234,21 +272,22 @@ fn build_prompt(
async fn request_completion(
http_client: &reqwest::Client,
ide: IDE,
model: &str,
request_params: RequestParams,
api_token: Option<&String>,
prompt: String,
) -> Result<Vec<Generation>> {
let mut req = http_client.post(build_url(model)).json(&APIRequest {
inputs: prompt,
parameters: request_params.into(),
});
if let Some(api_token) = api_token {
req = req.header(AUTHORIZATION, format!("Bearer {api_token}"))
}
let res = req.send().await.map_err(internal_error)?;
let res = http_client
.post(build_url(model))
.json(&APIRequest {
inputs: prompt,
parameters: request_params.into(),
})
.headers(build_headers(api_token, ide)?)
.send()
.await
.map_err(internal_error)?;
match res.json().await.map_err(internal_error)? {
APIResponse::Generation(gen) => Ok(vec![gen]),
@ -377,6 +416,7 @@ impl Backend {
};
let result = request_completion(
http_client,
params.ide,
&params.model,
params.request_params,
params.api_token.as_ref(),
@ -395,7 +435,7 @@ impl LanguageServer for Backend {
Ok(InitializeResult {
server_info: Some(ServerInfo {
name: "llm-ls".to_owned(),
version: Some("0.1.0".to_owned()),
version: Some(VERSION.to_owned()),
}),
capabilities: ServerCapabilities {
text_document_sync: Some(TextDocumentSyncCapability::Kind(
@ -468,6 +508,24 @@ impl LanguageServer for Backend {
}
}
fn build_headers(api_token: Option<&String>, ide: IDE) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
headers.insert(
USER_AGENT,
HeaderValue::from_str(&user_agent).map_err(internal_error)?,
);
if let Some(api_token) = api_token {
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?,
);
}
Ok(headers)
}
#[tokio::main]
async fn main() {
let stdin = tokio::io::stdin();

1
crates/mock_server/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
/target

View file

@ -0,0 +1,14 @@
[package]
name = "mock_server"
version = "0.1.0"
edition = "2021"
[[bin]]
name = "mock_server"
[dependencies]
axum = "0.6"
# use this is you need axum::debug_handler
# axum = { version = "0.6", features = ["macros"] }
serde = { version = "1", features = ["derive"] }
tokio = { version = "1", features = ["full"] }

View file

@ -0,0 +1,63 @@
use axum::{extract::State, http::HeaderMap, routing::post, Json, Router};
use serde::{Deserialize, Serialize};
use std::{net::SocketAddr, sync::Arc};
use tokio::sync::Mutex;
#[derive(Clone)]
struct AppState {
counter: Arc<Mutex<u32>>,
}
#[derive(Deserialize, Serialize)]
struct GeneratedText {
generated_text: String,
}
async fn default(state: State<AppState>) -> Json<Vec<GeneratedText>> {
let mut lock = state.counter.lock().await;
*lock += 1;
println!("got request {}", lock);
Json(vec![GeneratedText {
generated_text: "dummy".to_owned(),
}])
}
async fn tgi(state: State<AppState>) -> Json<GeneratedText> {
let mut lock = state.counter.lock().await;
*lock += 1;
Json(GeneratedText {
generated_text: "dummy".to_owned(),
})
}
async fn log_headers(headers: HeaderMap, state: State<AppState>) -> Json<GeneratedText> {
let mut lock = state.counter.lock().await;
*lock += 1;
for (name, value) in headers.iter() {
println!("{lock} - {}: {}", name, value.to_str().unwrap());
}
Json(GeneratedText {
generated_text: "dummy".to_owned(),
})
}
#[tokio::main]
async fn main() {
let app_state = AppState {
counter: Arc::new(Mutex::new(0)),
};
let app = Router::new()
.route("/", post(default))
.route("/tgi", post(tgi))
.route("/headers", post(log_headers))
.with_state(app_state);
let addr: SocketAddr = format!("{}:{}", "0.0.0.0", 4242)
.parse()
.expect("string to parse to socket addr");
println!("starting server {}:{}", addr.ip().to_string(), addr.port(),);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.expect("server to start");
}

View file

@ -2,7 +2,6 @@
name = "xtask"
version = "0.1.0"
publish = false
rust-version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true