feat: improve tokenizer config (#21)
* feat: improve tokenizer config * fix: add untagged decorator to `TokenizerConfig` * feat: bump version to `0.2.0`
This commit is contained in:
parent
eeb443feb3
commit
787f2a1a26
2
Cargo.lock
generated
2
Cargo.lock
generated
|
@ -659,7 +659,7 @@ checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3"
|
|||
|
||||
[[package]]
|
||||
name = "llm-ls"
|
||||
version = "0.1.0"
|
||||
version = "0.1.1"
|
||||
dependencies = [
|
||||
"home",
|
||||
"reqwest",
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "llm-ls"
|
||||
version = "0.1.1"
|
||||
version = "0.2.0"
|
||||
edition = "2021"
|
||||
|
||||
[[bin]]
|
||||
|
|
|
@ -19,6 +19,14 @@ use tracing_subscriber::EnvFilter;
|
|||
const NAME: &str = "llm-ls";
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
enum TokenizerConfig {
|
||||
Local { path: PathBuf },
|
||||
HuggingFace { repository: String },
|
||||
Download { url: String, to: PathBuf },
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
struct RequestParams {
|
||||
max_new_tokens: u32,
|
||||
|
@ -120,9 +128,9 @@ struct Completion {
|
|||
generated_text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum IDE {
|
||||
enum Ide {
|
||||
Neovim,
|
||||
VSCode,
|
||||
JetBrains,
|
||||
|
@ -130,26 +138,21 @@ enum IDE {
|
|||
Jupyter,
|
||||
Sublime,
|
||||
VisualStudio,
|
||||
#[default]
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl Default for IDE {
|
||||
fn default() -> Self {
|
||||
IDE::Unknown
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for IDE {
|
||||
impl Display for Ide {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.serialize(f)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_ide<'de, D>(d: D) -> std::result::Result<IDE, D::Error>
|
||||
fn parse_ide<'de, D>(d: D) -> std::result::Result<Ide, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(IDE::Unknown))
|
||||
Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
|
@ -159,12 +162,12 @@ struct CompletionParams {
|
|||
request_params: RequestParams,
|
||||
#[serde(default)]
|
||||
#[serde(deserialize_with = "parse_ide")]
|
||||
ide: IDE,
|
||||
ide: Ide,
|
||||
fim: FimParams,
|
||||
api_token: Option<String>,
|
||||
model: String,
|
||||
tokens_to_clear: Vec<String>,
|
||||
tokenizer_path: Option<String>,
|
||||
tokenizer_config: Option<TokenizerConfig>,
|
||||
context_window: usize,
|
||||
tls_skip_verify_insecure: bool,
|
||||
}
|
||||
|
@ -183,7 +186,7 @@ fn build_prompt(
|
|||
pos: Position,
|
||||
text: &Rope,
|
||||
fim: &FimParams,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
tokenizer: Option<Arc<Tokenizer>>,
|
||||
context_window: usize,
|
||||
) -> Result<String> {
|
||||
let t = Instant::now();
|
||||
|
@ -206,10 +209,14 @@ fn build_prompt(
|
|||
while before_line.is_some() || after_line.is_some() {
|
||||
if let Some(before_line) = before_line {
|
||||
let before_line = before_line.to_string();
|
||||
let tokens = tokenizer
|
||||
.encode(before_line.clone(), false)
|
||||
.map_err(internal_error)?
|
||||
.len();
|
||||
let tokens = if let Some(tokenizer) = tokenizer.clone() {
|
||||
tokenizer
|
||||
.encode(before_line.clone(), false)
|
||||
.map_err(internal_error)?
|
||||
.len()
|
||||
} else {
|
||||
before_line.len()
|
||||
};
|
||||
if tokens > token_count {
|
||||
break;
|
||||
}
|
||||
|
@ -218,10 +225,14 @@ fn build_prompt(
|
|||
}
|
||||
if let Some(after_line) = after_line {
|
||||
let after_line = after_line.to_string();
|
||||
let tokens = tokenizer
|
||||
.encode(after_line.clone(), false)
|
||||
.map_err(internal_error)?
|
||||
.len();
|
||||
let tokens = if let Some(tokenizer) = tokenizer.clone() {
|
||||
tokenizer
|
||||
.encode(after_line.clone(), false)
|
||||
.map_err(internal_error)?
|
||||
.len()
|
||||
} else {
|
||||
after_line.len()
|
||||
};
|
||||
if tokens > token_count {
|
||||
break;
|
||||
}
|
||||
|
@ -253,10 +264,14 @@ fn build_prompt(
|
|||
first = false;
|
||||
}
|
||||
let line = line.to_string();
|
||||
let tokens = tokenizer
|
||||
.encode(line.clone(), false)
|
||||
.map_err(internal_error)?
|
||||
.len();
|
||||
let tokens = if let Some(tokenizer) = tokenizer.clone() {
|
||||
tokenizer
|
||||
.encode(line.clone(), false)
|
||||
.map_err(internal_error)?
|
||||
.len()
|
||||
} else {
|
||||
line.len()
|
||||
};
|
||||
if tokens > token_count {
|
||||
break;
|
||||
}
|
||||
|
@ -272,7 +287,7 @@ fn build_prompt(
|
|||
|
||||
async fn request_completion(
|
||||
http_client: &reqwest::Client,
|
||||
ide: IDE,
|
||||
ide: Ide,
|
||||
model: &str,
|
||||
request_params: RequestParams,
|
||||
api_token: Option<&String>,
|
||||
|
@ -311,9 +326,10 @@ fn parse_generations(generations: Vec<Generation>, tokens_to_clear: &[String]) -
|
|||
|
||||
async fn download_tokenizer_file(
|
||||
http_client: &reqwest::Client,
|
||||
model: &str,
|
||||
url: &str,
|
||||
api_token: Option<&String>,
|
||||
to: impl AsRef<Path>,
|
||||
ide: Ide,
|
||||
) -> Result<()> {
|
||||
if to.as_ref().exists() {
|
||||
return Ok(());
|
||||
|
@ -325,13 +341,9 @@ async fn download_tokenizer_file(
|
|||
)
|
||||
.await
|
||||
.map_err(internal_error)?;
|
||||
let mut req = http_client.get(format!(
|
||||
"https://huggingface.co/{model}/resolve/main/tokenizer.json"
|
||||
));
|
||||
if let Some(api_token) = api_token {
|
||||
req = req.header(AUTHORIZATION, format!("Bearer {api_token}"))
|
||||
}
|
||||
let res = req
|
||||
let res = http_client
|
||||
.get(url)
|
||||
.headers(build_headers(api_token, ide)?)
|
||||
.send()
|
||||
.await
|
||||
.map_err(internal_error)?
|
||||
|
@ -352,27 +364,37 @@ async fn download_tokenizer_file(
|
|||
async fn get_tokenizer(
|
||||
model: &str,
|
||||
tokenizer_map: &mut HashMap<String, Arc<Tokenizer>>,
|
||||
tokenizer_path: Option<&String>,
|
||||
tokenizer_config: Option<TokenizerConfig>,
|
||||
http_client: &reqwest::Client,
|
||||
cache_dir: impl AsRef<Path>,
|
||||
api_token: Option<&String>,
|
||||
) -> Result<Arc<Tokenizer>> {
|
||||
ide: Ide,
|
||||
) -> Result<Option<Arc<Tokenizer>>> {
|
||||
if let Some(tokenizer) = tokenizer_map.get(model) {
|
||||
return Ok(tokenizer.clone());
|
||||
return Ok(Some(tokenizer.clone()));
|
||||
}
|
||||
let tokenizer = if model.starts_with("http://") || model.starts_with("https://") {
|
||||
match tokenizer_path {
|
||||
Some(path) => Arc::new(Tokenizer::from_file(path).map_err(internal_error)?),
|
||||
None => return Err(internal_error("`tokenizer_path` is null")),
|
||||
}
|
||||
if let Some(config) = tokenizer_config {
|
||||
let tokenizer = match config {
|
||||
TokenizerConfig::Local { path } => {
|
||||
Arc::new(Tokenizer::from_file(path).map_err(internal_error)?)
|
||||
}
|
||||
TokenizerConfig::HuggingFace { repository } => {
|
||||
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
|
||||
let url =
|
||||
format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json");
|
||||
download_tokenizer_file(http_client, &url, api_token, &path, ide).await?;
|
||||
Arc::new(Tokenizer::from_file(path).map_err(internal_error)?)
|
||||
}
|
||||
TokenizerConfig::Download { url, to } => {
|
||||
download_tokenizer_file(http_client, &url, api_token, &to, ide).await?;
|
||||
Arc::new(Tokenizer::from_file(to).map_err(internal_error)?)
|
||||
}
|
||||
};
|
||||
tokenizer_map.insert(model.to_owned(), tokenizer.clone());
|
||||
Ok(Some(tokenizer))
|
||||
} else {
|
||||
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
|
||||
download_tokenizer_file(http_client, model, api_token, &path).await?;
|
||||
Arc::new(Tokenizer::from_file(path).map_err(internal_error)?)
|
||||
};
|
||||
|
||||
tokenizer_map.insert(model.to_owned(), tokenizer.clone());
|
||||
Ok(tokenizer)
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
fn build_url(model: &str) -> String {
|
||||
|
@ -394,10 +416,11 @@ impl Backend {
|
|||
let tokenizer = get_tokenizer(
|
||||
¶ms.model,
|
||||
&mut *self.tokenizer_map.write().await,
|
||||
params.tokenizer_path.as_ref(),
|
||||
params.tokenizer_config,
|
||||
&self.http_client,
|
||||
&self.cache_dir,
|
||||
params.api_token.as_ref(),
|
||||
params.ide,
|
||||
)
|
||||
.await?;
|
||||
let prompt = build_prompt(
|
||||
|
@ -508,7 +531,7 @@ impl LanguageServer for Backend {
|
|||
}
|
||||
}
|
||||
|
||||
fn build_headers(api_token: Option<&String>, ide: IDE) -> Result<HeaderMap> {
|
||||
fn build_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
|
||||
let mut headers = HeaderMap::new();
|
||||
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
|
||||
headers.insert(
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
use axum::{extract::State, http::HeaderMap, routing::post, Json, Router};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::{
|
||||
sync::Mutex,
|
||||
time::{sleep, Duration},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
|
@ -41,6 +44,16 @@ async fn log_headers(headers: HeaderMap, state: State<AppState>) -> Json<Generat
|
|||
})
|
||||
}
|
||||
|
||||
async fn wait(state: State<AppState>) -> Json<GeneratedText> {
|
||||
let mut lock = state.counter.lock().await;
|
||||
*lock += 1;
|
||||
sleep(Duration::from_millis(200)).await;
|
||||
println!("waited for req {}", lock);
|
||||
Json(GeneratedText {
|
||||
generated_text: "dummy".to_owned(),
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let app_state = AppState {
|
||||
|
@ -50,11 +63,12 @@ async fn main() {
|
|||
.route("/", post(default))
|
||||
.route("/tgi", post(tgi))
|
||||
.route("/headers", post(log_headers))
|
||||
.route("/wait", post(wait))
|
||||
.with_state(app_state);
|
||||
let addr: SocketAddr = format!("{}:{}", "0.0.0.0", 4242)
|
||||
.parse()
|
||||
.expect("string to parse to socket addr");
|
||||
println!("starting server {}:{}", addr.ip().to_string(), addr.port(),);
|
||||
println!("starting server {}:{}", addr.ip(), addr.port(),);
|
||||
|
||||
axum::Server::bind(&addr)
|
||||
.serve(app.into_make_service())
|
||||
|
|
Loading…
Reference in a new issue