perf: put Tokenizer in an Arc to avoid cloning (#11)

This commit is contained in:
Luc Georges 2023-09-08 09:02:44 +02:00 committed by GitHub
parent d8b7e05a20
commit 8c92eaa994
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -41,6 +41,7 @@ struct APIParams {
temperature: f32,
do_sample: bool,
top_p: f32,
#[allow(dead_code)]
#[serde(skip_serializing)]
stop: Option<Vec<String>>,
return_full_text: bool,
@ -109,7 +110,7 @@ struct Backend {
document_map: Arc<RwLock<HashMap<String, Document>>>,
http_client: reqwest::Client,
workspace_folders: Arc<RwLock<Option<Vec<WorkspaceFolder>>>>,
tokenizer_map: Arc<RwLock<HashMap<String, Tokenizer>>>,
tokenizer_map: Arc<RwLock<HashMap<String, Arc<Tokenizer>>>>,
}
#[derive(Debug, Deserialize, Serialize)]
@ -144,7 +145,7 @@ fn build_prompt(
pos: Position,
text: &Rope,
fim: &FimParams,
tokenizer: Tokenizer,
tokenizer: Arc<Tokenizer>,
context_window: usize,
) -> Result<String> {
if fim.enabled {
@ -298,30 +299,28 @@ async fn download_tokenizer_file(
async fn get_tokenizer(
model: &str,
tokenizer_map: &mut HashMap<String, Tokenizer>,
tokenizer_map: &mut HashMap<String, Arc<Tokenizer>>,
tokenizer_path: Option<&String>,
http_client: &reqwest::Client,
cache_dir: impl AsRef<Path>,
api_token: Option<&String>,
) -> Result<Tokenizer> {
if model.starts_with("http://") || model.starts_with("https://") {
let tokenizer = match tokenizer_path {
Some(path) => Tokenizer::from_file(path).map_err(internal_error)?,
None => return Err(internal_error("`tokenizer_path` is null")),
};
Ok(tokenizer)
} else {
match tokenizer_map.get(model) {
Some(tokenizer) => Ok(tokenizer.clone()),
None => {
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
download_tokenizer_file(http_client, model, api_token, &path).await?;
let tokenizer = Tokenizer::from_file(path).map_err(internal_error)?;
tokenizer_map.insert(model.to_owned(), tokenizer.clone());
Ok(tokenizer)
}
}
) -> Result<Arc<Tokenizer>> {
if let Some(tokenizer) = tokenizer_map.get(model) {
return Ok(tokenizer.clone());
}
let tokenizer = if model.starts_with("http://") || model.starts_with("https://") {
match tokenizer_path {
Some(path) => Arc::new(Tokenizer::from_file(path).map_err(internal_error)?),
None => return Err(internal_error("`tokenizer_path` is null")),
}
} else {
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
download_tokenizer_file(http_client, model, api_token, &path).await?;
Arc::new(Tokenizer::from_file(path).map_err(internal_error)?)
};
tokenizer_map.insert(model.to_owned(), tokenizer.clone());
Ok(tokenizer)
}
fn build_url(model: &str) -> String {