fix: don't use tokenizer on config error (#22)

This commit is contained in:
Luc Georges 2023-09-25 15:10:29 +02:00 committed by GitHub
parent 787f2a1a26
commit fbaf98203f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 21 deletions

2
Cargo.lock generated
View file

@ -659,7 +659,7 @@ checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3"
[[package]] [[package]]
name = "llm-ls" name = "llm-ls"
version = "0.1.1" version = "0.2.1"
dependencies = [ dependencies = [
"home", "home",
"reqwest", "reqwest",

View file

@ -1,6 +1,6 @@
[package] [package]
name = "llm-ls" name = "llm-ls"
version = "0.2.0" version = "0.2.1"
edition = "2021" edition = "2021"
[[bin]] [[bin]]

View file

@ -337,27 +337,50 @@ async fn download_tokenizer_file(
tokio::fs::create_dir_all( tokio::fs::create_dir_all(
to.as_ref() to.as_ref()
.parent() .parent()
.ok_or_else(|| internal_error("tokenizer path has no parent"))?, .ok_or_else(|| internal_error("invalid tokenizer path"))?,
) )
.await .await
.map_err(internal_error)?; .map_err(internal_error)?;
let res = http_client let headers = build_headers(api_token, ide)?;
.get(url)
.headers(build_headers(api_token, ide)?)
.send()
.await
.map_err(internal_error)?
.error_for_status()
.map_err(internal_error)?;
let mut file = tokio::fs::OpenOptions::new() let mut file = tokio::fs::OpenOptions::new()
.write(true) .write(true)
.create(true) .create(true)
.open(to) .open(to)
.await .await
.map_err(internal_error)?; .map_err(internal_error)?;
file.write_all(&res.bytes().await.map_err(internal_error)?) let http_client = http_client.clone();
.await let url = url.to_owned();
.map_err(internal_error)?; tokio::spawn(async move {
let res = match http_client.get(url).headers(headers).send().await {
Ok(res) => res,
Err(err) => {
error!("error sending download request for the tokenzier file: {err}");
return;
}
};
let res = match res.error_for_status() {
Ok(res) => res,
Err(err) => {
error!("API replied with error to the tokenizer file download: {err}");
return;
}
};
let bytes = match res.bytes().await {
Ok(bytes) => bytes,
Err(err) => {
error!("error while streaming tokenizer file bytes: {err}");
return;
}
};
match file.write_all(&bytes).await {
Ok(_) => (),
Err(err) => {
error!("error writing the tokenizer file to disk: {err}");
}
};
})
.await
.map_err(internal_error)?;
Ok(()) Ok(())
} }
@ -375,23 +398,41 @@ async fn get_tokenizer(
} }
if let Some(config) = tokenizer_config { if let Some(config) = tokenizer_config {
let tokenizer = match config { let tokenizer = match config {
TokenizerConfig::Local { path } => { TokenizerConfig::Local { path } => match Tokenizer::from_file(path) {
Arc::new(Tokenizer::from_file(path).map_err(internal_error)?) Ok(tokenizer) => Some(Arc::new(tokenizer)),
} Err(err) => {
error!("error loading tokenizer from file: {err}");
None
}
},
TokenizerConfig::HuggingFace { repository } => { TokenizerConfig::HuggingFace { repository } => {
let path = cache_dir.as_ref().join(model).join("tokenizer.json"); let path = cache_dir.as_ref().join(model).join("tokenizer.json");
let url = let url =
format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json"); format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json");
download_tokenizer_file(http_client, &url, api_token, &path, ide).await?; download_tokenizer_file(http_client, &url, api_token, &path, ide).await?;
Arc::new(Tokenizer::from_file(path).map_err(internal_error)?) match Tokenizer::from_file(path) {
Ok(tokenizer) => Some(Arc::new(tokenizer)),
Err(err) => {
error!("error loading tokenizer from file: {err}");
None
}
}
} }
TokenizerConfig::Download { url, to } => { TokenizerConfig::Download { url, to } => {
download_tokenizer_file(http_client, &url, api_token, &to, ide).await?; download_tokenizer_file(http_client, &url, api_token, &to, ide).await?;
Arc::new(Tokenizer::from_file(to).map_err(internal_error)?) match Tokenizer::from_file(to) {
Ok(tokenizer) => Some(Arc::new(tokenizer)),
Err(err) => {
error!("error loading tokenizer from file: {err}");
None
}
}
} }
}; };
tokenizer_map.insert(model.to_owned(), tokenizer.clone()); if let Some(tokenizer) = tokenizer.clone() {
Ok(Some(tokenizer)) tokenizer_map.insert(model.to_owned(), tokenizer.clone());
}
Ok(tokenizer)
} else { } else {
Ok(None) Ok(None)
} }