perf: put Tokenizer
in an Arc
to avoid cloning (#11)
This commit is contained in:
parent
d8b7e05a20
commit
8c92eaa994
|
@ -41,6 +41,7 @@ struct APIParams {
|
|||
temperature: f32,
|
||||
do_sample: bool,
|
||||
top_p: f32,
|
||||
#[allow(dead_code)]
|
||||
#[serde(skip_serializing)]
|
||||
stop: Option<Vec<String>>,
|
||||
return_full_text: bool,
|
||||
|
@ -109,7 +110,7 @@ struct Backend {
|
|||
document_map: Arc<RwLock<HashMap<String, Document>>>,
|
||||
http_client: reqwest::Client,
|
||||
workspace_folders: Arc<RwLock<Option<Vec<WorkspaceFolder>>>>,
|
||||
tokenizer_map: Arc<RwLock<HashMap<String, Tokenizer>>>,
|
||||
tokenizer_map: Arc<RwLock<HashMap<String, Arc<Tokenizer>>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
|
@ -144,7 +145,7 @@ fn build_prompt(
|
|||
pos: Position,
|
||||
text: &Rope,
|
||||
fim: &FimParams,
|
||||
tokenizer: Tokenizer,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
context_window: usize,
|
||||
) -> Result<String> {
|
||||
if fim.enabled {
|
||||
|
@ -298,30 +299,28 @@ async fn download_tokenizer_file(
|
|||
|
||||
async fn get_tokenizer(
|
||||
model: &str,
|
||||
tokenizer_map: &mut HashMap<String, Tokenizer>,
|
||||
tokenizer_map: &mut HashMap<String, Arc<Tokenizer>>,
|
||||
tokenizer_path: Option<&String>,
|
||||
http_client: &reqwest::Client,
|
||||
cache_dir: impl AsRef<Path>,
|
||||
api_token: Option<&String>,
|
||||
) -> Result<Tokenizer> {
|
||||
if model.starts_with("http://") || model.starts_with("https://") {
|
||||
let tokenizer = match tokenizer_path {
|
||||
Some(path) => Tokenizer::from_file(path).map_err(internal_error)?,
|
||||
None => return Err(internal_error("`tokenizer_path` is null")),
|
||||
};
|
||||
Ok(tokenizer)
|
||||
} else {
|
||||
match tokenizer_map.get(model) {
|
||||
Some(tokenizer) => Ok(tokenizer.clone()),
|
||||
None => {
|
||||
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
|
||||
download_tokenizer_file(http_client, model, api_token, &path).await?;
|
||||
let tokenizer = Tokenizer::from_file(path).map_err(internal_error)?;
|
||||
tokenizer_map.insert(model.to_owned(), tokenizer.clone());
|
||||
Ok(tokenizer)
|
||||
}
|
||||
}
|
||||
) -> Result<Arc<Tokenizer>> {
|
||||
if let Some(tokenizer) = tokenizer_map.get(model) {
|
||||
return Ok(tokenizer.clone());
|
||||
}
|
||||
let tokenizer = if model.starts_with("http://") || model.starts_with("https://") {
|
||||
match tokenizer_path {
|
||||
Some(path) => Arc::new(Tokenizer::from_file(path).map_err(internal_error)?),
|
||||
None => return Err(internal_error("`tokenizer_path` is null")),
|
||||
}
|
||||
} else {
|
||||
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
|
||||
download_tokenizer_file(http_client, model, api_token, &path).await?;
|
||||
Arc::new(Tokenizer::from_file(path).map_err(internal_error)?)
|
||||
};
|
||||
|
||||
tokenizer_map.insert(model.to_owned(), tokenizer.clone());
|
||||
Ok(tokenizer)
|
||||
}
|
||||
|
||||
fn build_url(model: &str) -> String {
|
||||
|
|
Loading…
Reference in a new issue