feat: add user agent (#20)
* feat: add user agent * feat: add mock_server to repo * feat: bump to `0.1.1`
This commit is contained in:
parent
7f9c7855d5
commit
eeb443feb3
112
Cargo.lock
generated
112
Cargo.lock
generated
|
@ -70,6 +70,55 @@ version = "1.1.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
||||
|
||||
[[package]]
|
||||
name = "axum"
|
||||
version = "0.6.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core",
|
||||
"bitflags",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"itoa",
|
||||
"matchit",
|
||||
"memchr",
|
||||
"mime",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"rustversion",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-core"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"mime",
|
||||
"rustversion",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backtrace"
|
||||
version = "0.3.69"
|
||||
|
@ -610,7 +659,7 @@ checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3"
|
|||
|
||||
[[package]]
|
||||
name = "llm-ls"
|
||||
version = "0.0.0"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"home",
|
||||
"reqwest",
|
||||
|
@ -678,6 +727,12 @@ dependencies = [
|
|||
"regex-automata 0.1.10",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "matchit"
|
||||
version = "0.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed1202b2a6f884ae56f04cff409ab315c5ce26b5e58d7412e484f01fd52f52ef"
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.6.3"
|
||||
|
@ -725,6 +780,15 @@ dependencies = [
|
|||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mock_server"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"serde",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "monostate"
|
||||
version = "0.1.9"
|
||||
|
@ -819,6 +883,16 @@ version = "0.1.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
"parking_lot_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot_core"
|
||||
version = "0.9.8"
|
||||
|
@ -1147,6 +1221,12 @@ dependencies = [
|
|||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4"
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.15"
|
||||
|
@ -1200,6 +1280,16 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_path_to_error"
|
||||
version = "0.1.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4beec8bce849d58d06238cb50db2e1c417cfeafa4c63f692b15c82b7c80f8335"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_repr"
|
||||
version = "0.1.16"
|
||||
|
@ -1232,6 +1322,15 @@ dependencies = [
|
|||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook-registry"
|
||||
version = "1.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "slab"
|
||||
version = "0.4.9"
|
||||
|
@ -1319,6 +1418,12 @@ dependencies = [
|
|||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sync_wrapper"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.48"
|
||||
|
@ -1434,7 +1539,9 @@ dependencies = [
|
|||
"libc",
|
||||
"mio",
|
||||
"num_cpus",
|
||||
"parking_lot",
|
||||
"pin-project-lite",
|
||||
"signal-hook-registry",
|
||||
"socket2 0.5.3",
|
||||
"tokio-macros",
|
||||
"windows-sys",
|
||||
|
@ -1485,8 +1592,10 @@ dependencies = [
|
|||
"futures-util",
|
||||
"pin-project",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1542,6 +1651,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"log",
|
||||
"pin-project-lite",
|
||||
"tracing-attributes",
|
||||
"tracing-core",
|
||||
|
|
|
@ -3,7 +3,6 @@ members = ["xtask/", "crates/*"]
|
|||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
rust-version = "1.71"
|
||||
edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
authors = ["Luc Georges <luc@huggingface.co>"]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "llm-ls"
|
||||
version = "0.0.0"
|
||||
version = "0.1.1"
|
||||
edition = "2021"
|
||||
|
||||
[[bin]]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use reqwest::header::AUTHORIZATION;
|
||||
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
|
||||
use ropey::Rope;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
@ -16,6 +16,9 @@ use tracing::{debug, error, info};
|
|||
use tracing_appender::rolling;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
const NAME: &str = "llm-ls";
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
struct RequestParams {
|
||||
max_new_tokens: u32,
|
||||
|
@ -117,11 +120,46 @@ struct Completion {
|
|||
generated_text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum IDE {
|
||||
Neovim,
|
||||
VSCode,
|
||||
JetBrains,
|
||||
Emacs,
|
||||
Jupyter,
|
||||
Sublime,
|
||||
VisualStudio,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl Default for IDE {
|
||||
fn default() -> Self {
|
||||
IDE::Unknown
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for IDE {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.serialize(f)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_ide<'de, D>(d: D) -> std::result::Result<IDE, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(IDE::Unknown))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct CompletionParams {
|
||||
#[serde(flatten)]
|
||||
text_document_position: TextDocumentPositionParams,
|
||||
request_params: RequestParams,
|
||||
#[serde(default)]
|
||||
#[serde(deserialize_with = "parse_ide")]
|
||||
ide: IDE,
|
||||
fim: FimParams,
|
||||
api_token: Option<String>,
|
||||
model: String,
|
||||
|
@ -234,21 +272,22 @@ fn build_prompt(
|
|||
|
||||
async fn request_completion(
|
||||
http_client: &reqwest::Client,
|
||||
ide: IDE,
|
||||
model: &str,
|
||||
request_params: RequestParams,
|
||||
api_token: Option<&String>,
|
||||
prompt: String,
|
||||
) -> Result<Vec<Generation>> {
|
||||
let mut req = http_client.post(build_url(model)).json(&APIRequest {
|
||||
inputs: prompt,
|
||||
parameters: request_params.into(),
|
||||
});
|
||||
|
||||
if let Some(api_token) = api_token {
|
||||
req = req.header(AUTHORIZATION, format!("Bearer {api_token}"))
|
||||
}
|
||||
|
||||
let res = req.send().await.map_err(internal_error)?;
|
||||
let res = http_client
|
||||
.post(build_url(model))
|
||||
.json(&APIRequest {
|
||||
inputs: prompt,
|
||||
parameters: request_params.into(),
|
||||
})
|
||||
.headers(build_headers(api_token, ide)?)
|
||||
.send()
|
||||
.await
|
||||
.map_err(internal_error)?;
|
||||
|
||||
match res.json().await.map_err(internal_error)? {
|
||||
APIResponse::Generation(gen) => Ok(vec![gen]),
|
||||
|
@ -377,6 +416,7 @@ impl Backend {
|
|||
};
|
||||
let result = request_completion(
|
||||
http_client,
|
||||
params.ide,
|
||||
¶ms.model,
|
||||
params.request_params,
|
||||
params.api_token.as_ref(),
|
||||
|
@ -395,7 +435,7 @@ impl LanguageServer for Backend {
|
|||
Ok(InitializeResult {
|
||||
server_info: Some(ServerInfo {
|
||||
name: "llm-ls".to_owned(),
|
||||
version: Some("0.1.0".to_owned()),
|
||||
version: Some(VERSION.to_owned()),
|
||||
}),
|
||||
capabilities: ServerCapabilities {
|
||||
text_document_sync: Some(TextDocumentSyncCapability::Kind(
|
||||
|
@ -468,6 +508,24 @@ impl LanguageServer for Backend {
|
|||
}
|
||||
}
|
||||
|
||||
fn build_headers(api_token: Option<&String>, ide: IDE) -> Result<HeaderMap> {
|
||||
let mut headers = HeaderMap::new();
|
||||
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
|
||||
headers.insert(
|
||||
USER_AGENT,
|
||||
HeaderValue::from_str(&user_agent).map_err(internal_error)?,
|
||||
);
|
||||
|
||||
if let Some(api_token) = api_token {
|
||||
headers.insert(
|
||||
AUTHORIZATION,
|
||||
HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let stdin = tokio::io::stdin();
|
||||
|
|
1
crates/mock_server/.gitignore
vendored
Normal file
1
crates/mock_server/.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
/target
|
14
crates/mock_server/Cargo.toml
Normal file
14
crates/mock_server/Cargo.toml
Normal file
|
@ -0,0 +1,14 @@
|
|||
[package]
|
||||
name = "mock_server"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[[bin]]
|
||||
name = "mock_server"
|
||||
|
||||
[dependencies]
|
||||
axum = "0.6"
|
||||
# use this is you need axum::debug_handler
|
||||
# axum = { version = "0.6", features = ["macros"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
tokio = { version = "1", features = ["full"] }
|
63
crates/mock_server/src/main.rs
Normal file
63
crates/mock_server/src/main.rs
Normal file
|
@ -0,0 +1,63 @@
|
|||
use axum::{extract::State, http::HeaderMap, routing::post, Json, Router};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
counter: Arc<Mutex<u32>>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct GeneratedText {
|
||||
generated_text: String,
|
||||
}
|
||||
|
||||
async fn default(state: State<AppState>) -> Json<Vec<GeneratedText>> {
|
||||
let mut lock = state.counter.lock().await;
|
||||
*lock += 1;
|
||||
println!("got request {}", lock);
|
||||
Json(vec![GeneratedText {
|
||||
generated_text: "dummy".to_owned(),
|
||||
}])
|
||||
}
|
||||
|
||||
async fn tgi(state: State<AppState>) -> Json<GeneratedText> {
|
||||
let mut lock = state.counter.lock().await;
|
||||
*lock += 1;
|
||||
Json(GeneratedText {
|
||||
generated_text: "dummy".to_owned(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn log_headers(headers: HeaderMap, state: State<AppState>) -> Json<GeneratedText> {
|
||||
let mut lock = state.counter.lock().await;
|
||||
*lock += 1;
|
||||
for (name, value) in headers.iter() {
|
||||
println!("{lock} - {}: {}", name, value.to_str().unwrap());
|
||||
}
|
||||
Json(GeneratedText {
|
||||
generated_text: "dummy".to_owned(),
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let app_state = AppState {
|
||||
counter: Arc::new(Mutex::new(0)),
|
||||
};
|
||||
let app = Router::new()
|
||||
.route("/", post(default))
|
||||
.route("/tgi", post(tgi))
|
||||
.route("/headers", post(log_headers))
|
||||
.with_state(app_state);
|
||||
let addr: SocketAddr = format!("{}:{}", "0.0.0.0", 4242)
|
||||
.parse()
|
||||
.expect("string to parse to socket addr");
|
||||
println!("starting server {}:{}", addr.ip().to_string(), addr.port(),);
|
||||
|
||||
axum::Server::bind(&addr)
|
||||
.serve(app.into_make_service())
|
||||
.await
|
||||
.expect("server to start");
|
||||
}
|
|
@ -2,7 +2,6 @@
|
|||
name = "xtask"
|
||||
version = "0.1.0"
|
||||
publish = false
|
||||
rust-version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
|
|
Loading…
Reference in a new issue