feat: improve tokenizer config (#21)

* feat: improve tokenizer config

* fix: add untagged decorator to `TokenizerConfig`

* feat: bump version to `0.2.0`
This commit is contained in:
Luc Georges 2023-09-21 17:57:19 +02:00 committed by GitHub
parent eeb443feb3
commit 787f2a1a26
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 93 additions and 56 deletions

2
Cargo.lock generated
View file

@ -659,7 +659,7 @@ checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3"
[[package]]
name = "llm-ls"
version = "0.1.0"
version = "0.1.1"
dependencies = [
"home",
"reqwest",

View file

@ -1,6 +1,6 @@
[package]
name = "llm-ls"
version = "0.1.1"
version = "0.2.0"
edition = "2021"
[[bin]]

View file

@ -19,6 +19,14 @@ use tracing_subscriber::EnvFilter;
const NAME: &str = "llm-ls";
const VERSION: &str = env!("CARGO_PKG_VERSION");
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
enum TokenizerConfig {
Local { path: PathBuf },
HuggingFace { repository: String },
Download { url: String, to: PathBuf },
}
#[derive(Clone, Debug, Deserialize, Serialize)]
struct RequestParams {
max_new_tokens: u32,
@ -120,9 +128,9 @@ struct Completion {
generated_text: String,
}
#[derive(Debug, Deserialize, Serialize)]
#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
enum IDE {
enum Ide {
Neovim,
VSCode,
JetBrains,
@ -130,26 +138,21 @@ enum IDE {
Jupyter,
Sublime,
VisualStudio,
#[default]
Unknown,
}
impl Default for IDE {
fn default() -> Self {
IDE::Unknown
}
}
impl Display for IDE {
impl Display for Ide {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.serialize(f)
}
}
fn parse_ide<'de, D>(d: D) -> std::result::Result<IDE, D::Error>
fn parse_ide<'de, D>(d: D) -> std::result::Result<Ide, D::Error>
where
D: Deserializer<'de>,
{
Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(IDE::Unknown))
Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown))
}
#[derive(Debug, Deserialize, Serialize)]
@ -159,12 +162,12 @@ struct CompletionParams {
request_params: RequestParams,
#[serde(default)]
#[serde(deserialize_with = "parse_ide")]
ide: IDE,
ide: Ide,
fim: FimParams,
api_token: Option<String>,
model: String,
tokens_to_clear: Vec<String>,
tokenizer_path: Option<String>,
tokenizer_config: Option<TokenizerConfig>,
context_window: usize,
tls_skip_verify_insecure: bool,
}
@ -183,7 +186,7 @@ fn build_prompt(
pos: Position,
text: &Rope,
fim: &FimParams,
tokenizer: Arc<Tokenizer>,
tokenizer: Option<Arc<Tokenizer>>,
context_window: usize,
) -> Result<String> {
let t = Instant::now();
@ -206,10 +209,14 @@ fn build_prompt(
while before_line.is_some() || after_line.is_some() {
if let Some(before_line) = before_line {
let before_line = before_line.to_string();
let tokens = tokenizer
let tokens = if let Some(tokenizer) = tokenizer.clone() {
tokenizer
.encode(before_line.clone(), false)
.map_err(internal_error)?
.len();
.len()
} else {
before_line.len()
};
if tokens > token_count {
break;
}
@ -218,10 +225,14 @@ fn build_prompt(
}
if let Some(after_line) = after_line {
let after_line = after_line.to_string();
let tokens = tokenizer
let tokens = if let Some(tokenizer) = tokenizer.clone() {
tokenizer
.encode(after_line.clone(), false)
.map_err(internal_error)?
.len();
.len()
} else {
after_line.len()
};
if tokens > token_count {
break;
}
@ -253,10 +264,14 @@ fn build_prompt(
first = false;
}
let line = line.to_string();
let tokens = tokenizer
let tokens = if let Some(tokenizer) = tokenizer.clone() {
tokenizer
.encode(line.clone(), false)
.map_err(internal_error)?
.len();
.len()
} else {
line.len()
};
if tokens > token_count {
break;
}
@ -272,7 +287,7 @@ fn build_prompt(
async fn request_completion(
http_client: &reqwest::Client,
ide: IDE,
ide: Ide,
model: &str,
request_params: RequestParams,
api_token: Option<&String>,
@ -311,9 +326,10 @@ fn parse_generations(generations: Vec<Generation>, tokens_to_clear: &[String]) -
async fn download_tokenizer_file(
http_client: &reqwest::Client,
model: &str,
url: &str,
api_token: Option<&String>,
to: impl AsRef<Path>,
ide: Ide,
) -> Result<()> {
if to.as_ref().exists() {
return Ok(());
@ -325,13 +341,9 @@ async fn download_tokenizer_file(
)
.await
.map_err(internal_error)?;
let mut req = http_client.get(format!(
"https://huggingface.co/{model}/resolve/main/tokenizer.json"
));
if let Some(api_token) = api_token {
req = req.header(AUTHORIZATION, format!("Bearer {api_token}"))
}
let res = req
let res = http_client
.get(url)
.headers(build_headers(api_token, ide)?)
.send()
.await
.map_err(internal_error)?
@ -352,27 +364,37 @@ async fn download_tokenizer_file(
async fn get_tokenizer(
model: &str,
tokenizer_map: &mut HashMap<String, Arc<Tokenizer>>,
tokenizer_path: Option<&String>,
tokenizer_config: Option<TokenizerConfig>,
http_client: &reqwest::Client,
cache_dir: impl AsRef<Path>,
api_token: Option<&String>,
) -> Result<Arc<Tokenizer>> {
ide: Ide,
) -> Result<Option<Arc<Tokenizer>>> {
if let Some(tokenizer) = tokenizer_map.get(model) {
return Ok(tokenizer.clone());
return Ok(Some(tokenizer.clone()));
}
let tokenizer = if model.starts_with("http://") || model.starts_with("https://") {
match tokenizer_path {
Some(path) => Arc::new(Tokenizer::from_file(path).map_err(internal_error)?),
None => return Err(internal_error("`tokenizer_path` is null")),
}
} else {
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
download_tokenizer_file(http_client, model, api_token, &path).await?;
if let Some(config) = tokenizer_config {
let tokenizer = match config {
TokenizerConfig::Local { path } => {
Arc::new(Tokenizer::from_file(path).map_err(internal_error)?)
}
TokenizerConfig::HuggingFace { repository } => {
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
let url =
format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json");
download_tokenizer_file(http_client, &url, api_token, &path, ide).await?;
Arc::new(Tokenizer::from_file(path).map_err(internal_error)?)
}
TokenizerConfig::Download { url, to } => {
download_tokenizer_file(http_client, &url, api_token, &to, ide).await?;
Arc::new(Tokenizer::from_file(to).map_err(internal_error)?)
}
};
tokenizer_map.insert(model.to_owned(), tokenizer.clone());
Ok(tokenizer)
Ok(Some(tokenizer))
} else {
Ok(None)
}
}
fn build_url(model: &str) -> String {
@ -394,10 +416,11 @@ impl Backend {
let tokenizer = get_tokenizer(
&params.model,
&mut *self.tokenizer_map.write().await,
params.tokenizer_path.as_ref(),
params.tokenizer_config,
&self.http_client,
&self.cache_dir,
params.api_token.as_ref(),
params.ide,
)
.await?;
let prompt = build_prompt(
@ -508,7 +531,7 @@ impl LanguageServer for Backend {
}
}
fn build_headers(api_token: Option<&String>, ide: IDE) -> Result<HeaderMap> {
fn build_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
headers.insert(

View file

@ -1,7 +1,10 @@
use axum::{extract::State, http::HeaderMap, routing::post, Json, Router};
use serde::{Deserialize, Serialize};
use std::{net::SocketAddr, sync::Arc};
use tokio::sync::Mutex;
use tokio::{
sync::Mutex,
time::{sleep, Duration},
};
#[derive(Clone)]
struct AppState {
@ -41,6 +44,16 @@ async fn log_headers(headers: HeaderMap, state: State<AppState>) -> Json<Generat
})
}
async fn wait(state: State<AppState>) -> Json<GeneratedText> {
let mut lock = state.counter.lock().await;
*lock += 1;
sleep(Duration::from_millis(200)).await;
println!("waited for req {}", lock);
Json(GeneratedText {
generated_text: "dummy".to_owned(),
})
}
#[tokio::main]
async fn main() {
let app_state = AppState {
@ -50,11 +63,12 @@ async fn main() {
.route("/", post(default))
.route("/tgi", post(tgi))
.route("/headers", post(log_headers))
.route("/wait", post(wait))
.with_state(app_state);
let addr: SocketAddr = format!("{}:{}", "0.0.0.0", 4242)
.parse()
.expect("string to parse to socket addr");
println!("starting server {}:{}", addr.ip().to_string(), addr.port(),);
println!("starting server {}:{}", addr.ip(), addr.port(),);
axum::Server::bind(&addr)
.serve(app.into_make_service())