fix: don't use tokenizer on config error (#22)

This commit is contained in:
Luc Georges 2023-09-25 15:10:29 +02:00 committed by GitHub
parent 787f2a1a26
commit fbaf98203f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 21 deletions

2
Cargo.lock generated
View file

@ -659,7 +659,7 @@ checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3"
[[package]]
name = "llm-ls"
version = "0.1.1"
version = "0.2.1"
dependencies = [
"home",
"reqwest",

View file

@ -1,6 +1,6 @@
[package]
name = "llm-ls"
version = "0.2.0"
version = "0.2.1"
edition = "2021"
[[bin]]

View file

@ -337,27 +337,50 @@ async fn download_tokenizer_file(
tokio::fs::create_dir_all(
to.as_ref()
.parent()
.ok_or_else(|| internal_error("tokenizer path has no parent"))?,
.ok_or_else(|| internal_error("invalid tokenizer path"))?,
)
.await
.map_err(internal_error)?;
let res = http_client
.get(url)
.headers(build_headers(api_token, ide)?)
.send()
.await
.map_err(internal_error)?
.error_for_status()
.map_err(internal_error)?;
let headers = build_headers(api_token, ide)?;
let mut file = tokio::fs::OpenOptions::new()
.write(true)
.create(true)
.open(to)
.await
.map_err(internal_error)?;
file.write_all(&res.bytes().await.map_err(internal_error)?)
.await
.map_err(internal_error)?;
let http_client = http_client.clone();
let url = url.to_owned();
tokio::spawn(async move {
let res = match http_client.get(url).headers(headers).send().await {
Ok(res) => res,
Err(err) => {
error!("error sending download request for the tokenzier file: {err}");
return;
}
};
let res = match res.error_for_status() {
Ok(res) => res,
Err(err) => {
error!("API replied with error to the tokenizer file download: {err}");
return;
}
};
let bytes = match res.bytes().await {
Ok(bytes) => bytes,
Err(err) => {
error!("error while streaming tokenizer file bytes: {err}");
return;
}
};
match file.write_all(&bytes).await {
Ok(_) => (),
Err(err) => {
error!("error writing the tokenizer file to disk: {err}");
}
};
})
.await
.map_err(internal_error)?;
Ok(())
}
@ -375,23 +398,41 @@ async fn get_tokenizer(
}
if let Some(config) = tokenizer_config {
let tokenizer = match config {
TokenizerConfig::Local { path } => {
Arc::new(Tokenizer::from_file(path).map_err(internal_error)?)
}
TokenizerConfig::Local { path } => match Tokenizer::from_file(path) {
Ok(tokenizer) => Some(Arc::new(tokenizer)),
Err(err) => {
error!("error loading tokenizer from file: {err}");
None
}
},
TokenizerConfig::HuggingFace { repository } => {
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
let url =
format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json");
download_tokenizer_file(http_client, &url, api_token, &path, ide).await?;
Arc::new(Tokenizer::from_file(path).map_err(internal_error)?)
match Tokenizer::from_file(path) {
Ok(tokenizer) => Some(Arc::new(tokenizer)),
Err(err) => {
error!("error loading tokenizer from file: {err}");
None
}
}
}
TokenizerConfig::Download { url, to } => {
download_tokenizer_file(http_client, &url, api_token, &to, ide).await?;
Arc::new(Tokenizer::from_file(to).map_err(internal_error)?)
match Tokenizer::from_file(to) {
Ok(tokenizer) => Some(Arc::new(tokenizer)),
Err(err) => {
error!("error loading tokenizer from file: {err}");
None
}
}
}
};
tokenizer_map.insert(model.to_owned(), tokenizer.clone());
Ok(Some(tokenizer))
if let Some(tokenizer) = tokenizer.clone() {
tokenizer_map.insert(model.to_owned(), tokenizer.clone());
}
Ok(tokenizer)
} else {
Ok(None)
}