feat: add tokenizer for context window size check (#5)

* feat: add tokenizer for context window size check

* refactor: clippy

* refactor: use `format!` instead of mut

* refactor: remove unnecessary cloning
This commit is contained in:
Luc Georges 2023-08-27 00:21:55 +02:00 committed by GitHub
parent 6f455eca18
commit 073be09042
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 1233 additions and 61 deletions

1059
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -10,6 +10,7 @@ home = "0.5"
ropey = "1.6"
serde = { version = "1", features = ["derive"] }
reqwest = { version = "0.11", features = ["json"] }
tokenizers = "0.13"
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "rt-multi-thread"] }
tower-lsp = "0.20"
tracing = "0.1"

View file

@ -6,7 +6,10 @@ use ropey::Rope;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Display;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokenizers::Tokenizer;
use tokio::io::AsyncWriteExt;
use tokio::sync::RwLock;
use tower_lsp::jsonrpc::{Error, Result};
use tower_lsp::lsp_types::*;
@ -104,11 +107,13 @@ impl Document {
#[derive(Debug)]
struct Backend {
cache_dir: PathBuf,
client: Client,
document_map: Arc<RwLock<HashMap<String, Document>>>,
http_client: reqwest::Client,
workspace_folders: Arc<RwLock<Option<Vec<WorkspaceFolder>>>>,
language_comments: HashMap<String, LanguageComment>,
tokenizer_map: Arc<RwLock<HashMap<String, Tokenizer>>>,
}
#[derive(Debug, Deserialize, Serialize)]
@ -124,6 +129,8 @@ struct CompletionParams {
fim: FimParams,
api_token: Option<String>,
model: String,
tokenizer_path: Option<String>,
context_window: usize,
}
fn internal_error<E: Display>(err: E) -> Error {
@ -138,7 +145,7 @@ fn internal_error<E: Display>(err: E) -> Error {
fn file_path_comment(
file_url: Url,
file_language_id: String,
file_language_id: &str,
workspace_folders: Option<&Vec<WorkspaceFolder>>,
language_comments: &HashMap<String, LanguageComment>,
) -> String {
@ -155,47 +162,106 @@ fn file_path_comment(
} else {
file_path
};
let lc = match language_comments.get(&file_language_id) {
let lc = match language_comments.get(file_language_id) {
Some(id) => id.clone(),
None => LanguageComment {
open: "//".to_owned(),
close: None,
},
};
let mut commented_path = lc.open;
commented_path.push(' ');
commented_path.push_str(&path_in_workspace);
if let Some(close) = lc.close {
commented_path.push(' ');
commented_path.push_str(&close);
}
commented_path.push('\n');
commented_path
let close = if let Some(close) = lc.close {
format!(" {close}")
} else {
"".to_owned()
};
format!("{} {path_in_workspace}{close}\n", lc.open)
}
fn build_prompt(pos: Position, text: &Rope, fim: &FimParams, file_path: String) -> Result<String> {
let mut prompt = file_path;
let cursor_offset = text
.try_line_to_char(pos.line as usize)
.map_err(internal_error)?
+ pos.character as usize;
let text_len = text.len_chars();
// XXX: not sure this is useful, rather be safe than sorry
let cursor_offset = if cursor_offset > text_len {
text_len
} else {
cursor_offset
};
fn build_prompt(
pos: Position,
text: &Rope,
fim: &FimParams,
file_path: String,
tokenizer: Tokenizer,
context_window: usize,
) -> Result<String> {
if fim.enabled {
prompt.push_str(&fim.prefix);
prompt.push_str(&text.slice(0..cursor_offset).to_string());
prompt.push_str(&fim.suffix);
prompt.push_str(&text.slice(cursor_offset..text_len).to_string());
prompt.push_str(&fim.middle);
Ok(prompt)
let mut token_count = context_window;
let mut before_iter = text.lines_at(pos.line as usize + 1).reversed();
let mut after_iter = text.lines_at(pos.line as usize);
let mut before_line = before_iter.next();
let col = pos.character as usize;
if let Some(line) = before_line {
before_line = Some(line.slice(0..col));
}
let mut after_line = after_iter.next();
if let Some(line) = after_line {
after_line = Some(line.slice(col..));
}
let mut before = vec![];
let mut after = String::new();
while before_line.is_some() || after_line.is_some() {
if let Some(before_line) = before_line {
let before_line = before_line.to_string();
let tokens = tokenizer
.encode(before_line.clone(), false)
.map_err(internal_error)?
.len();
if tokens > token_count {
break;
}
token_count -= tokens;
before.push(before_line);
}
if let Some(after_line) = after_line {
let after_line = after_line.to_string();
let tokens = tokenizer
.encode(after_line.clone(), false)
.map_err(internal_error)?
.len();
if tokens > token_count {
break;
}
token_count -= tokens;
after.push_str(&after_line);
}
before_line = before_iter.next();
after_line = after_iter.next();
}
Ok(format!(
"{}{}{}{}{}{}",
file_path,
fim.prefix,
before.into_iter().rev().collect::<Vec<_>>().join(""),
fim.suffix,
after,
fim.middle
))
} else {
prompt.push_str(&text.slice(0..cursor_offset).to_string());
Ok(prompt)
let mut token_count = context_window;
let mut before = vec![];
let mut first = true;
for mut line in text.lines_at(pos.line as usize).reversed() {
if first {
line = line.slice(0..pos.character as usize);
first = false;
}
let line = line.to_string();
let tokens = tokenizer
.encode(line.clone(), false)
.map_err(internal_error)?
.len();
if tokens > token_count {
break;
}
token_count -= tokens;
before.push(line);
}
Ok(format!(
"{}{}",
file_path,
&before.into_iter().rev().collect::<Vec<_>>().join("")
))
}
}
@ -203,15 +269,15 @@ async fn request_completion(
http_client: &reqwest::Client,
model: &str,
request_params: RequestParams,
api_token: Option<String>,
api_token: Option<&String>,
prompt: String,
) -> Result<Vec<Generation>> {
let mut req = http_client.post(model).json(&APIRequest {
let mut req = http_client.post(build_url(model)).json(&APIRequest {
inputs: prompt,
parameters: request_params.into(),
});
if let Some(api_token) = api_token.clone() {
if let Some(api_token) = api_token {
req = req.header(AUTHORIZATION, format!("Bearer {api_token}"))
}
@ -232,6 +298,82 @@ fn parse_generations(generations: Vec<Generation>, stop_token: &str) -> Vec<Comp
.collect()
}
async fn download_tokenizer_file(
http_client: &reqwest::Client,
model: &str,
api_token: Option<&String>,
to: impl AsRef<Path>,
) -> Result<()> {
if to.as_ref().exists() {
return Ok(());
}
tokio::fs::create_dir_all(
to.as_ref()
.parent()
.ok_or_else(|| internal_error("tokenizer path has no parent"))?,
)
.await
.map_err(internal_error)?;
let mut req = http_client.get(format!(
"https://huggingface.co/{model}/resolve/main/tokenizer.json"
));
if let Some(api_token) = api_token {
req = req.header(AUTHORIZATION, format!("Bearer {api_token}"))
}
let res = req
.send()
.await
.map_err(internal_error)?
.error_for_status()
.map_err(internal_error)?;
let mut file = tokio::fs::OpenOptions::new()
.write(true)
.create(true)
.open(to)
.await
.map_err(internal_error)?;
file.write_all(&res.bytes().await.map_err(internal_error)?)
.await
.map_err(internal_error)?;
Ok(())
}
async fn get_tokenizer(
model: &str,
tokenizer_map: &mut HashMap<String, Tokenizer>,
tokenizer_path: Option<&String>,
http_client: &reqwest::Client,
cache_dir: impl AsRef<Path>,
api_token: Option<&String>,
) -> Result<Tokenizer> {
if model.starts_with("http://") || model.starts_with("https://") {
let tokenizer = match tokenizer_path {
Some(path) => Tokenizer::from_file(path).map_err(internal_error)?,
None => return Err(internal_error("`tokenizer_path` is null")),
};
Ok(tokenizer)
} else {
match tokenizer_map.get(model) {
Some(tokenizer) => Ok(tokenizer.clone()),
None => {
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
download_tokenizer_file(http_client, model, api_token, &path).await?;
let tokenizer = Tokenizer::from_file(path).map_err(internal_error)?;
tokenizer_map.insert(model.to_owned(), tokenizer.clone());
Ok(tokenizer)
}
}
}
}
fn build_url(model: &str) -> String {
if model.starts_with("http://") || model.starts_with("https://") {
model.to_owned()
} else {
format!("https://api-inference.huggingface.co/models/{model}")
}
}
impl Backend {
async fn get_completions(&self, params: CompletionParams) -> Result<Vec<Completion>> {
info!("get_completions {params:?}");
@ -242,23 +384,34 @@ impl Backend {
.ok_or_else(|| internal_error("failed to find document"))?;
let file_path = file_path_comment(
params.text_document_position.text_document.uri,
document.language_id.clone(),
&document.language_id,
self.workspace_folders.read().await.as_ref(),
&self.language_comments,
);
let tokenizer = get_tokenizer(
&params.model,
&mut *self.tokenizer_map.write().await,
params.tokenizer_path.as_ref(),
&self.http_client,
&self.cache_dir,
params.api_token.as_ref(),
)
.await?;
let prompt = build_prompt(
params.text_document_position.position,
&document.text,
&params.fim,
file_path,
tokenizer,
params.context_window,
)?;
let stop_token = params.request_params.stop_token.clone();
let result = request_completion(
&self.http_client,
&params.model,
params.request_params,
params.api_token,
prompt.clone(),
params.api_token.as_ref(),
prompt,
)
.await?;
@ -281,7 +434,6 @@ impl LanguageServer for Backend {
)),
..Default::default()
},
..Default::default()
})
}
@ -358,7 +510,7 @@ async fn main() {
.await
.expect("failed to create cache dir");
let log_file = rolling::never(cache_dir, "llm-ls.log");
let log_file = rolling::never(&cache_dir, "llm-ls.log");
let builder = tracing_subscriber::fmt()
.with_writer(log_file)
.with_target(true)
@ -377,11 +529,13 @@ async fn main() {
let http_client = reqwest::Client::new();
let (service, socket) = LspService::build(|client| Backend {
cache_dir,
client,
document_map: Arc::new(RwLock::new(HashMap::new())),
http_client,
workspace_folders: Arc::new(RwLock::new(None)),
language_comments: build_language_comments(),
tokenizer_map: Arc::new(RwLock::new(HashMap::new())),
})
.custom_method("llm-ls/getCompletions", Backend::get_completions)
.finish();