feat: add tokenizer for context window size check (#5)
* feat: add tokenizer for context window size check * refactor: clippy * refactor: use `format!` instead of mut * refactor: remove unnecessary cloning
This commit is contained in:
parent
6f455eca18
commit
073be09042
1059
Cargo.lock
generated
1059
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -10,6 +10,7 @@ home = "0.5"
|
|||
ropey = "1.6"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
reqwest = { version = "0.11", features = ["json"] }
|
||||
tokenizers = "0.13"
|
||||
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "rt-multi-thread"] }
|
||||
tower-lsp = "0.20"
|
||||
tracing = "0.1"
|
||||
|
|
234
src/main.rs
234
src/main.rs
|
@ -6,7 +6,10 @@ use ropey::Rope;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::sync::RwLock;
|
||||
use tower_lsp::jsonrpc::{Error, Result};
|
||||
use tower_lsp::lsp_types::*;
|
||||
|
@ -104,11 +107,13 @@ impl Document {
|
|||
|
||||
#[derive(Debug)]
|
||||
struct Backend {
|
||||
cache_dir: PathBuf,
|
||||
client: Client,
|
||||
document_map: Arc<RwLock<HashMap<String, Document>>>,
|
||||
http_client: reqwest::Client,
|
||||
workspace_folders: Arc<RwLock<Option<Vec<WorkspaceFolder>>>>,
|
||||
language_comments: HashMap<String, LanguageComment>,
|
||||
tokenizer_map: Arc<RwLock<HashMap<String, Tokenizer>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
|
@ -124,6 +129,8 @@ struct CompletionParams {
|
|||
fim: FimParams,
|
||||
api_token: Option<String>,
|
||||
model: String,
|
||||
tokenizer_path: Option<String>,
|
||||
context_window: usize,
|
||||
}
|
||||
|
||||
fn internal_error<E: Display>(err: E) -> Error {
|
||||
|
@ -138,7 +145,7 @@ fn internal_error<E: Display>(err: E) -> Error {
|
|||
|
||||
fn file_path_comment(
|
||||
file_url: Url,
|
||||
file_language_id: String,
|
||||
file_language_id: &str,
|
||||
workspace_folders: Option<&Vec<WorkspaceFolder>>,
|
||||
language_comments: &HashMap<String, LanguageComment>,
|
||||
) -> String {
|
||||
|
@ -155,47 +162,106 @@ fn file_path_comment(
|
|||
} else {
|
||||
file_path
|
||||
};
|
||||
let lc = match language_comments.get(&file_language_id) {
|
||||
let lc = match language_comments.get(file_language_id) {
|
||||
Some(id) => id.clone(),
|
||||
None => LanguageComment {
|
||||
open: "//".to_owned(),
|
||||
close: None,
|
||||
},
|
||||
};
|
||||
let mut commented_path = lc.open;
|
||||
commented_path.push(' ');
|
||||
commented_path.push_str(&path_in_workspace);
|
||||
if let Some(close) = lc.close {
|
||||
commented_path.push(' ');
|
||||
commented_path.push_str(&close);
|
||||
}
|
||||
commented_path.push('\n');
|
||||
commented_path
|
||||
let close = if let Some(close) = lc.close {
|
||||
format!(" {close}")
|
||||
} else {
|
||||
"".to_owned()
|
||||
};
|
||||
format!("{} {path_in_workspace}{close}\n", lc.open)
|
||||
}
|
||||
|
||||
fn build_prompt(pos: Position, text: &Rope, fim: &FimParams, file_path: String) -> Result<String> {
|
||||
let mut prompt = file_path;
|
||||
let cursor_offset = text
|
||||
.try_line_to_char(pos.line as usize)
|
||||
.map_err(internal_error)?
|
||||
+ pos.character as usize;
|
||||
let text_len = text.len_chars();
|
||||
// XXX: not sure this is useful, rather be safe than sorry
|
||||
let cursor_offset = if cursor_offset > text_len {
|
||||
text_len
|
||||
} else {
|
||||
cursor_offset
|
||||
};
|
||||
fn build_prompt(
|
||||
pos: Position,
|
||||
text: &Rope,
|
||||
fim: &FimParams,
|
||||
file_path: String,
|
||||
tokenizer: Tokenizer,
|
||||
context_window: usize,
|
||||
) -> Result<String> {
|
||||
if fim.enabled {
|
||||
prompt.push_str(&fim.prefix);
|
||||
prompt.push_str(&text.slice(0..cursor_offset).to_string());
|
||||
prompt.push_str(&fim.suffix);
|
||||
prompt.push_str(&text.slice(cursor_offset..text_len).to_string());
|
||||
prompt.push_str(&fim.middle);
|
||||
Ok(prompt)
|
||||
let mut token_count = context_window;
|
||||
let mut before_iter = text.lines_at(pos.line as usize + 1).reversed();
|
||||
let mut after_iter = text.lines_at(pos.line as usize);
|
||||
let mut before_line = before_iter.next();
|
||||
let col = pos.character as usize;
|
||||
if let Some(line) = before_line {
|
||||
before_line = Some(line.slice(0..col));
|
||||
}
|
||||
let mut after_line = after_iter.next();
|
||||
if let Some(line) = after_line {
|
||||
after_line = Some(line.slice(col..));
|
||||
}
|
||||
let mut before = vec![];
|
||||
let mut after = String::new();
|
||||
while before_line.is_some() || after_line.is_some() {
|
||||
if let Some(before_line) = before_line {
|
||||
let before_line = before_line.to_string();
|
||||
let tokens = tokenizer
|
||||
.encode(before_line.clone(), false)
|
||||
.map_err(internal_error)?
|
||||
.len();
|
||||
if tokens > token_count {
|
||||
break;
|
||||
}
|
||||
token_count -= tokens;
|
||||
before.push(before_line);
|
||||
}
|
||||
if let Some(after_line) = after_line {
|
||||
let after_line = after_line.to_string();
|
||||
let tokens = tokenizer
|
||||
.encode(after_line.clone(), false)
|
||||
.map_err(internal_error)?
|
||||
.len();
|
||||
if tokens > token_count {
|
||||
break;
|
||||
}
|
||||
token_count -= tokens;
|
||||
after.push_str(&after_line);
|
||||
}
|
||||
before_line = before_iter.next();
|
||||
after_line = after_iter.next();
|
||||
}
|
||||
Ok(format!(
|
||||
"{}{}{}{}{}{}",
|
||||
file_path,
|
||||
fim.prefix,
|
||||
before.into_iter().rev().collect::<Vec<_>>().join(""),
|
||||
fim.suffix,
|
||||
after,
|
||||
fim.middle
|
||||
))
|
||||
} else {
|
||||
prompt.push_str(&text.slice(0..cursor_offset).to_string());
|
||||
Ok(prompt)
|
||||
let mut token_count = context_window;
|
||||
let mut before = vec![];
|
||||
let mut first = true;
|
||||
for mut line in text.lines_at(pos.line as usize).reversed() {
|
||||
if first {
|
||||
line = line.slice(0..pos.character as usize);
|
||||
first = false;
|
||||
}
|
||||
let line = line.to_string();
|
||||
let tokens = tokenizer
|
||||
.encode(line.clone(), false)
|
||||
.map_err(internal_error)?
|
||||
.len();
|
||||
if tokens > token_count {
|
||||
break;
|
||||
}
|
||||
token_count -= tokens;
|
||||
before.push(line);
|
||||
}
|
||||
Ok(format!(
|
||||
"{}{}",
|
||||
file_path,
|
||||
&before.into_iter().rev().collect::<Vec<_>>().join("")
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -203,15 +269,15 @@ async fn request_completion(
|
|||
http_client: &reqwest::Client,
|
||||
model: &str,
|
||||
request_params: RequestParams,
|
||||
api_token: Option<String>,
|
||||
api_token: Option<&String>,
|
||||
prompt: String,
|
||||
) -> Result<Vec<Generation>> {
|
||||
let mut req = http_client.post(model).json(&APIRequest {
|
||||
let mut req = http_client.post(build_url(model)).json(&APIRequest {
|
||||
inputs: prompt,
|
||||
parameters: request_params.into(),
|
||||
});
|
||||
|
||||
if let Some(api_token) = api_token.clone() {
|
||||
if let Some(api_token) = api_token {
|
||||
req = req.header(AUTHORIZATION, format!("Bearer {api_token}"))
|
||||
}
|
||||
|
||||
|
@ -232,6 +298,82 @@ fn parse_generations(generations: Vec<Generation>, stop_token: &str) -> Vec<Comp
|
|||
.collect()
|
||||
}
|
||||
|
||||
async fn download_tokenizer_file(
|
||||
http_client: &reqwest::Client,
|
||||
model: &str,
|
||||
api_token: Option<&String>,
|
||||
to: impl AsRef<Path>,
|
||||
) -> Result<()> {
|
||||
if to.as_ref().exists() {
|
||||
return Ok(());
|
||||
}
|
||||
tokio::fs::create_dir_all(
|
||||
to.as_ref()
|
||||
.parent()
|
||||
.ok_or_else(|| internal_error("tokenizer path has no parent"))?,
|
||||
)
|
||||
.await
|
||||
.map_err(internal_error)?;
|
||||
let mut req = http_client.get(format!(
|
||||
"https://huggingface.co/{model}/resolve/main/tokenizer.json"
|
||||
));
|
||||
if let Some(api_token) = api_token {
|
||||
req = req.header(AUTHORIZATION, format!("Bearer {api_token}"))
|
||||
}
|
||||
let res = req
|
||||
.send()
|
||||
.await
|
||||
.map_err(internal_error)?
|
||||
.error_for_status()
|
||||
.map_err(internal_error)?;
|
||||
let mut file = tokio::fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.open(to)
|
||||
.await
|
||||
.map_err(internal_error)?;
|
||||
file.write_all(&res.bytes().await.map_err(internal_error)?)
|
||||
.await
|
||||
.map_err(internal_error)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_tokenizer(
|
||||
model: &str,
|
||||
tokenizer_map: &mut HashMap<String, Tokenizer>,
|
||||
tokenizer_path: Option<&String>,
|
||||
http_client: &reqwest::Client,
|
||||
cache_dir: impl AsRef<Path>,
|
||||
api_token: Option<&String>,
|
||||
) -> Result<Tokenizer> {
|
||||
if model.starts_with("http://") || model.starts_with("https://") {
|
||||
let tokenizer = match tokenizer_path {
|
||||
Some(path) => Tokenizer::from_file(path).map_err(internal_error)?,
|
||||
None => return Err(internal_error("`tokenizer_path` is null")),
|
||||
};
|
||||
Ok(tokenizer)
|
||||
} else {
|
||||
match tokenizer_map.get(model) {
|
||||
Some(tokenizer) => Ok(tokenizer.clone()),
|
||||
None => {
|
||||
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
|
||||
download_tokenizer_file(http_client, model, api_token, &path).await?;
|
||||
let tokenizer = Tokenizer::from_file(path).map_err(internal_error)?;
|
||||
tokenizer_map.insert(model.to_owned(), tokenizer.clone());
|
||||
Ok(tokenizer)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_url(model: &str) -> String {
|
||||
if model.starts_with("http://") || model.starts_with("https://") {
|
||||
model.to_owned()
|
||||
} else {
|
||||
format!("https://api-inference.huggingface.co/models/{model}")
|
||||
}
|
||||
}
|
||||
|
||||
impl Backend {
|
||||
async fn get_completions(&self, params: CompletionParams) -> Result<Vec<Completion>> {
|
||||
info!("get_completions {params:?}");
|
||||
|
@ -242,23 +384,34 @@ impl Backend {
|
|||
.ok_or_else(|| internal_error("failed to find document"))?;
|
||||
let file_path = file_path_comment(
|
||||
params.text_document_position.text_document.uri,
|
||||
document.language_id.clone(),
|
||||
&document.language_id,
|
||||
self.workspace_folders.read().await.as_ref(),
|
||||
&self.language_comments,
|
||||
);
|
||||
let tokenizer = get_tokenizer(
|
||||
¶ms.model,
|
||||
&mut *self.tokenizer_map.write().await,
|
||||
params.tokenizer_path.as_ref(),
|
||||
&self.http_client,
|
||||
&self.cache_dir,
|
||||
params.api_token.as_ref(),
|
||||
)
|
||||
.await?;
|
||||
let prompt = build_prompt(
|
||||
params.text_document_position.position,
|
||||
&document.text,
|
||||
¶ms.fim,
|
||||
file_path,
|
||||
tokenizer,
|
||||
params.context_window,
|
||||
)?;
|
||||
let stop_token = params.request_params.stop_token.clone();
|
||||
let result = request_completion(
|
||||
&self.http_client,
|
||||
¶ms.model,
|
||||
params.request_params,
|
||||
params.api_token,
|
||||
prompt.clone(),
|
||||
params.api_token.as_ref(),
|
||||
prompt,
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
@ -281,7 +434,6 @@ impl LanguageServer for Backend {
|
|||
)),
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -358,7 +510,7 @@ async fn main() {
|
|||
.await
|
||||
.expect("failed to create cache dir");
|
||||
|
||||
let log_file = rolling::never(cache_dir, "llm-ls.log");
|
||||
let log_file = rolling::never(&cache_dir, "llm-ls.log");
|
||||
let builder = tracing_subscriber::fmt()
|
||||
.with_writer(log_file)
|
||||
.with_target(true)
|
||||
|
@ -377,11 +529,13 @@ async fn main() {
|
|||
let http_client = reqwest::Client::new();
|
||||
|
||||
let (service, socket) = LspService::build(|client| Backend {
|
||||
cache_dir,
|
||||
client,
|
||||
document_map: Arc::new(RwLock::new(HashMap::new())),
|
||||
http_client,
|
||||
workspace_folders: Arc::new(RwLock::new(None)),
|
||||
language_comments: build_language_comments(),
|
||||
tokenizer_map: Arc::new(RwLock::new(HashMap::new())),
|
||||
})
|
||||
.custom_method("llm-ls/getCompletions", Backend::get_completions)
|
||||
.finish();
|
||||
|
|
Loading…
Reference in a new issue