feat: improve tokenizer config (#21)
* feat: improve tokenizer config * fix: add untagged decorator to `TokenizerConfig` * feat: bump version to `0.2.0`
This commit is contained in:
parent
eeb443feb3
commit
787f2a1a26
2
Cargo.lock
generated
2
Cargo.lock
generated
|
@ -659,7 +659,7 @@ checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "llm-ls"
|
name = "llm-ls"
|
||||||
version = "0.1.0"
|
version = "0.1.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"home",
|
"home",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "llm-ls"
|
name = "llm-ls"
|
||||||
version = "0.1.1"
|
version = "0.2.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
|
|
|
@ -19,6 +19,14 @@ use tracing_subscriber::EnvFilter;
|
||||||
const NAME: &str = "llm-ls";
|
const NAME: &str = "llm-ls";
|
||||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum TokenizerConfig {
|
||||||
|
Local { path: PathBuf },
|
||||||
|
HuggingFace { repository: String },
|
||||||
|
Download { url: String, to: PathBuf },
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
struct RequestParams {
|
struct RequestParams {
|
||||||
max_new_tokens: u32,
|
max_new_tokens: u32,
|
||||||
|
@ -120,9 +128,9 @@ struct Completion {
|
||||||
generated_text: String,
|
generated_text: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
enum IDE {
|
enum Ide {
|
||||||
Neovim,
|
Neovim,
|
||||||
VSCode,
|
VSCode,
|
||||||
JetBrains,
|
JetBrains,
|
||||||
|
@ -130,26 +138,21 @@ enum IDE {
|
||||||
Jupyter,
|
Jupyter,
|
||||||
Sublime,
|
Sublime,
|
||||||
VisualStudio,
|
VisualStudio,
|
||||||
|
#[default]
|
||||||
Unknown,
|
Unknown,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for IDE {
|
impl Display for Ide {
|
||||||
fn default() -> Self {
|
|
||||||
IDE::Unknown
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for IDE {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
self.serialize(f)
|
self.serialize(f)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_ide<'de, D>(d: D) -> std::result::Result<IDE, D::Error>
|
fn parse_ide<'de, D>(d: D) -> std::result::Result<Ide, D::Error>
|
||||||
where
|
where
|
||||||
D: Deserializer<'de>,
|
D: Deserializer<'de>,
|
||||||
{
|
{
|
||||||
Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(IDE::Unknown))
|
Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
@ -159,12 +162,12 @@ struct CompletionParams {
|
||||||
request_params: RequestParams,
|
request_params: RequestParams,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[serde(deserialize_with = "parse_ide")]
|
#[serde(deserialize_with = "parse_ide")]
|
||||||
ide: IDE,
|
ide: Ide,
|
||||||
fim: FimParams,
|
fim: FimParams,
|
||||||
api_token: Option<String>,
|
api_token: Option<String>,
|
||||||
model: String,
|
model: String,
|
||||||
tokens_to_clear: Vec<String>,
|
tokens_to_clear: Vec<String>,
|
||||||
tokenizer_path: Option<String>,
|
tokenizer_config: Option<TokenizerConfig>,
|
||||||
context_window: usize,
|
context_window: usize,
|
||||||
tls_skip_verify_insecure: bool,
|
tls_skip_verify_insecure: bool,
|
||||||
}
|
}
|
||||||
|
@ -183,7 +186,7 @@ fn build_prompt(
|
||||||
pos: Position,
|
pos: Position,
|
||||||
text: &Rope,
|
text: &Rope,
|
||||||
fim: &FimParams,
|
fim: &FimParams,
|
||||||
tokenizer: Arc<Tokenizer>,
|
tokenizer: Option<Arc<Tokenizer>>,
|
||||||
context_window: usize,
|
context_window: usize,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
let t = Instant::now();
|
let t = Instant::now();
|
||||||
|
@ -206,10 +209,14 @@ fn build_prompt(
|
||||||
while before_line.is_some() || after_line.is_some() {
|
while before_line.is_some() || after_line.is_some() {
|
||||||
if let Some(before_line) = before_line {
|
if let Some(before_line) = before_line {
|
||||||
let before_line = before_line.to_string();
|
let before_line = before_line.to_string();
|
||||||
let tokens = tokenizer
|
let tokens = if let Some(tokenizer) = tokenizer.clone() {
|
||||||
|
tokenizer
|
||||||
.encode(before_line.clone(), false)
|
.encode(before_line.clone(), false)
|
||||||
.map_err(internal_error)?
|
.map_err(internal_error)?
|
||||||
.len();
|
.len()
|
||||||
|
} else {
|
||||||
|
before_line.len()
|
||||||
|
};
|
||||||
if tokens > token_count {
|
if tokens > token_count {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -218,10 +225,14 @@ fn build_prompt(
|
||||||
}
|
}
|
||||||
if let Some(after_line) = after_line {
|
if let Some(after_line) = after_line {
|
||||||
let after_line = after_line.to_string();
|
let after_line = after_line.to_string();
|
||||||
let tokens = tokenizer
|
let tokens = if let Some(tokenizer) = tokenizer.clone() {
|
||||||
|
tokenizer
|
||||||
.encode(after_line.clone(), false)
|
.encode(after_line.clone(), false)
|
||||||
.map_err(internal_error)?
|
.map_err(internal_error)?
|
||||||
.len();
|
.len()
|
||||||
|
} else {
|
||||||
|
after_line.len()
|
||||||
|
};
|
||||||
if tokens > token_count {
|
if tokens > token_count {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -253,10 +264,14 @@ fn build_prompt(
|
||||||
first = false;
|
first = false;
|
||||||
}
|
}
|
||||||
let line = line.to_string();
|
let line = line.to_string();
|
||||||
let tokens = tokenizer
|
let tokens = if let Some(tokenizer) = tokenizer.clone() {
|
||||||
|
tokenizer
|
||||||
.encode(line.clone(), false)
|
.encode(line.clone(), false)
|
||||||
.map_err(internal_error)?
|
.map_err(internal_error)?
|
||||||
.len();
|
.len()
|
||||||
|
} else {
|
||||||
|
line.len()
|
||||||
|
};
|
||||||
if tokens > token_count {
|
if tokens > token_count {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -272,7 +287,7 @@ fn build_prompt(
|
||||||
|
|
||||||
async fn request_completion(
|
async fn request_completion(
|
||||||
http_client: &reqwest::Client,
|
http_client: &reqwest::Client,
|
||||||
ide: IDE,
|
ide: Ide,
|
||||||
model: &str,
|
model: &str,
|
||||||
request_params: RequestParams,
|
request_params: RequestParams,
|
||||||
api_token: Option<&String>,
|
api_token: Option<&String>,
|
||||||
|
@ -311,9 +326,10 @@ fn parse_generations(generations: Vec<Generation>, tokens_to_clear: &[String]) -
|
||||||
|
|
||||||
async fn download_tokenizer_file(
|
async fn download_tokenizer_file(
|
||||||
http_client: &reqwest::Client,
|
http_client: &reqwest::Client,
|
||||||
model: &str,
|
url: &str,
|
||||||
api_token: Option<&String>,
|
api_token: Option<&String>,
|
||||||
to: impl AsRef<Path>,
|
to: impl AsRef<Path>,
|
||||||
|
ide: Ide,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
if to.as_ref().exists() {
|
if to.as_ref().exists() {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
|
@ -325,13 +341,9 @@ async fn download_tokenizer_file(
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(internal_error)?;
|
.map_err(internal_error)?;
|
||||||
let mut req = http_client.get(format!(
|
let res = http_client
|
||||||
"https://huggingface.co/{model}/resolve/main/tokenizer.json"
|
.get(url)
|
||||||
));
|
.headers(build_headers(api_token, ide)?)
|
||||||
if let Some(api_token) = api_token {
|
|
||||||
req = req.header(AUTHORIZATION, format!("Bearer {api_token}"))
|
|
||||||
}
|
|
||||||
let res = req
|
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.map_err(internal_error)?
|
.map_err(internal_error)?
|
||||||
|
@ -352,27 +364,37 @@ async fn download_tokenizer_file(
|
||||||
async fn get_tokenizer(
|
async fn get_tokenizer(
|
||||||
model: &str,
|
model: &str,
|
||||||
tokenizer_map: &mut HashMap<String, Arc<Tokenizer>>,
|
tokenizer_map: &mut HashMap<String, Arc<Tokenizer>>,
|
||||||
tokenizer_path: Option<&String>,
|
tokenizer_config: Option<TokenizerConfig>,
|
||||||
http_client: &reqwest::Client,
|
http_client: &reqwest::Client,
|
||||||
cache_dir: impl AsRef<Path>,
|
cache_dir: impl AsRef<Path>,
|
||||||
api_token: Option<&String>,
|
api_token: Option<&String>,
|
||||||
) -> Result<Arc<Tokenizer>> {
|
ide: Ide,
|
||||||
|
) -> Result<Option<Arc<Tokenizer>>> {
|
||||||
if let Some(tokenizer) = tokenizer_map.get(model) {
|
if let Some(tokenizer) = tokenizer_map.get(model) {
|
||||||
return Ok(tokenizer.clone());
|
return Ok(Some(tokenizer.clone()));
|
||||||
}
|
}
|
||||||
let tokenizer = if model.starts_with("http://") || model.starts_with("https://") {
|
if let Some(config) = tokenizer_config {
|
||||||
match tokenizer_path {
|
let tokenizer = match config {
|
||||||
Some(path) => Arc::new(Tokenizer::from_file(path).map_err(internal_error)?),
|
TokenizerConfig::Local { path } => {
|
||||||
None => return Err(internal_error("`tokenizer_path` is null")),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
|
|
||||||
download_tokenizer_file(http_client, model, api_token, &path).await?;
|
|
||||||
Arc::new(Tokenizer::from_file(path).map_err(internal_error)?)
|
Arc::new(Tokenizer::from_file(path).map_err(internal_error)?)
|
||||||
|
}
|
||||||
|
TokenizerConfig::HuggingFace { repository } => {
|
||||||
|
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
|
||||||
|
let url =
|
||||||
|
format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json");
|
||||||
|
download_tokenizer_file(http_client, &url, api_token, &path, ide).await?;
|
||||||
|
Arc::new(Tokenizer::from_file(path).map_err(internal_error)?)
|
||||||
|
}
|
||||||
|
TokenizerConfig::Download { url, to } => {
|
||||||
|
download_tokenizer_file(http_client, &url, api_token, &to, ide).await?;
|
||||||
|
Arc::new(Tokenizer::from_file(to).map_err(internal_error)?)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
tokenizer_map.insert(model.to_owned(), tokenizer.clone());
|
tokenizer_map.insert(model.to_owned(), tokenizer.clone());
|
||||||
Ok(tokenizer)
|
Ok(Some(tokenizer))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_url(model: &str) -> String {
|
fn build_url(model: &str) -> String {
|
||||||
|
@ -394,10 +416,11 @@ impl Backend {
|
||||||
let tokenizer = get_tokenizer(
|
let tokenizer = get_tokenizer(
|
||||||
¶ms.model,
|
¶ms.model,
|
||||||
&mut *self.tokenizer_map.write().await,
|
&mut *self.tokenizer_map.write().await,
|
||||||
params.tokenizer_path.as_ref(),
|
params.tokenizer_config,
|
||||||
&self.http_client,
|
&self.http_client,
|
||||||
&self.cache_dir,
|
&self.cache_dir,
|
||||||
params.api_token.as_ref(),
|
params.api_token.as_ref(),
|
||||||
|
params.ide,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
let prompt = build_prompt(
|
let prompt = build_prompt(
|
||||||
|
@ -508,7 +531,7 @@ impl LanguageServer for Backend {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_headers(api_token: Option<&String>, ide: IDE) -> Result<HeaderMap> {
|
fn build_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
|
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
|
||||||
headers.insert(
|
headers.insert(
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
use axum::{extract::State, http::HeaderMap, routing::post, Json, Router};
|
use axum::{extract::State, http::HeaderMap, routing::post, Json, Router};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{net::SocketAddr, sync::Arc};
|
use std::{net::SocketAddr, sync::Arc};
|
||||||
use tokio::sync::Mutex;
|
use tokio::{
|
||||||
|
sync::Mutex,
|
||||||
|
time::{sleep, Duration},
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct AppState {
|
struct AppState {
|
||||||
|
@ -41,6 +44,16 @@ async fn log_headers(headers: HeaderMap, state: State<AppState>) -> Json<Generat
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn wait(state: State<AppState>) -> Json<GeneratedText> {
|
||||||
|
let mut lock = state.counter.lock().await;
|
||||||
|
*lock += 1;
|
||||||
|
sleep(Duration::from_millis(200)).await;
|
||||||
|
println!("waited for req {}", lock);
|
||||||
|
Json(GeneratedText {
|
||||||
|
generated_text: "dummy".to_owned(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
let app_state = AppState {
|
let app_state = AppState {
|
||||||
|
@ -50,11 +63,12 @@ async fn main() {
|
||||||
.route("/", post(default))
|
.route("/", post(default))
|
||||||
.route("/tgi", post(tgi))
|
.route("/tgi", post(tgi))
|
||||||
.route("/headers", post(log_headers))
|
.route("/headers", post(log_headers))
|
||||||
|
.route("/wait", post(wait))
|
||||||
.with_state(app_state);
|
.with_state(app_state);
|
||||||
let addr: SocketAddr = format!("{}:{}", "0.0.0.0", 4242)
|
let addr: SocketAddr = format!("{}:{}", "0.0.0.0", 4242)
|
||||||
.parse()
|
.parse()
|
||||||
.expect("string to parse to socket addr");
|
.expect("string to parse to socket addr");
|
||||||
println!("starting server {}:{}", addr.ip().to_string(), addr.port(),);
|
println!("starting server {}:{}", addr.ip(), addr.port(),);
|
||||||
|
|
||||||
axum::Server::bind(&addr)
|
axum::Server::bind(&addr)
|
||||||
.serve(app.into_make_service())
|
.serve(app.into_make_service())
|
||||||
|
|
Loading…
Reference in a new issue