feat: add tokenizer for context window size check (#5)

* feat: add tokenizer for context window size check

* refactor: clippy

* refactor: use `format!` instead of mut

* refactor: remove unnecessary cloning
This commit is contained in:
Luc Georges 2023-08-27 00:21:55 +02:00 committed by GitHub
parent 6f455eca18
commit 073be09042
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 1233 additions and 61 deletions

1059
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -10,6 +10,7 @@ home = "0.5"
ropey = "1.6" ropey = "1.6"
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
reqwest = { version = "0.11", features = ["json"] } reqwest = { version = "0.11", features = ["json"] }
tokenizers = "0.13"
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "rt-multi-thread"] } tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "rt-multi-thread"] }
tower-lsp = "0.20" tower-lsp = "0.20"
tracing = "0.1" tracing = "0.1"

View file

@ -6,7 +6,10 @@ use ropey::Rope;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Display; use std::fmt::Display;
use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use tokenizers::Tokenizer;
use tokio::io::AsyncWriteExt;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tower_lsp::jsonrpc::{Error, Result}; use tower_lsp::jsonrpc::{Error, Result};
use tower_lsp::lsp_types::*; use tower_lsp::lsp_types::*;
@ -104,11 +107,13 @@ impl Document {
#[derive(Debug)] #[derive(Debug)]
struct Backend { struct Backend {
cache_dir: PathBuf,
client: Client, client: Client,
document_map: Arc<RwLock<HashMap<String, Document>>>, document_map: Arc<RwLock<HashMap<String, Document>>>,
http_client: reqwest::Client, http_client: reqwest::Client,
workspace_folders: Arc<RwLock<Option<Vec<WorkspaceFolder>>>>, workspace_folders: Arc<RwLock<Option<Vec<WorkspaceFolder>>>>,
language_comments: HashMap<String, LanguageComment>, language_comments: HashMap<String, LanguageComment>,
tokenizer_map: Arc<RwLock<HashMap<String, Tokenizer>>>,
} }
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize)]
@ -124,6 +129,8 @@ struct CompletionParams {
fim: FimParams, fim: FimParams,
api_token: Option<String>, api_token: Option<String>,
model: String, model: String,
tokenizer_path: Option<String>,
context_window: usize,
} }
fn internal_error<E: Display>(err: E) -> Error { fn internal_error<E: Display>(err: E) -> Error {
@ -138,7 +145,7 @@ fn internal_error<E: Display>(err: E) -> Error {
fn file_path_comment( fn file_path_comment(
file_url: Url, file_url: Url,
file_language_id: String, file_language_id: &str,
workspace_folders: Option<&Vec<WorkspaceFolder>>, workspace_folders: Option<&Vec<WorkspaceFolder>>,
language_comments: &HashMap<String, LanguageComment>, language_comments: &HashMap<String, LanguageComment>,
) -> String { ) -> String {
@ -155,47 +162,106 @@ fn file_path_comment(
} else { } else {
file_path file_path
}; };
let lc = match language_comments.get(&file_language_id) { let lc = match language_comments.get(file_language_id) {
Some(id) => id.clone(), Some(id) => id.clone(),
None => LanguageComment { None => LanguageComment {
open: "//".to_owned(), open: "//".to_owned(),
close: None, close: None,
}, },
}; };
let mut commented_path = lc.open; let close = if let Some(close) = lc.close {
commented_path.push(' '); format!(" {close}")
commented_path.push_str(&path_in_workspace); } else {
if let Some(close) = lc.close { "".to_owned()
commented_path.push(' '); };
commented_path.push_str(&close); format!("{} {path_in_workspace}{close}\n", lc.open)
}
commented_path.push('\n');
commented_path
} }
fn build_prompt(pos: Position, text: &Rope, fim: &FimParams, file_path: String) -> Result<String> { fn build_prompt(
let mut prompt = file_path; pos: Position,
let cursor_offset = text text: &Rope,
.try_line_to_char(pos.line as usize) fim: &FimParams,
.map_err(internal_error)? file_path: String,
+ pos.character as usize; tokenizer: Tokenizer,
let text_len = text.len_chars(); context_window: usize,
// XXX: not sure this is useful, rather be safe than sorry ) -> Result<String> {
let cursor_offset = if cursor_offset > text_len {
text_len
} else {
cursor_offset
};
if fim.enabled { if fim.enabled {
prompt.push_str(&fim.prefix); let mut token_count = context_window;
prompt.push_str(&text.slice(0..cursor_offset).to_string()); let mut before_iter = text.lines_at(pos.line as usize + 1).reversed();
prompt.push_str(&fim.suffix); let mut after_iter = text.lines_at(pos.line as usize);
prompt.push_str(&text.slice(cursor_offset..text_len).to_string()); let mut before_line = before_iter.next();
prompt.push_str(&fim.middle); let col = pos.character as usize;
Ok(prompt) if let Some(line) = before_line {
before_line = Some(line.slice(0..col));
}
let mut after_line = after_iter.next();
if let Some(line) = after_line {
after_line = Some(line.slice(col..));
}
let mut before = vec![];
let mut after = String::new();
while before_line.is_some() || after_line.is_some() {
if let Some(before_line) = before_line {
let before_line = before_line.to_string();
let tokens = tokenizer
.encode(before_line.clone(), false)
.map_err(internal_error)?
.len();
if tokens > token_count {
break;
}
token_count -= tokens;
before.push(before_line);
}
if let Some(after_line) = after_line {
let after_line = after_line.to_string();
let tokens = tokenizer
.encode(after_line.clone(), false)
.map_err(internal_error)?
.len();
if tokens > token_count {
break;
}
token_count -= tokens;
after.push_str(&after_line);
}
before_line = before_iter.next();
after_line = after_iter.next();
}
Ok(format!(
"{}{}{}{}{}{}",
file_path,
fim.prefix,
before.into_iter().rev().collect::<Vec<_>>().join(""),
fim.suffix,
after,
fim.middle
))
} else { } else {
prompt.push_str(&text.slice(0..cursor_offset).to_string()); let mut token_count = context_window;
Ok(prompt) let mut before = vec![];
let mut first = true;
for mut line in text.lines_at(pos.line as usize).reversed() {
if first {
line = line.slice(0..pos.character as usize);
first = false;
}
let line = line.to_string();
let tokens = tokenizer
.encode(line.clone(), false)
.map_err(internal_error)?
.len();
if tokens > token_count {
break;
}
token_count -= tokens;
before.push(line);
}
Ok(format!(
"{}{}",
file_path,
&before.into_iter().rev().collect::<Vec<_>>().join("")
))
} }
} }
@ -203,15 +269,15 @@ async fn request_completion(
http_client: &reqwest::Client, http_client: &reqwest::Client,
model: &str, model: &str,
request_params: RequestParams, request_params: RequestParams,
api_token: Option<String>, api_token: Option<&String>,
prompt: String, prompt: String,
) -> Result<Vec<Generation>> { ) -> Result<Vec<Generation>> {
let mut req = http_client.post(model).json(&APIRequest { let mut req = http_client.post(build_url(model)).json(&APIRequest {
inputs: prompt, inputs: prompt,
parameters: request_params.into(), parameters: request_params.into(),
}); });
if let Some(api_token) = api_token.clone() { if let Some(api_token) = api_token {
req = req.header(AUTHORIZATION, format!("Bearer {api_token}")) req = req.header(AUTHORIZATION, format!("Bearer {api_token}"))
} }
@ -232,6 +298,82 @@ fn parse_generations(generations: Vec<Generation>, stop_token: &str) -> Vec<Comp
.collect() .collect()
} }
async fn download_tokenizer_file(
http_client: &reqwest::Client,
model: &str,
api_token: Option<&String>,
to: impl AsRef<Path>,
) -> Result<()> {
if to.as_ref().exists() {
return Ok(());
}
tokio::fs::create_dir_all(
to.as_ref()
.parent()
.ok_or_else(|| internal_error("tokenizer path has no parent"))?,
)
.await
.map_err(internal_error)?;
let mut req = http_client.get(format!(
"https://huggingface.co/{model}/resolve/main/tokenizer.json"
));
if let Some(api_token) = api_token {
req = req.header(AUTHORIZATION, format!("Bearer {api_token}"))
}
let res = req
.send()
.await
.map_err(internal_error)?
.error_for_status()
.map_err(internal_error)?;
let mut file = tokio::fs::OpenOptions::new()
.write(true)
.create(true)
.open(to)
.await
.map_err(internal_error)?;
file.write_all(&res.bytes().await.map_err(internal_error)?)
.await
.map_err(internal_error)?;
Ok(())
}
async fn get_tokenizer(
model: &str,
tokenizer_map: &mut HashMap<String, Tokenizer>,
tokenizer_path: Option<&String>,
http_client: &reqwest::Client,
cache_dir: impl AsRef<Path>,
api_token: Option<&String>,
) -> Result<Tokenizer> {
if model.starts_with("http://") || model.starts_with("https://") {
let tokenizer = match tokenizer_path {
Some(path) => Tokenizer::from_file(path).map_err(internal_error)?,
None => return Err(internal_error("`tokenizer_path` is null")),
};
Ok(tokenizer)
} else {
match tokenizer_map.get(model) {
Some(tokenizer) => Ok(tokenizer.clone()),
None => {
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
download_tokenizer_file(http_client, model, api_token, &path).await?;
let tokenizer = Tokenizer::from_file(path).map_err(internal_error)?;
tokenizer_map.insert(model.to_owned(), tokenizer.clone());
Ok(tokenizer)
}
}
}
}
fn build_url(model: &str) -> String {
if model.starts_with("http://") || model.starts_with("https://") {
model.to_owned()
} else {
format!("https://api-inference.huggingface.co/models/{model}")
}
}
impl Backend { impl Backend {
async fn get_completions(&self, params: CompletionParams) -> Result<Vec<Completion>> { async fn get_completions(&self, params: CompletionParams) -> Result<Vec<Completion>> {
info!("get_completions {params:?}"); info!("get_completions {params:?}");
@ -242,23 +384,34 @@ impl Backend {
.ok_or_else(|| internal_error("failed to find document"))?; .ok_or_else(|| internal_error("failed to find document"))?;
let file_path = file_path_comment( let file_path = file_path_comment(
params.text_document_position.text_document.uri, params.text_document_position.text_document.uri,
document.language_id.clone(), &document.language_id,
self.workspace_folders.read().await.as_ref(), self.workspace_folders.read().await.as_ref(),
&self.language_comments, &self.language_comments,
); );
let tokenizer = get_tokenizer(
&params.model,
&mut *self.tokenizer_map.write().await,
params.tokenizer_path.as_ref(),
&self.http_client,
&self.cache_dir,
params.api_token.as_ref(),
)
.await?;
let prompt = build_prompt( let prompt = build_prompt(
params.text_document_position.position, params.text_document_position.position,
&document.text, &document.text,
&params.fim, &params.fim,
file_path, file_path,
tokenizer,
params.context_window,
)?; )?;
let stop_token = params.request_params.stop_token.clone(); let stop_token = params.request_params.stop_token.clone();
let result = request_completion( let result = request_completion(
&self.http_client, &self.http_client,
&params.model, &params.model,
params.request_params, params.request_params,
params.api_token, params.api_token.as_ref(),
prompt.clone(), prompt,
) )
.await?; .await?;
@ -281,7 +434,6 @@ impl LanguageServer for Backend {
)), )),
..Default::default() ..Default::default()
}, },
..Default::default()
}) })
} }
@ -358,7 +510,7 @@ async fn main() {
.await .await
.expect("failed to create cache dir"); .expect("failed to create cache dir");
let log_file = rolling::never(cache_dir, "llm-ls.log"); let log_file = rolling::never(&cache_dir, "llm-ls.log");
let builder = tracing_subscriber::fmt() let builder = tracing_subscriber::fmt()
.with_writer(log_file) .with_writer(log_file)
.with_target(true) .with_target(true)
@ -377,11 +529,13 @@ async fn main() {
let http_client = reqwest::Client::new(); let http_client = reqwest::Client::new();
let (service, socket) = LspService::build(|client| Backend { let (service, socket) = LspService::build(|client| Backend {
cache_dir,
client, client,
document_map: Arc::new(RwLock::new(HashMap::new())), document_map: Arc::new(RwLock::new(HashMap::new())),
http_client, http_client,
workspace_folders: Arc::new(RwLock::new(None)), workspace_folders: Arc::new(RwLock::new(None)),
language_comments: build_language_comments(), language_comments: build_language_comments(),
tokenizer_map: Arc::new(RwLock::new(HashMap::new())),
}) })
.custom_method("llm-ls/getCompletions", Backend::get_completions) .custom_method("llm-ls/getCompletions", Backend::get_completions)
.finish(); .finish();