refactor: error handling (#71)

This commit is contained in:
Luc Georges 2024-02-07 12:34:17 +01:00 committed by GitHub
parent a9831d5720
commit 455b085c96
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 186 additions and 218 deletions

1
Cargo.lock generated
View file

@ -979,6 +979,7 @@ dependencies = [
"ropey", "ropey",
"serde", "serde",
"serde_json", "serde_json",
"thiserror",
"tokenizers", "tokenizers",
"tokio", "tokio",
"tower-lsp", "tower-lsp",

View file

@ -18,6 +18,7 @@ reqwest = { version = "0.11", default-features = false, features = [
] } ] }
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde_json = "1" serde_json = "1"
thiserror = "1"
tokenizers = { version = "0.14", default-features = false, features = ["onig"] } tokenizers = { version = "0.14", default-features = false, features = ["onig"] }
tokio = { version = "1", features = [ tokio = { version = "1", features = [
"fs", "fs",

View file

@ -1,59 +1,44 @@
use super::{ use super::{APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION};
internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION,
};
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{Map, Value}; use serde_json::{Map, Value};
use std::fmt::Display; use std::fmt::Display;
use tower_lsp::jsonrpc;
fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> { use crate::error::{Error, Result};
fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}"); let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
headers.insert( headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?);
USER_AGENT,
HeaderValue::from_str(&user_agent).map_err(internal_error)?,
);
if let Some(api_token) = api_token { if let Some(api_token) = api_token {
headers.insert( headers.insert(
AUTHORIZATION, AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?, HeaderValue::from_str(&format!("Bearer {api_token}"))?,
); );
} }
Ok(headers) Ok(headers)
} }
fn parse_tgi_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> { fn parse_tgi_text(text: &str) -> Result<Vec<Generation>> {
let generations = match serde_json::from_str(text)? {
match serde_json::from_str(text).map_err(internal_error)? { APIResponse::Generation(gen) => Ok(vec![gen]),
APIResponse::Generation(gen) => vec![gen], APIResponse::Generations(_) => Err(Error::InvalidBackend),
APIResponse::Generations(_) => { APIResponse::Error(err) => Err(Error::Tgi(err)),
return Err(internal_error( }
"You are attempting to parse a result in the API inference format when using the `tgi` backend",
))
}
APIResponse::Error(err) => return Err(internal_error(err)),
};
Ok(generations)
} }
fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> { fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
build_tgi_headers(api_token, ide) build_tgi_headers(api_token, ide)
} }
fn parse_api_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> { fn parse_api_text(text: &str) -> Result<Vec<Generation>> {
let generations = match serde_json::from_str(text).map_err(internal_error)? { match serde_json::from_str(text)? {
APIResponse::Generation(gen) => vec![gen], APIResponse::Generation(gen) => Ok(vec![gen]),
APIResponse::Generations(gens) => gens, APIResponse::Generations(gens) => Ok(gens),
APIResponse::Error(err) => return Err(internal_error(err)), APIResponse::Error(err) => Err(Error::InferenceApi(err)),
}; }
Ok(generations)
}
fn build_ollama_headers() -> Result<HeaderMap, jsonrpc::Error> {
Ok(HeaderMap::new())
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@ -76,16 +61,15 @@ enum OllamaAPIResponse {
Error(APIError), Error(APIError),
} }
fn parse_ollama_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> { fn build_ollama_headers() -> HeaderMap {
let generations = match serde_json::from_str(text).map_err(internal_error)? { HeaderMap::new()
OllamaAPIResponse::Generation(gen) => vec![gen.into()],
OllamaAPIResponse::Error(err) => return Err(internal_error(err)),
};
Ok(generations)
} }
fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> { fn parse_ollama_text(text: &str) -> Result<Vec<Generation>> {
build_api_headers(api_token, ide) match serde_json::from_str(text)? {
OllamaAPIResponse::Generation(gen) => Ok(vec![gen.into()]),
OllamaAPIResponse::Error(err) => Err(Error::Ollama(err)),
}
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -130,7 +114,7 @@ struct OpenAIErrorDetail {
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct OpenAIError { pub struct OpenAIError {
detail: Vec<OpenAIErrorDetail>, detail: Vec<OpenAIErrorDetail>,
} }
@ -153,13 +137,16 @@ enum OpenAIAPIResponse {
Error(OpenAIError), Error(OpenAIError),
} }
fn parse_openai_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> { fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
match serde_json::from_str(text).map_err(internal_error) { build_api_headers(api_token, ide)
Ok(OpenAIAPIResponse::Generation(completion)) => { }
fn parse_openai_text(text: &str) -> Result<Vec<Generation>> {
match serde_json::from_str(text)? {
OpenAIAPIResponse::Generation(completion) => {
Ok(completion.choices.into_iter().map(|x| x.into()).collect()) Ok(completion.choices.into_iter().map(|x| x.into()).collect())
} }
Ok(OpenAIAPIResponse::Error(err)) => Err(internal_error(err)), OpenAIAPIResponse::Error(err) => Err(Error::OpenAI(err)),
Err(err) => Err(internal_error(err)),
} }
} }
@ -186,20 +173,16 @@ pub fn build_body(prompt: String, params: &CompletionParams) -> Map<String, Valu
body body
} }
pub fn build_headers( pub fn build_headers(backend: &Backend, api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
backend: &Backend,
api_token: Option<&String>,
ide: Ide,
) -> Result<HeaderMap, jsonrpc::Error> {
match backend { match backend {
Backend::HuggingFace => build_api_headers(api_token, ide), Backend::HuggingFace => build_api_headers(api_token, ide),
Backend::Ollama => build_ollama_headers(), Backend::Ollama => Ok(build_ollama_headers()),
Backend::OpenAi => build_openai_headers(api_token, ide), Backend::OpenAi => build_openai_headers(api_token, ide),
Backend::Tgi => build_tgi_headers(api_token, ide), Backend::Tgi => build_tgi_headers(api_token, ide),
} }
} }
pub fn parse_generations(backend: &Backend, text: &str) -> jsonrpc::Result<Vec<Generation>> { pub fn parse_generations(backend: &Backend, text: &str) -> Result<Vec<Generation>> {
match backend { match backend {
Backend::HuggingFace => parse_api_text(text), Backend::HuggingFace => parse_api_text(text),
Backend::Ollama => parse_ollama_text(text), Backend::Ollama => parse_ollama_text(text),

View file

@ -1,172 +1,126 @@
use ropey::Rope; use ropey::Rope;
use tower_lsp::jsonrpc::Result;
use tower_lsp::lsp_types::Range; use tower_lsp::lsp_types::Range;
use tree_sitter::{InputEdit, Parser, Point, Tree}; use tree_sitter::{InputEdit, Parser, Point, Tree};
use crate::error::Result;
use crate::get_position_idx;
use crate::language_id::LanguageId; use crate::language_id::LanguageId;
use crate::{get_position_idx, internal_error};
fn get_parser(language_id: LanguageId) -> Result<Parser> { fn get_parser(language_id: LanguageId) -> Result<Parser> {
match language_id { match language_id {
LanguageId::Bash => { LanguageId::Bash => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_bash::language())?;
.set_language(tree_sitter_bash::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::C => { LanguageId::C => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_c::language())?;
.set_language(tree_sitter_c::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::Cpp => { LanguageId::Cpp => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_cpp::language())?;
.set_language(tree_sitter_cpp::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::CSharp => { LanguageId::CSharp => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_c_sharp::language())?;
.set_language(tree_sitter_c_sharp::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::Elixir => { LanguageId::Elixir => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_elixir::language())?;
.set_language(tree_sitter_elixir::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::Erlang => { LanguageId::Erlang => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_erlang::language())?;
.set_language(tree_sitter_erlang::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::Go => { LanguageId::Go => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_go::language())?;
.set_language(tree_sitter_go::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::Html => { LanguageId::Html => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_html::language())?;
.set_language(tree_sitter_html::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::Java => { LanguageId::Java => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_java::language())?;
.set_language(tree_sitter_java::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::JavaScript | LanguageId::JavaScriptReact => { LanguageId::JavaScript | LanguageId::JavaScriptReact => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_javascript::language())?;
.set_language(tree_sitter_javascript::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::Json => { LanguageId::Json => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_json::language())?;
.set_language(tree_sitter_json::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::Kotlin => { LanguageId::Kotlin => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_kotlin::language())?;
.set_language(tree_sitter_kotlin::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::Lua => { LanguageId::Lua => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_lua::language())?;
.set_language(tree_sitter_lua::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::Markdown => { LanguageId::Markdown => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_md::language())?;
.set_language(tree_sitter_md::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::ObjectiveC => { LanguageId::ObjectiveC => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_objc::language())?;
.set_language(tree_sitter_objc::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::Python => { LanguageId::Python => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_python::language())?;
.set_language(tree_sitter_python::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::R => { LanguageId::R => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_r::language())?;
.set_language(tree_sitter_r::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::Ruby => { LanguageId::Ruby => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_ruby::language())?;
.set_language(tree_sitter_ruby::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::Rust => { LanguageId::Rust => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_rust::language())?;
.set_language(tree_sitter_rust::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::Scala => { LanguageId::Scala => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_scala::language())?;
.set_language(tree_sitter_scala::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::Swift => { LanguageId::Swift => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_swift::language())?;
.set_language(tree_sitter_swift::language())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::TypeScript => { LanguageId::TypeScript => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_typescript::language_typescript())?;
.set_language(tree_sitter_typescript::language_typescript())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::TypeScriptReact => { LanguageId::TypeScriptReact => {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser.set_language(tree_sitter_typescript::language_tsx())?;
.set_language(tree_sitter_typescript::language_tsx())
.map_err(internal_error)?;
Ok(parser) Ok(parser)
} }
LanguageId::Unknown => Ok(Parser::new()), LanguageId::Unknown => Ok(Parser::new()),
@ -200,19 +154,13 @@ impl Document {
range.start.line as usize, range.start.line as usize,
range.start.character as usize, range.start.character as usize,
)?; )?;
let start_byte = self let start_byte = self.text.try_char_to_byte(start_idx)?;
.text
.try_char_to_byte(start_idx)
.map_err(internal_error)?;
let old_end_idx = get_position_idx( let old_end_idx = get_position_idx(
&self.text, &self.text,
range.end.line as usize, range.end.line as usize,
range.end.character as usize, range.end.character as usize,
)?; )?;
let old_end_byte = self let old_end_byte = self.text.try_char_to_byte(old_end_idx)?;
.text
.try_char_to_byte(old_end_idx)
.map_err(internal_error)?;
let start_position = Point { let start_position = Point {
row: range.start.line as usize, row: range.start.line as usize,
column: range.start.character as usize, column: range.start.character as usize,
@ -224,7 +172,7 @@ impl Document {
let (new_end_idx, new_end_position) = if range.start == range.end { let (new_end_idx, new_end_position) = if range.start == range.end {
let row = range.start.line as usize; let row = range.start.line as usize;
let column = range.start.character as usize; let column = range.start.character as usize;
let idx = self.text.try_line_to_char(row).map_err(internal_error)? + column; let idx = self.text.try_line_to_char(row)? + column;
let rope = Rope::from_str(text); let rope = Rope::from_str(text);
let text_len = rope.len_chars(); let text_len = rope.len_chars();
let end_idx = idx + text_len; let end_idx = idx + text_len;
@ -237,11 +185,10 @@ impl Document {
}, },
) )
} else { } else {
let removal_idx = self.text.try_line_to_char(range.end.line as usize).map_err(internal_error)? + (range.end.character as usize); let removal_idx = self.text.try_line_to_char(range.end.line as usize)?
+ (range.end.character as usize);
let slice_size = removal_idx - start_idx; let slice_size = removal_idx - start_idx;
self.text self.text.try_remove(start_idx..removal_idx)?;
.try_remove(start_idx..removal_idx)
.map_err(internal_error)?;
self.text.insert(start_idx, text); self.text.insert(start_idx, text);
let rope = Rope::from_str(text); let rope = Rope::from_str(text);
let text_len = rope.len_chars(); let text_len = rope.len_chars();
@ -251,11 +198,8 @@ impl Document {
} else { } else {
removal_idx + character_difference as usize removal_idx + character_difference as usize
}; };
let row = self let row = self.text.try_char_to_line(new_end_idx)?;
.text let line_start = self.text.try_line_to_char(row)?;
.try_char_to_line(new_end_idx)
.map_err(internal_error)?;
let line_start = self.text.try_line_to_char(row).map_err(internal_error)?;
let column = new_end_idx - line_start; let column = new_end_idx - line_start;
(new_end_idx, Point { row, column }) (new_end_idx, Point { row, column })
}; };
@ -263,10 +207,7 @@ impl Document {
let edit = InputEdit { let edit = InputEdit {
start_byte, start_byte,
old_end_byte, old_end_byte,
new_end_byte: self new_end_byte: self.text.try_char_to_byte(new_end_idx)?,
.text
.try_char_to_byte(new_end_idx)
.map_err(internal_error)?,
start_position, start_position,
old_end_position, old_end_position,
new_end_position, new_end_position,

View file

@ -0,0 +1,64 @@
use std::fmt::Display;
use tower_lsp::jsonrpc::Error as LspError;
use tracing::error;
pub fn internal_error<E: Display>(err: E) -> LspError {
let err_msg = err.to_string();
error!(err_msg);
LspError {
code: tower_lsp::jsonrpc::ErrorCode::InternalError,
message: err_msg.into(),
data: None,
}
}
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("http error: {0}")]
Http(#[from] reqwest::Error),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("inference api error: {0}")]
InferenceApi(crate::APIError),
#[error("You are attempting to parse a result in the API inference format when using the `tgi` backend")]
InvalidBackend,
#[error("invalid header value: {0}")]
InvalidHeaderValue(#[from] reqwest::header::InvalidHeaderValue),
#[error("invalid repository id")]
InvalidRepositoryId,
#[error("invalid tokenizer path")]
InvalidTokenizerPath,
#[error("ollama error: {0}")]
Ollama(crate::APIError),
#[error("openai error: {0}")]
OpenAI(crate::backend::OpenAIError),
#[error("index out of bounds: {0}")]
OutOfBoundIndexing(usize),
#[error("line out of bounds: {0}")]
OutOfBoundLine(usize),
#[error("slice out of bounds: {0}..{1}")]
OutOfBoundSlice(usize, usize),
#[error("rope error: {0}")]
Rope(#[from] ropey::Error),
#[error("serde json error: {0}")]
SerdeJson(#[from] serde_json::Error),
#[error("tgi error: {0}")]
Tgi(crate::APIError),
#[error("tree-sitter language error: {0}")]
TreeSitterLanguage(#[from] tree_sitter::LanguageError),
#[error("tokenizer error: {0}")]
Tokenizer(#[from] tokenizers::Error),
#[error("tokio join error: {0}")]
TokioJoin(#[from] tokio::task::JoinError),
#[error("unknown backend: {0}")]
UnknownBackend(String),
}
pub type Result<T> = std::result::Result<T, Error>;
impl From<Error> for LspError {
fn from(err: Error) -> Self {
internal_error(err)
}
}

View file

@ -1,5 +1,3 @@
use backend::{build_body, build_headers, parse_generations, Backend};
use document::Document;
use ropey::Rope; use ropey::Rope;
use serde::{Deserialize, Deserializer, Serialize}; use serde::{Deserialize, Deserializer, Serialize};
use serde_json::{Map, Value}; use serde_json::{Map, Value};
@ -11,7 +9,7 @@ use std::time::{Duration, Instant};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tower_lsp::jsonrpc::{Error, Result}; use tower_lsp::jsonrpc::Result as LspResult;
use tower_lsp::lsp_types::*; use tower_lsp::lsp_types::*;
use tower_lsp::{Client, LanguageServer, LspService, Server}; use tower_lsp::{Client, LanguageServer, LspService, Server};
use tracing::{debug, error, info, info_span, warn, Instrument}; use tracing::{debug, error, info, info_span, warn, Instrument};
@ -19,8 +17,13 @@ use tracing_appender::rolling;
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
use uuid::Uuid; use uuid::Uuid;
use crate::backend::{build_body, build_headers, parse_generations, Backend};
use crate::document::Document;
use crate::error::{internal_error, Error, Result};
mod backend; mod backend;
mod document; mod document;
mod error;
mod language_id; mod language_id;
const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600); const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600);
@ -29,10 +32,10 @@ pub const VERSION: &str = env!("CARGO_PKG_VERSION");
const HF_INFERENCE_API_HOSTNAME: &str = "api-inference.huggingface.co"; const HF_INFERENCE_API_HOSTNAME: &str = "api-inference.huggingface.co";
fn get_position_idx(rope: &Rope, row: usize, col: usize) -> Result<usize> { fn get_position_idx(rope: &Rope, row: usize, col: usize) -> Result<usize> {
Ok(rope.try_line_to_char(row).map_err(internal_error)? Ok(rope.try_line_to_char(row)?
+ col.min( + col.min(
rope.get_line(row.min(rope.len_lines().saturating_sub(1))) rope.get_line(row.min(rope.len_lines().saturating_sub(1)))
.ok_or_else(|| internal_error(format!("failed to find line at {row}")))? .ok_or(Error::OutOfBoundLine(row))?
.len_chars() .len_chars()
.saturating_sub(1), .saturating_sub(1),
)) ))
@ -80,16 +83,12 @@ fn should_complete(document: &Document, position: Position) -> Result<Completion
let mut end_offset = get_position_idx(&document.text, end.row, end.column)? - 1; let mut end_offset = get_position_idx(&document.text, end.row, end.column)? - 1;
let start_char = document let start_char = document
.text .text
.get_char(start_offset.min(document.text.len_chars() - 1)) .get_char(start_offset.min(document.text.len_chars().saturating_sub(1)))
.ok_or_else(|| { .ok_or(Error::OutOfBoundIndexing(start_offset))?;
internal_error(format!("failed to find start char at {start_offset}"))
})?;
let end_char = document let end_char = document
.text .text
.get_char(end_offset.min(document.text.len_chars() - 1)) .get_char(end_offset.min(document.text.len_chars().saturating_sub(1)))
.ok_or_else(|| { .ok_or(Error::OutOfBoundIndexing(end_offset))?;
internal_error(format!("failed to find end char at {end_offset}"))
})?;
if !start_char.is_whitespace() { if !start_char.is_whitespace() {
start_offset += 1; start_offset += 1;
} }
@ -102,20 +101,13 @@ fn should_complete(document: &Document, position: Position) -> Result<Completion
let slice = document let slice = document
.text .text
.get_slice(start_offset..end_offset) .get_slice(start_offset..end_offset)
.ok_or_else(|| { .ok_or(Error::OutOfBoundSlice(start_offset, end_offset))?;
internal_error(format!(
"failed to find slice at {start_offset}..{end_offset}"
))
})?;
if slice.to_string().trim().is_empty() { if slice.to_string().trim().is_empty() {
return Ok(CompletionType::MultiLine); return Ok(CompletionType::MultiLine);
} }
} }
} }
let start_idx = document let start_idx = document.text.try_line_to_char(row)?;
.text
.try_line_to_char(row)
.map_err(internal_error)?;
// XXX: We treat the end of a document as a newline // XXX: We treat the end of a document as a newline
let next_char = document.text.get_char(start_idx + column).unwrap_or('\n'); let next_char = document.text.get_char(start_idx + column).unwrap_or('\n');
if next_char.is_whitespace() { if next_char.is_whitespace() {
@ -200,6 +192,12 @@ pub struct APIError {
error: String, error: String,
} }
impl std::error::Error for APIError {
fn description(&self) -> &str {
&self.error
}
}
impl Display for APIError { impl Display for APIError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.error) write!(f, "{}", self.error)
@ -297,16 +295,6 @@ struct CompletionResult {
completions: Vec<Completion>, completions: Vec<Completion>,
} }
pub fn internal_error<E: Display>(err: E) -> Error {
let err_msg = err.to_string();
error!(err_msg);
Error {
code: tower_lsp::jsonrpc::ErrorCode::InternalError,
message: err_msg.into(),
data: None,
}
}
fn build_prompt( fn build_prompt(
pos: Position, pos: Position,
text: &Rope, text: &Rope,
@ -335,10 +323,7 @@ fn build_prompt(
if let Some(before_line) = before_line { if let Some(before_line) = before_line {
let before_line = before_line.to_string(); let before_line = before_line.to_string();
let tokens = if let Some(tokenizer) = tokenizer.clone() { let tokens = if let Some(tokenizer) = tokenizer.clone() {
tokenizer tokenizer.encode(before_line.clone(), false)?.len()
.encode(before_line.clone(), false)
.map_err(internal_error)?
.len()
} else { } else {
before_line.len() before_line.len()
}; };
@ -351,10 +336,7 @@ fn build_prompt(
if let Some(after_line) = after_line { if let Some(after_line) = after_line {
let after_line = after_line.to_string(); let after_line = after_line.to_string();
let tokens = if let Some(tokenizer) = tokenizer.clone() { let tokens = if let Some(tokenizer) = tokenizer.clone() {
tokenizer tokenizer.encode(after_line.clone(), false)?.len()
.encode(after_line.clone(), false)
.map_err(internal_error)?
.len()
} else { } else {
after_line.len() after_line.len()
}; };
@ -390,10 +372,7 @@ fn build_prompt(
} }
let line = line.to_string(); let line = line.to_string();
let tokens = if let Some(tokenizer) = tokenizer.clone() { let tokens = if let Some(tokenizer) = tokenizer.clone() {
tokenizer tokenizer.encode(line.clone(), false)?.len()
.encode(line.clone(), false)
.map_err(internal_error)?
.len()
} else { } else {
line.len() line.len()
}; };
@ -424,22 +403,18 @@ async fn request_completion(
.json(&json) .json(&json)
.headers(headers) .headers(headers)
.send() .send()
.await .await?;
.map_err(internal_error)?;
let model = &params.model; let model = &params.model;
let generations = parse_generations( let generations = parse_generations(&params.backend, res.text().await?.as_str())?;
&params.backend,
res.text().await.map_err(internal_error)?.as_str(),
);
let time = t.elapsed().as_millis(); let time = t.elapsed().as_millis();
info!( info!(
model, model,
compute_generations_ms = time, compute_generations_ms = time,
generations = serde_json::to_string(&generations).map_err(internal_error)?, generations = serde_json::to_string(&generations)?,
"{model} computed generations in {time} ms" "{model} computed generations in {time} ms"
); );
generations Ok(generations)
} }
fn format_generations( fn format_generations(
@ -482,22 +457,19 @@ async fn download_tokenizer_file(
if to.as_ref().exists() { if to.as_ref().exists() {
return Ok(()); return Ok(());
} }
tokio::fs::create_dir_all( tokio::fs::create_dir_all(to.as_ref().parent().ok_or(Error::InvalidTokenizerPath)?).await?;
to.as_ref()
.parent()
.ok_or_else(|| internal_error("invalid tokenizer path"))?,
)
.await
.map_err(internal_error)?;
let headers = build_headers(&Backend::HuggingFace, api_token, ide)?; let headers = build_headers(&Backend::HuggingFace, api_token, ide)?;
let mut file = tokio::fs::OpenOptions::new() let mut file = tokio::fs::OpenOptions::new()
.write(true) .write(true)
.create(true) .create(true)
.open(to) .open(to)
.await .await?;
.map_err(internal_error)?;
let http_client = http_client.clone(); let http_client = http_client.clone();
let url = url.to_owned(); let url = url.to_owned();
// TODO:
// - create oneshot channel to send result of tokenizer download to display error message
// to user?
// - retry logic?
tokio::spawn(async move { tokio::spawn(async move {
let res = match http_client.get(url).headers(headers).send().await { let res = match http_client.get(url).headers(headers).send().await {
Ok(res) => res, Ok(res) => res,
@ -527,8 +499,7 @@ async fn download_tokenizer_file(
} }
}; };
}) })
.await .await?;
.map_err(internal_error)?;
Ok(()) Ok(())
} }
@ -556,7 +527,14 @@ async fn get_tokenizer(
repository, repository,
api_token, api_token,
} => { } => {
let path = cache_dir.as_ref().join(repository).join("tokenizer.json"); let (org, repo) = repository
.split_once('/')
.ok_or(Error::InvalidRepositoryId)?;
let path = cache_dir
.as_ref()
.join(org)
.join(repo)
.join("tokenizer.json");
let url = let url =
format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json"); format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json");
download_tokenizer_file(http_client, &url, api_token.as_ref(), &path, ide).await?; download_tokenizer_file(http_client, &url, api_token.as_ref(), &path, ide).await?;
@ -597,7 +575,7 @@ fn build_url(model: &str) -> String {
} }
impl LlmService { impl LlmService {
async fn get_completions(&self, params: CompletionParams) -> Result<CompletionResult> { async fn get_completions(&self, params: CompletionParams) -> LspResult<CompletionResult> {
let request_id = Uuid::new_v4(); let request_id = Uuid::new_v4();
let span = info_span!("completion_request", %request_id); let span = info_span!("completion_request", %request_id);
async move { async move {
@ -669,7 +647,7 @@ impl LlmService {
}.instrument(span).await }.instrument(span).await
} }
async fn accept_completion(&self, accepted: AcceptedCompletion) -> Result<()> { async fn accept_completion(&self, accepted: AcceptedCompletion) -> LspResult<()> {
info!( info!(
request_id = %accepted.request_id, request_id = %accepted.request_id,
accepted_position = accepted.accepted_completion, accepted_position = accepted.accepted_completion,
@ -679,7 +657,7 @@ impl LlmService {
Ok(()) Ok(())
} }
async fn reject_completion(&self, rejected: RejectedCompletion) -> Result<()> { async fn reject_completion(&self, rejected: RejectedCompletion) -> LspResult<()> {
info!( info!(
request_id = %rejected.request_id, request_id = %rejected.request_id,
shown_completions = serde_json::to_string(&rejected.shown_completions).map_err(internal_error)?, shown_completions = serde_json::to_string(&rejected.shown_completions).map_err(internal_error)?,
@ -691,7 +669,7 @@ impl LlmService {
#[tower_lsp::async_trait] #[tower_lsp::async_trait]
impl LanguageServer for LlmService { impl LanguageServer for LlmService {
async fn initialize(&self, params: InitializeParams) -> Result<InitializeResult> { async fn initialize(&self, params: InitializeParams) -> LspResult<InitializeResult> {
*self.workspace_folders.write().await = params.workspace_folders; *self.workspace_folders.write().await = params.workspace_folders;
Ok(InitializeResult { Ok(InitializeResult {
server_info: Some(ServerInfo { server_info: Some(ServerInfo {
@ -743,7 +721,7 @@ impl LanguageServer for LlmService {
} }
// ignore the output scheme // ignore the output scheme
if uri.starts_with("output:") { if params.text_document.uri.scheme() == "output" {
return; return;
} }
@ -786,7 +764,7 @@ impl LanguageServer for LlmService {
info!("{uri} closed"); info!("{uri} closed");
} }
async fn shutdown(&self) -> Result<()> { async fn shutdown(&self) -> LspResult<()> {
debug!("shutdown"); debug!("shutdown");
Ok(()) Ok(())
} }