feat: add user agent (#20)
* feat: add user agent * feat: add mock_server to repo * feat: bump to `0.1.1`
This commit is contained in:
parent
7f9c7855d5
commit
eeb443feb3
112
Cargo.lock
generated
112
Cargo.lock
generated
|
@ -70,6 +70,55 @@ version = "1.1.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "axum"
|
||||||
|
version = "0.6.20"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf"
|
||||||
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
|
"axum-core",
|
||||||
|
"bitflags",
|
||||||
|
"bytes",
|
||||||
|
"futures-util",
|
||||||
|
"http",
|
||||||
|
"http-body",
|
||||||
|
"hyper",
|
||||||
|
"itoa",
|
||||||
|
"matchit",
|
||||||
|
"memchr",
|
||||||
|
"mime",
|
||||||
|
"percent-encoding",
|
||||||
|
"pin-project-lite",
|
||||||
|
"rustversion",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"serde_path_to_error",
|
||||||
|
"serde_urlencoded",
|
||||||
|
"sync_wrapper",
|
||||||
|
"tokio",
|
||||||
|
"tower",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "axum-core"
|
||||||
|
version = "0.3.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c"
|
||||||
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
|
"bytes",
|
||||||
|
"futures-util",
|
||||||
|
"http",
|
||||||
|
"http-body",
|
||||||
|
"mime",
|
||||||
|
"rustversion",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "backtrace"
|
name = "backtrace"
|
||||||
version = "0.3.69"
|
version = "0.3.69"
|
||||||
|
@ -610,7 +659,7 @@ checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "llm-ls"
|
name = "llm-ls"
|
||||||
version = "0.0.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"home",
|
"home",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
|
@ -678,6 +727,12 @@ dependencies = [
|
||||||
"regex-automata 0.1.10",
|
"regex-automata 0.1.10",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "matchit"
|
||||||
|
version = "0.7.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ed1202b2a6f884ae56f04cff409ab315c5ce26b5e58d7412e484f01fd52f52ef"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "memchr"
|
name = "memchr"
|
||||||
version = "2.6.3"
|
version = "2.6.3"
|
||||||
|
@ -725,6 +780,15 @@ dependencies = [
|
||||||
"windows-sys",
|
"windows-sys",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mock_server"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"axum",
|
||||||
|
"serde",
|
||||||
|
"tokio",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "monostate"
|
name = "monostate"
|
||||||
version = "0.1.9"
|
version = "0.1.9"
|
||||||
|
@ -819,6 +883,16 @@ version = "0.1.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
|
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "parking_lot"
|
||||||
|
version = "0.12.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f"
|
||||||
|
dependencies = [
|
||||||
|
"lock_api",
|
||||||
|
"parking_lot_core",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "parking_lot_core"
|
name = "parking_lot_core"
|
||||||
version = "0.9.8"
|
version = "0.9.8"
|
||||||
|
@ -1147,6 +1221,12 @@ dependencies = [
|
||||||
"untrusted",
|
"untrusted",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustversion"
|
||||||
|
version = "1.0.14"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ryu"
|
name = "ryu"
|
||||||
version = "1.0.15"
|
version = "1.0.15"
|
||||||
|
@ -1200,6 +1280,16 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_path_to_error"
|
||||||
|
version = "0.1.14"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4beec8bce849d58d06238cb50db2e1c417cfeafa4c63f692b15c82b7c80f8335"
|
||||||
|
dependencies = [
|
||||||
|
"itoa",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_repr"
|
name = "serde_repr"
|
||||||
version = "0.1.16"
|
version = "0.1.16"
|
||||||
|
@ -1232,6 +1322,15 @@ dependencies = [
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "signal-hook-registry"
|
||||||
|
version = "1.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "slab"
|
name = "slab"
|
||||||
version = "0.4.9"
|
version = "0.4.9"
|
||||||
|
@ -1319,6 +1418,12 @@ dependencies = [
|
||||||
"unicode-ident",
|
"unicode-ident",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sync_wrapper"
|
||||||
|
version = "0.1.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror"
|
name = "thiserror"
|
||||||
version = "1.0.48"
|
version = "1.0.48"
|
||||||
|
@ -1434,7 +1539,9 @@ dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"mio",
|
"mio",
|
||||||
"num_cpus",
|
"num_cpus",
|
||||||
|
"parking_lot",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
|
"signal-hook-registry",
|
||||||
"socket2 0.5.3",
|
"socket2 0.5.3",
|
||||||
"tokio-macros",
|
"tokio-macros",
|
||||||
"windows-sys",
|
"windows-sys",
|
||||||
|
@ -1485,8 +1592,10 @@ dependencies = [
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"pin-project",
|
"pin-project",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
|
"tokio",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1542,6 +1651,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
|
checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
|
"log",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"tracing-attributes",
|
"tracing-attributes",
|
||||||
"tracing-core",
|
"tracing-core",
|
||||||
|
|
|
@ -3,7 +3,6 @@ members = ["xtask/", "crates/*"]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
rust-version = "1.71"
|
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
authors = ["Luc Georges <luc@huggingface.co>"]
|
authors = ["Luc Georges <luc@huggingface.co>"]
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "llm-ls"
|
name = "llm-ls"
|
||||||
version = "0.0.0"
|
version = "0.1.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use reqwest::header::AUTHORIZATION;
|
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
|
||||||
use ropey::Rope;
|
use ropey::Rope;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Deserializer, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
@ -16,6 +16,9 @@ use tracing::{debug, error, info};
|
||||||
use tracing_appender::rolling;
|
use tracing_appender::rolling;
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
|
const NAME: &str = "llm-ls";
|
||||||
|
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
struct RequestParams {
|
struct RequestParams {
|
||||||
max_new_tokens: u32,
|
max_new_tokens: u32,
|
||||||
|
@ -117,11 +120,46 @@ struct Completion {
|
||||||
generated_text: String,
|
generated_text: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
enum IDE {
|
||||||
|
Neovim,
|
||||||
|
VSCode,
|
||||||
|
JetBrains,
|
||||||
|
Emacs,
|
||||||
|
Jupyter,
|
||||||
|
Sublime,
|
||||||
|
VisualStudio,
|
||||||
|
Unknown,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for IDE {
|
||||||
|
fn default() -> Self {
|
||||||
|
IDE::Unknown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for IDE {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
self.serialize(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_ide<'de, D>(d: D) -> std::result::Result<IDE, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(IDE::Unknown))
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
struct CompletionParams {
|
struct CompletionParams {
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
text_document_position: TextDocumentPositionParams,
|
text_document_position: TextDocumentPositionParams,
|
||||||
request_params: RequestParams,
|
request_params: RequestParams,
|
||||||
|
#[serde(default)]
|
||||||
|
#[serde(deserialize_with = "parse_ide")]
|
||||||
|
ide: IDE,
|
||||||
fim: FimParams,
|
fim: FimParams,
|
||||||
api_token: Option<String>,
|
api_token: Option<String>,
|
||||||
model: String,
|
model: String,
|
||||||
|
@ -234,21 +272,22 @@ fn build_prompt(
|
||||||
|
|
||||||
async fn request_completion(
|
async fn request_completion(
|
||||||
http_client: &reqwest::Client,
|
http_client: &reqwest::Client,
|
||||||
|
ide: IDE,
|
||||||
model: &str,
|
model: &str,
|
||||||
request_params: RequestParams,
|
request_params: RequestParams,
|
||||||
api_token: Option<&String>,
|
api_token: Option<&String>,
|
||||||
prompt: String,
|
prompt: String,
|
||||||
) -> Result<Vec<Generation>> {
|
) -> Result<Vec<Generation>> {
|
||||||
let mut req = http_client.post(build_url(model)).json(&APIRequest {
|
let res = http_client
|
||||||
inputs: prompt,
|
.post(build_url(model))
|
||||||
parameters: request_params.into(),
|
.json(&APIRequest {
|
||||||
});
|
inputs: prompt,
|
||||||
|
parameters: request_params.into(),
|
||||||
if let Some(api_token) = api_token {
|
})
|
||||||
req = req.header(AUTHORIZATION, format!("Bearer {api_token}"))
|
.headers(build_headers(api_token, ide)?)
|
||||||
}
|
.send()
|
||||||
|
.await
|
||||||
let res = req.send().await.map_err(internal_error)?;
|
.map_err(internal_error)?;
|
||||||
|
|
||||||
match res.json().await.map_err(internal_error)? {
|
match res.json().await.map_err(internal_error)? {
|
||||||
APIResponse::Generation(gen) => Ok(vec![gen]),
|
APIResponse::Generation(gen) => Ok(vec![gen]),
|
||||||
|
@ -377,6 +416,7 @@ impl Backend {
|
||||||
};
|
};
|
||||||
let result = request_completion(
|
let result = request_completion(
|
||||||
http_client,
|
http_client,
|
||||||
|
params.ide,
|
||||||
¶ms.model,
|
¶ms.model,
|
||||||
params.request_params,
|
params.request_params,
|
||||||
params.api_token.as_ref(),
|
params.api_token.as_ref(),
|
||||||
|
@ -395,7 +435,7 @@ impl LanguageServer for Backend {
|
||||||
Ok(InitializeResult {
|
Ok(InitializeResult {
|
||||||
server_info: Some(ServerInfo {
|
server_info: Some(ServerInfo {
|
||||||
name: "llm-ls".to_owned(),
|
name: "llm-ls".to_owned(),
|
||||||
version: Some("0.1.0".to_owned()),
|
version: Some(VERSION.to_owned()),
|
||||||
}),
|
}),
|
||||||
capabilities: ServerCapabilities {
|
capabilities: ServerCapabilities {
|
||||||
text_document_sync: Some(TextDocumentSyncCapability::Kind(
|
text_document_sync: Some(TextDocumentSyncCapability::Kind(
|
||||||
|
@ -468,6 +508,24 @@ impl LanguageServer for Backend {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn build_headers(api_token: Option<&String>, ide: IDE) -> Result<HeaderMap> {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
|
||||||
|
headers.insert(
|
||||||
|
USER_AGENT,
|
||||||
|
HeaderValue::from_str(&user_agent).map_err(internal_error)?,
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(api_token) = api_token {
|
||||||
|
headers.insert(
|
||||||
|
AUTHORIZATION,
|
||||||
|
HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(headers)
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
let stdin = tokio::io::stdin();
|
let stdin = tokio::io::stdin();
|
||||||
|
|
1
crates/mock_server/.gitignore
vendored
Normal file
1
crates/mock_server/.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
||||||
|
/target
|
14
crates/mock_server/Cargo.toml
Normal file
14
crates/mock_server/Cargo.toml
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
[package]
|
||||||
|
name = "mock_server"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "mock_server"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
axum = "0.6"
|
||||||
|
# use this is you need axum::debug_handler
|
||||||
|
# axum = { version = "0.6", features = ["macros"] }
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
63
crates/mock_server/src/main.rs
Normal file
63
crates/mock_server/src/main.rs
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
use axum::{extract::State, http::HeaderMap, routing::post, Json, Router};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::{net::SocketAddr, sync::Arc};
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct AppState {
|
||||||
|
counter: Arc<Mutex<u32>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, Serialize)]
|
||||||
|
struct GeneratedText {
|
||||||
|
generated_text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn default(state: State<AppState>) -> Json<Vec<GeneratedText>> {
|
||||||
|
let mut lock = state.counter.lock().await;
|
||||||
|
*lock += 1;
|
||||||
|
println!("got request {}", lock);
|
||||||
|
Json(vec![GeneratedText {
|
||||||
|
generated_text: "dummy".to_owned(),
|
||||||
|
}])
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn tgi(state: State<AppState>) -> Json<GeneratedText> {
|
||||||
|
let mut lock = state.counter.lock().await;
|
||||||
|
*lock += 1;
|
||||||
|
Json(GeneratedText {
|
||||||
|
generated_text: "dummy".to_owned(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn log_headers(headers: HeaderMap, state: State<AppState>) -> Json<GeneratedText> {
|
||||||
|
let mut lock = state.counter.lock().await;
|
||||||
|
*lock += 1;
|
||||||
|
for (name, value) in headers.iter() {
|
||||||
|
println!("{lock} - {}: {}", name, value.to_str().unwrap());
|
||||||
|
}
|
||||||
|
Json(GeneratedText {
|
||||||
|
generated_text: "dummy".to_owned(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() {
|
||||||
|
let app_state = AppState {
|
||||||
|
counter: Arc::new(Mutex::new(0)),
|
||||||
|
};
|
||||||
|
let app = Router::new()
|
||||||
|
.route("/", post(default))
|
||||||
|
.route("/tgi", post(tgi))
|
||||||
|
.route("/headers", post(log_headers))
|
||||||
|
.with_state(app_state);
|
||||||
|
let addr: SocketAddr = format!("{}:{}", "0.0.0.0", 4242)
|
||||||
|
.parse()
|
||||||
|
.expect("string to parse to socket addr");
|
||||||
|
println!("starting server {}:{}", addr.ip().to_string(), addr.port(),);
|
||||||
|
|
||||||
|
axum::Server::bind(&addr)
|
||||||
|
.serve(app.into_make_service())
|
||||||
|
.await
|
||||||
|
.expect("server to start");
|
||||||
|
}
|
|
@ -2,7 +2,6 @@
|
||||||
name = "xtask"
|
name = "xtask"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
publish = false
|
publish = false
|
||||||
rust-version.workspace = true
|
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
authors.workspace = true
|
authors.workspace = true
|
||||||
|
|
Loading…
Reference in a new issue