feat: improve tokenizer config (#21)

* feat: improve tokenizer config

* fix: add untagged decorator to `TokenizerConfig`

* feat: bump version to `0.2.0`
This commit is contained in:
Luc Georges 2023-09-21 17:57:19 +02:00 committed by GitHub
parent eeb443feb3
commit 787f2a1a26
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 93 additions and 56 deletions

2
Cargo.lock generated
View file

@ -659,7 +659,7 @@ checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3"
[[package]] [[package]]
name = "llm-ls" name = "llm-ls"
version = "0.1.0" version = "0.1.1"
dependencies = [ dependencies = [
"home", "home",
"reqwest", "reqwest",

View file

@ -1,6 +1,6 @@
[package] [package]
name = "llm-ls" name = "llm-ls"
version = "0.1.1" version = "0.2.0"
edition = "2021" edition = "2021"
[[bin]] [[bin]]

View file

@ -19,6 +19,14 @@ use tracing_subscriber::EnvFilter;
const NAME: &str = "llm-ls"; const NAME: &str = "llm-ls";
const VERSION: &str = env!("CARGO_PKG_VERSION"); const VERSION: &str = env!("CARGO_PKG_VERSION");
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
enum TokenizerConfig {
Local { path: PathBuf },
HuggingFace { repository: String },
Download { url: String, to: PathBuf },
}
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]
struct RequestParams { struct RequestParams {
max_new_tokens: u32, max_new_tokens: u32,
@ -120,9 +128,9 @@ struct Completion {
generated_text: String, generated_text: String,
} }
#[derive(Debug, Deserialize, Serialize)] #[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
enum IDE { enum Ide {
Neovim, Neovim,
VSCode, VSCode,
JetBrains, JetBrains,
@ -130,26 +138,21 @@ enum IDE {
Jupyter, Jupyter,
Sublime, Sublime,
VisualStudio, VisualStudio,
#[default]
Unknown, Unknown,
} }
impl Default for IDE { impl Display for Ide {
fn default() -> Self {
IDE::Unknown
}
}
impl Display for IDE {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.serialize(f) self.serialize(f)
} }
} }
fn parse_ide<'de, D>(d: D) -> std::result::Result<IDE, D::Error> fn parse_ide<'de, D>(d: D) -> std::result::Result<Ide, D::Error>
where where
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(IDE::Unknown)) Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown))
} }
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize)]
@ -159,12 +162,12 @@ struct CompletionParams {
request_params: RequestParams, request_params: RequestParams,
#[serde(default)] #[serde(default)]
#[serde(deserialize_with = "parse_ide")] #[serde(deserialize_with = "parse_ide")]
ide: IDE, ide: Ide,
fim: FimParams, fim: FimParams,
api_token: Option<String>, api_token: Option<String>,
model: String, model: String,
tokens_to_clear: Vec<String>, tokens_to_clear: Vec<String>,
tokenizer_path: Option<String>, tokenizer_config: Option<TokenizerConfig>,
context_window: usize, context_window: usize,
tls_skip_verify_insecure: bool, tls_skip_verify_insecure: bool,
} }
@ -183,7 +186,7 @@ fn build_prompt(
pos: Position, pos: Position,
text: &Rope, text: &Rope,
fim: &FimParams, fim: &FimParams,
tokenizer: Arc<Tokenizer>, tokenizer: Option<Arc<Tokenizer>>,
context_window: usize, context_window: usize,
) -> Result<String> { ) -> Result<String> {
let t = Instant::now(); let t = Instant::now();
@ -206,10 +209,14 @@ fn build_prompt(
while before_line.is_some() || after_line.is_some() { while before_line.is_some() || after_line.is_some() {
if let Some(before_line) = before_line { if let Some(before_line) = before_line {
let before_line = before_line.to_string(); let before_line = before_line.to_string();
let tokens = tokenizer let tokens = if let Some(tokenizer) = tokenizer.clone() {
tokenizer
.encode(before_line.clone(), false) .encode(before_line.clone(), false)
.map_err(internal_error)? .map_err(internal_error)?
.len(); .len()
} else {
before_line.len()
};
if tokens > token_count { if tokens > token_count {
break; break;
} }
@ -218,10 +225,14 @@ fn build_prompt(
} }
if let Some(after_line) = after_line { if let Some(after_line) = after_line {
let after_line = after_line.to_string(); let after_line = after_line.to_string();
let tokens = tokenizer let tokens = if let Some(tokenizer) = tokenizer.clone() {
tokenizer
.encode(after_line.clone(), false) .encode(after_line.clone(), false)
.map_err(internal_error)? .map_err(internal_error)?
.len(); .len()
} else {
after_line.len()
};
if tokens > token_count { if tokens > token_count {
break; break;
} }
@ -253,10 +264,14 @@ fn build_prompt(
first = false; first = false;
} }
let line = line.to_string(); let line = line.to_string();
let tokens = tokenizer let tokens = if let Some(tokenizer) = tokenizer.clone() {
tokenizer
.encode(line.clone(), false) .encode(line.clone(), false)
.map_err(internal_error)? .map_err(internal_error)?
.len(); .len()
} else {
line.len()
};
if tokens > token_count { if tokens > token_count {
break; break;
} }
@ -272,7 +287,7 @@ fn build_prompt(
async fn request_completion( async fn request_completion(
http_client: &reqwest::Client, http_client: &reqwest::Client,
ide: IDE, ide: Ide,
model: &str, model: &str,
request_params: RequestParams, request_params: RequestParams,
api_token: Option<&String>, api_token: Option<&String>,
@ -311,9 +326,10 @@ fn parse_generations(generations: Vec<Generation>, tokens_to_clear: &[String]) -
async fn download_tokenizer_file( async fn download_tokenizer_file(
http_client: &reqwest::Client, http_client: &reqwest::Client,
model: &str, url: &str,
api_token: Option<&String>, api_token: Option<&String>,
to: impl AsRef<Path>, to: impl AsRef<Path>,
ide: Ide,
) -> Result<()> { ) -> Result<()> {
if to.as_ref().exists() { if to.as_ref().exists() {
return Ok(()); return Ok(());
@ -325,13 +341,9 @@ async fn download_tokenizer_file(
) )
.await .await
.map_err(internal_error)?; .map_err(internal_error)?;
let mut req = http_client.get(format!( let res = http_client
"https://huggingface.co/{model}/resolve/main/tokenizer.json" .get(url)
)); .headers(build_headers(api_token, ide)?)
if let Some(api_token) = api_token {
req = req.header(AUTHORIZATION, format!("Bearer {api_token}"))
}
let res = req
.send() .send()
.await .await
.map_err(internal_error)? .map_err(internal_error)?
@ -352,27 +364,37 @@ async fn download_tokenizer_file(
async fn get_tokenizer( async fn get_tokenizer(
model: &str, model: &str,
tokenizer_map: &mut HashMap<String, Arc<Tokenizer>>, tokenizer_map: &mut HashMap<String, Arc<Tokenizer>>,
tokenizer_path: Option<&String>, tokenizer_config: Option<TokenizerConfig>,
http_client: &reqwest::Client, http_client: &reqwest::Client,
cache_dir: impl AsRef<Path>, cache_dir: impl AsRef<Path>,
api_token: Option<&String>, api_token: Option<&String>,
) -> Result<Arc<Tokenizer>> { ide: Ide,
) -> Result<Option<Arc<Tokenizer>>> {
if let Some(tokenizer) = tokenizer_map.get(model) { if let Some(tokenizer) = tokenizer_map.get(model) {
return Ok(tokenizer.clone()); return Ok(Some(tokenizer.clone()));
} }
let tokenizer = if model.starts_with("http://") || model.starts_with("https://") { if let Some(config) = tokenizer_config {
match tokenizer_path { let tokenizer = match config {
Some(path) => Arc::new(Tokenizer::from_file(path).map_err(internal_error)?), TokenizerConfig::Local { path } => {
None => return Err(internal_error("`tokenizer_path` is null")),
}
} else {
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
download_tokenizer_file(http_client, model, api_token, &path).await?;
Arc::new(Tokenizer::from_file(path).map_err(internal_error)?) Arc::new(Tokenizer::from_file(path).map_err(internal_error)?)
}
TokenizerConfig::HuggingFace { repository } => {
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
let url =
format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json");
download_tokenizer_file(http_client, &url, api_token, &path, ide).await?;
Arc::new(Tokenizer::from_file(path).map_err(internal_error)?)
}
TokenizerConfig::Download { url, to } => {
download_tokenizer_file(http_client, &url, api_token, &to, ide).await?;
Arc::new(Tokenizer::from_file(to).map_err(internal_error)?)
}
}; };
tokenizer_map.insert(model.to_owned(), tokenizer.clone()); tokenizer_map.insert(model.to_owned(), tokenizer.clone());
Ok(tokenizer) Ok(Some(tokenizer))
} else {
Ok(None)
}
} }
fn build_url(model: &str) -> String { fn build_url(model: &str) -> String {
@ -394,10 +416,11 @@ impl Backend {
let tokenizer = get_tokenizer( let tokenizer = get_tokenizer(
&params.model, &params.model,
&mut *self.tokenizer_map.write().await, &mut *self.tokenizer_map.write().await,
params.tokenizer_path.as_ref(), params.tokenizer_config,
&self.http_client, &self.http_client,
&self.cache_dir, &self.cache_dir,
params.api_token.as_ref(), params.api_token.as_ref(),
params.ide,
) )
.await?; .await?;
let prompt = build_prompt( let prompt = build_prompt(
@ -508,7 +531,7 @@ impl LanguageServer for Backend {
} }
} }
fn build_headers(api_token: Option<&String>, ide: IDE) -> Result<HeaderMap> { fn build_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}"); let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
headers.insert( headers.insert(

View file

@ -1,7 +1,10 @@
use axum::{extract::State, http::HeaderMap, routing::post, Json, Router}; use axum::{extract::State, http::HeaderMap, routing::post, Json, Router};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{net::SocketAddr, sync::Arc}; use std::{net::SocketAddr, sync::Arc};
use tokio::sync::Mutex; use tokio::{
sync::Mutex,
time::{sleep, Duration},
};
#[derive(Clone)] #[derive(Clone)]
struct AppState { struct AppState {
@ -41,6 +44,16 @@ async fn log_headers(headers: HeaderMap, state: State<AppState>) -> Json<Generat
}) })
} }
async fn wait(state: State<AppState>) -> Json<GeneratedText> {
let mut lock = state.counter.lock().await;
*lock += 1;
sleep(Duration::from_millis(200)).await;
println!("waited for req {}", lock);
Json(GeneratedText {
generated_text: "dummy".to_owned(),
})
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
let app_state = AppState { let app_state = AppState {
@ -50,11 +63,12 @@ async fn main() {
.route("/", post(default)) .route("/", post(default))
.route("/tgi", post(tgi)) .route("/tgi", post(tgi))
.route("/headers", post(log_headers)) .route("/headers", post(log_headers))
.route("/wait", post(wait))
.with_state(app_state); .with_state(app_state);
let addr: SocketAddr = format!("{}:{}", "0.0.0.0", 4242) let addr: SocketAddr = format!("{}:{}", "0.0.0.0", 4242)
.parse() .parse()
.expect("string to parse to socket addr"); .expect("string to parse to socket addr");
println!("starting server {}:{}", addr.ip().to_string(), addr.port(),); println!("starting server {}:{}", addr.ip(), addr.port(),);
axum::Server::bind(&addr) axum::Server::bind(&addr)
.serve(app.into_make_service()) .serve(app.into_make_service())