feat: add tokenizer for context window size check (#5)
* feat: add tokenizer for context window size check * refactor: clippy * refactor: use `format!` instead of mut * refactor: remove unnecessary cloning
This commit is contained in:
parent
6f455eca18
commit
073be09042
1059
Cargo.lock
generated
1059
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -10,6 +10,7 @@ home = "0.5"
|
||||||
ropey = "1.6"
|
ropey = "1.6"
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
reqwest = { version = "0.11", features = ["json"] }
|
reqwest = { version = "0.11", features = ["json"] }
|
||||||
|
tokenizers = "0.13"
|
||||||
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "rt-multi-thread"] }
|
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "rt-multi-thread"] }
|
||||||
tower-lsp = "0.20"
|
tower-lsp = "0.20"
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
|
|
234
src/main.rs
234
src/main.rs
|
@ -6,7 +6,10 @@ use ropey::Rope;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use tower_lsp::jsonrpc::{Error, Result};
|
use tower_lsp::jsonrpc::{Error, Result};
|
||||||
use tower_lsp::lsp_types::*;
|
use tower_lsp::lsp_types::*;
|
||||||
|
@ -104,11 +107,13 @@ impl Document {
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Backend {
|
struct Backend {
|
||||||
|
cache_dir: PathBuf,
|
||||||
client: Client,
|
client: Client,
|
||||||
document_map: Arc<RwLock<HashMap<String, Document>>>,
|
document_map: Arc<RwLock<HashMap<String, Document>>>,
|
||||||
http_client: reqwest::Client,
|
http_client: reqwest::Client,
|
||||||
workspace_folders: Arc<RwLock<Option<Vec<WorkspaceFolder>>>>,
|
workspace_folders: Arc<RwLock<Option<Vec<WorkspaceFolder>>>>,
|
||||||
language_comments: HashMap<String, LanguageComment>,
|
language_comments: HashMap<String, LanguageComment>,
|
||||||
|
tokenizer_map: Arc<RwLock<HashMap<String, Tokenizer>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
@ -124,6 +129,8 @@ struct CompletionParams {
|
||||||
fim: FimParams,
|
fim: FimParams,
|
||||||
api_token: Option<String>,
|
api_token: Option<String>,
|
||||||
model: String,
|
model: String,
|
||||||
|
tokenizer_path: Option<String>,
|
||||||
|
context_window: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn internal_error<E: Display>(err: E) -> Error {
|
fn internal_error<E: Display>(err: E) -> Error {
|
||||||
|
@ -138,7 +145,7 @@ fn internal_error<E: Display>(err: E) -> Error {
|
||||||
|
|
||||||
fn file_path_comment(
|
fn file_path_comment(
|
||||||
file_url: Url,
|
file_url: Url,
|
||||||
file_language_id: String,
|
file_language_id: &str,
|
||||||
workspace_folders: Option<&Vec<WorkspaceFolder>>,
|
workspace_folders: Option<&Vec<WorkspaceFolder>>,
|
||||||
language_comments: &HashMap<String, LanguageComment>,
|
language_comments: &HashMap<String, LanguageComment>,
|
||||||
) -> String {
|
) -> String {
|
||||||
|
@ -155,47 +162,106 @@ fn file_path_comment(
|
||||||
} else {
|
} else {
|
||||||
file_path
|
file_path
|
||||||
};
|
};
|
||||||
let lc = match language_comments.get(&file_language_id) {
|
let lc = match language_comments.get(file_language_id) {
|
||||||
Some(id) => id.clone(),
|
Some(id) => id.clone(),
|
||||||
None => LanguageComment {
|
None => LanguageComment {
|
||||||
open: "//".to_owned(),
|
open: "//".to_owned(),
|
||||||
close: None,
|
close: None,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
let mut commented_path = lc.open;
|
let close = if let Some(close) = lc.close {
|
||||||
commented_path.push(' ');
|
format!(" {close}")
|
||||||
commented_path.push_str(&path_in_workspace);
|
} else {
|
||||||
if let Some(close) = lc.close {
|
"".to_owned()
|
||||||
commented_path.push(' ');
|
};
|
||||||
commented_path.push_str(&close);
|
format!("{} {path_in_workspace}{close}\n", lc.open)
|
||||||
}
|
|
||||||
commented_path.push('\n');
|
|
||||||
commented_path
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_prompt(pos: Position, text: &Rope, fim: &FimParams, file_path: String) -> Result<String> {
|
fn build_prompt(
|
||||||
let mut prompt = file_path;
|
pos: Position,
|
||||||
let cursor_offset = text
|
text: &Rope,
|
||||||
.try_line_to_char(pos.line as usize)
|
fim: &FimParams,
|
||||||
.map_err(internal_error)?
|
file_path: String,
|
||||||
+ pos.character as usize;
|
tokenizer: Tokenizer,
|
||||||
let text_len = text.len_chars();
|
context_window: usize,
|
||||||
// XXX: not sure this is useful, rather be safe than sorry
|
) -> Result<String> {
|
||||||
let cursor_offset = if cursor_offset > text_len {
|
|
||||||
text_len
|
|
||||||
} else {
|
|
||||||
cursor_offset
|
|
||||||
};
|
|
||||||
if fim.enabled {
|
if fim.enabled {
|
||||||
prompt.push_str(&fim.prefix);
|
let mut token_count = context_window;
|
||||||
prompt.push_str(&text.slice(0..cursor_offset).to_string());
|
let mut before_iter = text.lines_at(pos.line as usize + 1).reversed();
|
||||||
prompt.push_str(&fim.suffix);
|
let mut after_iter = text.lines_at(pos.line as usize);
|
||||||
prompt.push_str(&text.slice(cursor_offset..text_len).to_string());
|
let mut before_line = before_iter.next();
|
||||||
prompt.push_str(&fim.middle);
|
let col = pos.character as usize;
|
||||||
Ok(prompt)
|
if let Some(line) = before_line {
|
||||||
|
before_line = Some(line.slice(0..col));
|
||||||
|
}
|
||||||
|
let mut after_line = after_iter.next();
|
||||||
|
if let Some(line) = after_line {
|
||||||
|
after_line = Some(line.slice(col..));
|
||||||
|
}
|
||||||
|
let mut before = vec![];
|
||||||
|
let mut after = String::new();
|
||||||
|
while before_line.is_some() || after_line.is_some() {
|
||||||
|
if let Some(before_line) = before_line {
|
||||||
|
let before_line = before_line.to_string();
|
||||||
|
let tokens = tokenizer
|
||||||
|
.encode(before_line.clone(), false)
|
||||||
|
.map_err(internal_error)?
|
||||||
|
.len();
|
||||||
|
if tokens > token_count {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
token_count -= tokens;
|
||||||
|
before.push(before_line);
|
||||||
|
}
|
||||||
|
if let Some(after_line) = after_line {
|
||||||
|
let after_line = after_line.to_string();
|
||||||
|
let tokens = tokenizer
|
||||||
|
.encode(after_line.clone(), false)
|
||||||
|
.map_err(internal_error)?
|
||||||
|
.len();
|
||||||
|
if tokens > token_count {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
token_count -= tokens;
|
||||||
|
after.push_str(&after_line);
|
||||||
|
}
|
||||||
|
before_line = before_iter.next();
|
||||||
|
after_line = after_iter.next();
|
||||||
|
}
|
||||||
|
Ok(format!(
|
||||||
|
"{}{}{}{}{}{}",
|
||||||
|
file_path,
|
||||||
|
fim.prefix,
|
||||||
|
before.into_iter().rev().collect::<Vec<_>>().join(""),
|
||||||
|
fim.suffix,
|
||||||
|
after,
|
||||||
|
fim.middle
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
prompt.push_str(&text.slice(0..cursor_offset).to_string());
|
let mut token_count = context_window;
|
||||||
Ok(prompt)
|
let mut before = vec![];
|
||||||
|
let mut first = true;
|
||||||
|
for mut line in text.lines_at(pos.line as usize).reversed() {
|
||||||
|
if first {
|
||||||
|
line = line.slice(0..pos.character as usize);
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
let line = line.to_string();
|
||||||
|
let tokens = tokenizer
|
||||||
|
.encode(line.clone(), false)
|
||||||
|
.map_err(internal_error)?
|
||||||
|
.len();
|
||||||
|
if tokens > token_count {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
token_count -= tokens;
|
||||||
|
before.push(line);
|
||||||
|
}
|
||||||
|
Ok(format!(
|
||||||
|
"{}{}",
|
||||||
|
file_path,
|
||||||
|
&before.into_iter().rev().collect::<Vec<_>>().join("")
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -203,15 +269,15 @@ async fn request_completion(
|
||||||
http_client: &reqwest::Client,
|
http_client: &reqwest::Client,
|
||||||
model: &str,
|
model: &str,
|
||||||
request_params: RequestParams,
|
request_params: RequestParams,
|
||||||
api_token: Option<String>,
|
api_token: Option<&String>,
|
||||||
prompt: String,
|
prompt: String,
|
||||||
) -> Result<Vec<Generation>> {
|
) -> Result<Vec<Generation>> {
|
||||||
let mut req = http_client.post(model).json(&APIRequest {
|
let mut req = http_client.post(build_url(model)).json(&APIRequest {
|
||||||
inputs: prompt,
|
inputs: prompt,
|
||||||
parameters: request_params.into(),
|
parameters: request_params.into(),
|
||||||
});
|
});
|
||||||
|
|
||||||
if let Some(api_token) = api_token.clone() {
|
if let Some(api_token) = api_token {
|
||||||
req = req.header(AUTHORIZATION, format!("Bearer {api_token}"))
|
req = req.header(AUTHORIZATION, format!("Bearer {api_token}"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -232,6 +298,82 @@ fn parse_generations(generations: Vec<Generation>, stop_token: &str) -> Vec<Comp
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn download_tokenizer_file(
|
||||||
|
http_client: &reqwest::Client,
|
||||||
|
model: &str,
|
||||||
|
api_token: Option<&String>,
|
||||||
|
to: impl AsRef<Path>,
|
||||||
|
) -> Result<()> {
|
||||||
|
if to.as_ref().exists() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
tokio::fs::create_dir_all(
|
||||||
|
to.as_ref()
|
||||||
|
.parent()
|
||||||
|
.ok_or_else(|| internal_error("tokenizer path has no parent"))?,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(internal_error)?;
|
||||||
|
let mut req = http_client.get(format!(
|
||||||
|
"https://huggingface.co/{model}/resolve/main/tokenizer.json"
|
||||||
|
));
|
||||||
|
if let Some(api_token) = api_token {
|
||||||
|
req = req.header(AUTHORIZATION, format!("Bearer {api_token}"))
|
||||||
|
}
|
||||||
|
let res = req
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(internal_error)?
|
||||||
|
.error_for_status()
|
||||||
|
.map_err(internal_error)?;
|
||||||
|
let mut file = tokio::fs::OpenOptions::new()
|
||||||
|
.write(true)
|
||||||
|
.create(true)
|
||||||
|
.open(to)
|
||||||
|
.await
|
||||||
|
.map_err(internal_error)?;
|
||||||
|
file.write_all(&res.bytes().await.map_err(internal_error)?)
|
||||||
|
.await
|
||||||
|
.map_err(internal_error)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_tokenizer(
|
||||||
|
model: &str,
|
||||||
|
tokenizer_map: &mut HashMap<String, Tokenizer>,
|
||||||
|
tokenizer_path: Option<&String>,
|
||||||
|
http_client: &reqwest::Client,
|
||||||
|
cache_dir: impl AsRef<Path>,
|
||||||
|
api_token: Option<&String>,
|
||||||
|
) -> Result<Tokenizer> {
|
||||||
|
if model.starts_with("http://") || model.starts_with("https://") {
|
||||||
|
let tokenizer = match tokenizer_path {
|
||||||
|
Some(path) => Tokenizer::from_file(path).map_err(internal_error)?,
|
||||||
|
None => return Err(internal_error("`tokenizer_path` is null")),
|
||||||
|
};
|
||||||
|
Ok(tokenizer)
|
||||||
|
} else {
|
||||||
|
match tokenizer_map.get(model) {
|
||||||
|
Some(tokenizer) => Ok(tokenizer.clone()),
|
||||||
|
None => {
|
||||||
|
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
|
||||||
|
download_tokenizer_file(http_client, model, api_token, &path).await?;
|
||||||
|
let tokenizer = Tokenizer::from_file(path).map_err(internal_error)?;
|
||||||
|
tokenizer_map.insert(model.to_owned(), tokenizer.clone());
|
||||||
|
Ok(tokenizer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_url(model: &str) -> String {
|
||||||
|
if model.starts_with("http://") || model.starts_with("https://") {
|
||||||
|
model.to_owned()
|
||||||
|
} else {
|
||||||
|
format!("https://api-inference.huggingface.co/models/{model}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Backend {
|
impl Backend {
|
||||||
async fn get_completions(&self, params: CompletionParams) -> Result<Vec<Completion>> {
|
async fn get_completions(&self, params: CompletionParams) -> Result<Vec<Completion>> {
|
||||||
info!("get_completions {params:?}");
|
info!("get_completions {params:?}");
|
||||||
|
@ -242,23 +384,34 @@ impl Backend {
|
||||||
.ok_or_else(|| internal_error("failed to find document"))?;
|
.ok_or_else(|| internal_error("failed to find document"))?;
|
||||||
let file_path = file_path_comment(
|
let file_path = file_path_comment(
|
||||||
params.text_document_position.text_document.uri,
|
params.text_document_position.text_document.uri,
|
||||||
document.language_id.clone(),
|
&document.language_id,
|
||||||
self.workspace_folders.read().await.as_ref(),
|
self.workspace_folders.read().await.as_ref(),
|
||||||
&self.language_comments,
|
&self.language_comments,
|
||||||
);
|
);
|
||||||
|
let tokenizer = get_tokenizer(
|
||||||
|
¶ms.model,
|
||||||
|
&mut *self.tokenizer_map.write().await,
|
||||||
|
params.tokenizer_path.as_ref(),
|
||||||
|
&self.http_client,
|
||||||
|
&self.cache_dir,
|
||||||
|
params.api_token.as_ref(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
let prompt = build_prompt(
|
let prompt = build_prompt(
|
||||||
params.text_document_position.position,
|
params.text_document_position.position,
|
||||||
&document.text,
|
&document.text,
|
||||||
¶ms.fim,
|
¶ms.fim,
|
||||||
file_path,
|
file_path,
|
||||||
|
tokenizer,
|
||||||
|
params.context_window,
|
||||||
)?;
|
)?;
|
||||||
let stop_token = params.request_params.stop_token.clone();
|
let stop_token = params.request_params.stop_token.clone();
|
||||||
let result = request_completion(
|
let result = request_completion(
|
||||||
&self.http_client,
|
&self.http_client,
|
||||||
¶ms.model,
|
¶ms.model,
|
||||||
params.request_params,
|
params.request_params,
|
||||||
params.api_token,
|
params.api_token.as_ref(),
|
||||||
prompt.clone(),
|
prompt,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
@ -281,7 +434,6 @@ impl LanguageServer for Backend {
|
||||||
)),
|
)),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
},
|
},
|
||||||
..Default::default()
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -358,7 +510,7 @@ async fn main() {
|
||||||
.await
|
.await
|
||||||
.expect("failed to create cache dir");
|
.expect("failed to create cache dir");
|
||||||
|
|
||||||
let log_file = rolling::never(cache_dir, "llm-ls.log");
|
let log_file = rolling::never(&cache_dir, "llm-ls.log");
|
||||||
let builder = tracing_subscriber::fmt()
|
let builder = tracing_subscriber::fmt()
|
||||||
.with_writer(log_file)
|
.with_writer(log_file)
|
||||||
.with_target(true)
|
.with_target(true)
|
||||||
|
@ -377,11 +529,13 @@ async fn main() {
|
||||||
let http_client = reqwest::Client::new();
|
let http_client = reqwest::Client::new();
|
||||||
|
|
||||||
let (service, socket) = LspService::build(|client| Backend {
|
let (service, socket) = LspService::build(|client| Backend {
|
||||||
|
cache_dir,
|
||||||
client,
|
client,
|
||||||
document_map: Arc::new(RwLock::new(HashMap::new())),
|
document_map: Arc::new(RwLock::new(HashMap::new())),
|
||||||
http_client,
|
http_client,
|
||||||
workspace_folders: Arc::new(RwLock::new(None)),
|
workspace_folders: Arc::new(RwLock::new(None)),
|
||||||
language_comments: build_language_comments(),
|
language_comments: build_language_comments(),
|
||||||
|
tokenizer_map: Arc::new(RwLock::new(HashMap::new())),
|
||||||
})
|
})
|
||||||
.custom_method("llm-ls/getCompletions", Backend::get_completions)
|
.custom_method("llm-ls/getCompletions", Backend::get_completions)
|
||||||
.finish();
|
.finish();
|
||||||
|
|
Loading…
Reference in a new issue