perf: put Tokenizer
in an Arc
to avoid cloning (#11)
This commit is contained in:
parent
d8b7e05a20
commit
8c92eaa994
|
@ -41,6 +41,7 @@ struct APIParams {
|
||||||
temperature: f32,
|
temperature: f32,
|
||||||
do_sample: bool,
|
do_sample: bool,
|
||||||
top_p: f32,
|
top_p: f32,
|
||||||
|
#[allow(dead_code)]
|
||||||
#[serde(skip_serializing)]
|
#[serde(skip_serializing)]
|
||||||
stop: Option<Vec<String>>,
|
stop: Option<Vec<String>>,
|
||||||
return_full_text: bool,
|
return_full_text: bool,
|
||||||
|
@ -109,7 +110,7 @@ struct Backend {
|
||||||
document_map: Arc<RwLock<HashMap<String, Document>>>,
|
document_map: Arc<RwLock<HashMap<String, Document>>>,
|
||||||
http_client: reqwest::Client,
|
http_client: reqwest::Client,
|
||||||
workspace_folders: Arc<RwLock<Option<Vec<WorkspaceFolder>>>>,
|
workspace_folders: Arc<RwLock<Option<Vec<WorkspaceFolder>>>>,
|
||||||
tokenizer_map: Arc<RwLock<HashMap<String, Tokenizer>>>,
|
tokenizer_map: Arc<RwLock<HashMap<String, Arc<Tokenizer>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
@ -144,7 +145,7 @@ fn build_prompt(
|
||||||
pos: Position,
|
pos: Position,
|
||||||
text: &Rope,
|
text: &Rope,
|
||||||
fim: &FimParams,
|
fim: &FimParams,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Arc<Tokenizer>,
|
||||||
context_window: usize,
|
context_window: usize,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
if fim.enabled {
|
if fim.enabled {
|
||||||
|
@ -298,30 +299,28 @@ async fn download_tokenizer_file(
|
||||||
|
|
||||||
async fn get_tokenizer(
|
async fn get_tokenizer(
|
||||||
model: &str,
|
model: &str,
|
||||||
tokenizer_map: &mut HashMap<String, Tokenizer>,
|
tokenizer_map: &mut HashMap<String, Arc<Tokenizer>>,
|
||||||
tokenizer_path: Option<&String>,
|
tokenizer_path: Option<&String>,
|
||||||
http_client: &reqwest::Client,
|
http_client: &reqwest::Client,
|
||||||
cache_dir: impl AsRef<Path>,
|
cache_dir: impl AsRef<Path>,
|
||||||
api_token: Option<&String>,
|
api_token: Option<&String>,
|
||||||
) -> Result<Tokenizer> {
|
) -> Result<Arc<Tokenizer>> {
|
||||||
if model.starts_with("http://") || model.starts_with("https://") {
|
if let Some(tokenizer) = tokenizer_map.get(model) {
|
||||||
let tokenizer = match tokenizer_path {
|
return Ok(tokenizer.clone());
|
||||||
Some(path) => Tokenizer::from_file(path).map_err(internal_error)?,
|
|
||||||
None => return Err(internal_error("`tokenizer_path` is null")),
|
|
||||||
};
|
|
||||||
Ok(tokenizer)
|
|
||||||
} else {
|
|
||||||
match tokenizer_map.get(model) {
|
|
||||||
Some(tokenizer) => Ok(tokenizer.clone()),
|
|
||||||
None => {
|
|
||||||
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
|
|
||||||
download_tokenizer_file(http_client, model, api_token, &path).await?;
|
|
||||||
let tokenizer = Tokenizer::from_file(path).map_err(internal_error)?;
|
|
||||||
tokenizer_map.insert(model.to_owned(), tokenizer.clone());
|
|
||||||
Ok(tokenizer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
let tokenizer = if model.starts_with("http://") || model.starts_with("https://") {
|
||||||
|
match tokenizer_path {
|
||||||
|
Some(path) => Arc::new(Tokenizer::from_file(path).map_err(internal_error)?),
|
||||||
|
None => return Err(internal_error("`tokenizer_path` is null")),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
|
||||||
|
download_tokenizer_file(http_client, model, api_token, &path).await?;
|
||||||
|
Arc::new(Tokenizer::from_file(path).map_err(internal_error)?)
|
||||||
|
};
|
||||||
|
|
||||||
|
tokenizer_map.insert(model.to_owned(), tokenizer.clone());
|
||||||
|
Ok(tokenizer)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_url(model: &str) -> String {
|
fn build_url(model: &str) -> String {
|
||||||
|
|
Loading…
Reference in a new issue