refactor: error handling (#71)
This commit is contained in:
parent
a9831d5720
commit
455b085c96
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -979,6 +979,7 @@ dependencies = [
|
||||||
"ropey",
|
"ropey",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"thiserror",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tower-lsp",
|
"tower-lsp",
|
||||||
|
|
|
@ -18,6 +18,7 @@ reqwest = { version = "0.11", default-features = false, features = [
|
||||||
] }
|
] }
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
|
thiserror = "1"
|
||||||
tokenizers = { version = "0.14", default-features = false, features = ["onig"] }
|
tokenizers = { version = "0.14", default-features = false, features = ["onig"] }
|
||||||
tokio = { version = "1", features = [
|
tokio = { version = "1", features = [
|
||||||
"fs",
|
"fs",
|
||||||
|
|
|
@ -1,59 +1,44 @@
|
||||||
use super::{
|
use super::{APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION};
|
||||||
internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION,
|
|
||||||
};
|
|
||||||
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
|
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::{Map, Value};
|
use serde_json::{Map, Value};
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
use tower_lsp::jsonrpc;
|
|
||||||
|
|
||||||
fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
|
use crate::error::{Error, Result};
|
||||||
|
|
||||||
|
fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
|
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
|
||||||
headers.insert(
|
headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?);
|
||||||
USER_AGENT,
|
|
||||||
HeaderValue::from_str(&user_agent).map_err(internal_error)?,
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Some(api_token) = api_token {
|
if let Some(api_token) = api_token {
|
||||||
headers.insert(
|
headers.insert(
|
||||||
AUTHORIZATION,
|
AUTHORIZATION,
|
||||||
HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?,
|
HeaderValue::from_str(&format!("Bearer {api_token}"))?,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(headers)
|
Ok(headers)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_tgi_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
|
fn parse_tgi_text(text: &str) -> Result<Vec<Generation>> {
|
||||||
let generations =
|
match serde_json::from_str(text)? {
|
||||||
match serde_json::from_str(text).map_err(internal_error)? {
|
APIResponse::Generation(gen) => Ok(vec![gen]),
|
||||||
APIResponse::Generation(gen) => vec![gen],
|
APIResponse::Generations(_) => Err(Error::InvalidBackend),
|
||||||
APIResponse::Generations(_) => {
|
APIResponse::Error(err) => Err(Error::Tgi(err)),
|
||||||
return Err(internal_error(
|
|
||||||
"You are attempting to parse a result in the API inference format when using the `tgi` backend",
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
APIResponse::Error(err) => return Err(internal_error(err)),
|
|
||||||
};
|
|
||||||
Ok(generations)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
|
fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
|
||||||
build_tgi_headers(api_token, ide)
|
build_tgi_headers(api_token, ide)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_api_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
|
fn parse_api_text(text: &str) -> Result<Vec<Generation>> {
|
||||||
let generations = match serde_json::from_str(text).map_err(internal_error)? {
|
match serde_json::from_str(text)? {
|
||||||
APIResponse::Generation(gen) => vec![gen],
|
APIResponse::Generation(gen) => Ok(vec![gen]),
|
||||||
APIResponse::Generations(gens) => gens,
|
APIResponse::Generations(gens) => Ok(gens),
|
||||||
APIResponse::Error(err) => return Err(internal_error(err)),
|
APIResponse::Error(err) => Err(Error::InferenceApi(err)),
|
||||||
};
|
|
||||||
Ok(generations)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_ollama_headers() -> Result<HeaderMap, jsonrpc::Error> {
|
|
||||||
Ok(HeaderMap::new())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
@ -76,16 +61,15 @@ enum OllamaAPIResponse {
|
||||||
Error(APIError),
|
Error(APIError),
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_ollama_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
|
fn build_ollama_headers() -> HeaderMap {
|
||||||
let generations = match serde_json::from_str(text).map_err(internal_error)? {
|
HeaderMap::new()
|
||||||
OllamaAPIResponse::Generation(gen) => vec![gen.into()],
|
|
||||||
OllamaAPIResponse::Error(err) => return Err(internal_error(err)),
|
|
||||||
};
|
|
||||||
Ok(generations)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
|
fn parse_ollama_text(text: &str) -> Result<Vec<Generation>> {
|
||||||
build_api_headers(api_token, ide)
|
match serde_json::from_str(text)? {
|
||||||
|
OllamaAPIResponse::Generation(gen) => Ok(vec![gen.into()]),
|
||||||
|
OllamaAPIResponse::Error(err) => Err(Error::Ollama(err)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
@ -130,7 +114,7 @@ struct OpenAIErrorDetail {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct OpenAIError {
|
pub struct OpenAIError {
|
||||||
detail: Vec<OpenAIErrorDetail>,
|
detail: Vec<OpenAIErrorDetail>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -153,13 +137,16 @@ enum OpenAIAPIResponse {
|
||||||
Error(OpenAIError),
|
Error(OpenAIError),
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_openai_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
|
fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
|
||||||
match serde_json::from_str(text).map_err(internal_error) {
|
build_api_headers(api_token, ide)
|
||||||
Ok(OpenAIAPIResponse::Generation(completion)) => {
|
}
|
||||||
|
|
||||||
|
fn parse_openai_text(text: &str) -> Result<Vec<Generation>> {
|
||||||
|
match serde_json::from_str(text)? {
|
||||||
|
OpenAIAPIResponse::Generation(completion) => {
|
||||||
Ok(completion.choices.into_iter().map(|x| x.into()).collect())
|
Ok(completion.choices.into_iter().map(|x| x.into()).collect())
|
||||||
}
|
}
|
||||||
Ok(OpenAIAPIResponse::Error(err)) => Err(internal_error(err)),
|
OpenAIAPIResponse::Error(err) => Err(Error::OpenAI(err)),
|
||||||
Err(err) => Err(internal_error(err)),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -186,20 +173,16 @@ pub fn build_body(prompt: String, params: &CompletionParams) -> Map<String, Valu
|
||||||
body
|
body
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build_headers(
|
pub fn build_headers(backend: &Backend, api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
|
||||||
backend: &Backend,
|
|
||||||
api_token: Option<&String>,
|
|
||||||
ide: Ide,
|
|
||||||
) -> Result<HeaderMap, jsonrpc::Error> {
|
|
||||||
match backend {
|
match backend {
|
||||||
Backend::HuggingFace => build_api_headers(api_token, ide),
|
Backend::HuggingFace => build_api_headers(api_token, ide),
|
||||||
Backend::Ollama => build_ollama_headers(),
|
Backend::Ollama => Ok(build_ollama_headers()),
|
||||||
Backend::OpenAi => build_openai_headers(api_token, ide),
|
Backend::OpenAi => build_openai_headers(api_token, ide),
|
||||||
Backend::Tgi => build_tgi_headers(api_token, ide),
|
Backend::Tgi => build_tgi_headers(api_token, ide),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn parse_generations(backend: &Backend, text: &str) -> jsonrpc::Result<Vec<Generation>> {
|
pub fn parse_generations(backend: &Backend, text: &str) -> Result<Vec<Generation>> {
|
||||||
match backend {
|
match backend {
|
||||||
Backend::HuggingFace => parse_api_text(text),
|
Backend::HuggingFace => parse_api_text(text),
|
||||||
Backend::Ollama => parse_ollama_text(text),
|
Backend::Ollama => parse_ollama_text(text),
|
||||||
|
|
|
@ -1,172 +1,126 @@
|
||||||
use ropey::Rope;
|
use ropey::Rope;
|
||||||
use tower_lsp::jsonrpc::Result;
|
|
||||||
use tower_lsp::lsp_types::Range;
|
use tower_lsp::lsp_types::Range;
|
||||||
use tree_sitter::{InputEdit, Parser, Point, Tree};
|
use tree_sitter::{InputEdit, Parser, Point, Tree};
|
||||||
|
|
||||||
|
use crate::error::Result;
|
||||||
|
use crate::get_position_idx;
|
||||||
use crate::language_id::LanguageId;
|
use crate::language_id::LanguageId;
|
||||||
use crate::{get_position_idx, internal_error};
|
|
||||||
|
|
||||||
fn get_parser(language_id: LanguageId) -> Result<Parser> {
|
fn get_parser(language_id: LanguageId) -> Result<Parser> {
|
||||||
match language_id {
|
match language_id {
|
||||||
LanguageId::Bash => {
|
LanguageId::Bash => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_bash::language())?;
|
||||||
.set_language(tree_sitter_bash::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::C => {
|
LanguageId::C => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_c::language())?;
|
||||||
.set_language(tree_sitter_c::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::Cpp => {
|
LanguageId::Cpp => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_cpp::language())?;
|
||||||
.set_language(tree_sitter_cpp::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::CSharp => {
|
LanguageId::CSharp => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_c_sharp::language())?;
|
||||||
.set_language(tree_sitter_c_sharp::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::Elixir => {
|
LanguageId::Elixir => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_elixir::language())?;
|
||||||
.set_language(tree_sitter_elixir::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::Erlang => {
|
LanguageId::Erlang => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_erlang::language())?;
|
||||||
.set_language(tree_sitter_erlang::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::Go => {
|
LanguageId::Go => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_go::language())?;
|
||||||
.set_language(tree_sitter_go::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::Html => {
|
LanguageId::Html => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_html::language())?;
|
||||||
.set_language(tree_sitter_html::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::Java => {
|
LanguageId::Java => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_java::language())?;
|
||||||
.set_language(tree_sitter_java::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::JavaScript | LanguageId::JavaScriptReact => {
|
LanguageId::JavaScript | LanguageId::JavaScriptReact => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_javascript::language())?;
|
||||||
.set_language(tree_sitter_javascript::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::Json => {
|
LanguageId::Json => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_json::language())?;
|
||||||
.set_language(tree_sitter_json::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::Kotlin => {
|
LanguageId::Kotlin => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_kotlin::language())?;
|
||||||
.set_language(tree_sitter_kotlin::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::Lua => {
|
LanguageId::Lua => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_lua::language())?;
|
||||||
.set_language(tree_sitter_lua::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::Markdown => {
|
LanguageId::Markdown => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_md::language())?;
|
||||||
.set_language(tree_sitter_md::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::ObjectiveC => {
|
LanguageId::ObjectiveC => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_objc::language())?;
|
||||||
.set_language(tree_sitter_objc::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::Python => {
|
LanguageId::Python => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_python::language())?;
|
||||||
.set_language(tree_sitter_python::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::R => {
|
LanguageId::R => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_r::language())?;
|
||||||
.set_language(tree_sitter_r::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::Ruby => {
|
LanguageId::Ruby => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_ruby::language())?;
|
||||||
.set_language(tree_sitter_ruby::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::Rust => {
|
LanguageId::Rust => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_rust::language())?;
|
||||||
.set_language(tree_sitter_rust::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::Scala => {
|
LanguageId::Scala => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_scala::language())?;
|
||||||
.set_language(tree_sitter_scala::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::Swift => {
|
LanguageId::Swift => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_swift::language())?;
|
||||||
.set_language(tree_sitter_swift::language())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::TypeScript => {
|
LanguageId::TypeScript => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_typescript::language_typescript())?;
|
||||||
.set_language(tree_sitter_typescript::language_typescript())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::TypeScriptReact => {
|
LanguageId::TypeScriptReact => {
|
||||||
let mut parser = Parser::new();
|
let mut parser = Parser::new();
|
||||||
parser
|
parser.set_language(tree_sitter_typescript::language_tsx())?;
|
||||||
.set_language(tree_sitter_typescript::language_tsx())
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(parser)
|
Ok(parser)
|
||||||
}
|
}
|
||||||
LanguageId::Unknown => Ok(Parser::new()),
|
LanguageId::Unknown => Ok(Parser::new()),
|
||||||
|
@ -200,19 +154,13 @@ impl Document {
|
||||||
range.start.line as usize,
|
range.start.line as usize,
|
||||||
range.start.character as usize,
|
range.start.character as usize,
|
||||||
)?;
|
)?;
|
||||||
let start_byte = self
|
let start_byte = self.text.try_char_to_byte(start_idx)?;
|
||||||
.text
|
|
||||||
.try_char_to_byte(start_idx)
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
let old_end_idx = get_position_idx(
|
let old_end_idx = get_position_idx(
|
||||||
&self.text,
|
&self.text,
|
||||||
range.end.line as usize,
|
range.end.line as usize,
|
||||||
range.end.character as usize,
|
range.end.character as usize,
|
||||||
)?;
|
)?;
|
||||||
let old_end_byte = self
|
let old_end_byte = self.text.try_char_to_byte(old_end_idx)?;
|
||||||
.text
|
|
||||||
.try_char_to_byte(old_end_idx)
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
let start_position = Point {
|
let start_position = Point {
|
||||||
row: range.start.line as usize,
|
row: range.start.line as usize,
|
||||||
column: range.start.character as usize,
|
column: range.start.character as usize,
|
||||||
|
@ -224,7 +172,7 @@ impl Document {
|
||||||
let (new_end_idx, new_end_position) = if range.start == range.end {
|
let (new_end_idx, new_end_position) = if range.start == range.end {
|
||||||
let row = range.start.line as usize;
|
let row = range.start.line as usize;
|
||||||
let column = range.start.character as usize;
|
let column = range.start.character as usize;
|
||||||
let idx = self.text.try_line_to_char(row).map_err(internal_error)? + column;
|
let idx = self.text.try_line_to_char(row)? + column;
|
||||||
let rope = Rope::from_str(text);
|
let rope = Rope::from_str(text);
|
||||||
let text_len = rope.len_chars();
|
let text_len = rope.len_chars();
|
||||||
let end_idx = idx + text_len;
|
let end_idx = idx + text_len;
|
||||||
|
@ -237,11 +185,10 @@ impl Document {
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
let removal_idx = self.text.try_line_to_char(range.end.line as usize).map_err(internal_error)? + (range.end.character as usize);
|
let removal_idx = self.text.try_line_to_char(range.end.line as usize)?
|
||||||
|
+ (range.end.character as usize);
|
||||||
let slice_size = removal_idx - start_idx;
|
let slice_size = removal_idx - start_idx;
|
||||||
self.text
|
self.text.try_remove(start_idx..removal_idx)?;
|
||||||
.try_remove(start_idx..removal_idx)
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
self.text.insert(start_idx, text);
|
self.text.insert(start_idx, text);
|
||||||
let rope = Rope::from_str(text);
|
let rope = Rope::from_str(text);
|
||||||
let text_len = rope.len_chars();
|
let text_len = rope.len_chars();
|
||||||
|
@ -251,11 +198,8 @@ impl Document {
|
||||||
} else {
|
} else {
|
||||||
removal_idx + character_difference as usize
|
removal_idx + character_difference as usize
|
||||||
};
|
};
|
||||||
let row = self
|
let row = self.text.try_char_to_line(new_end_idx)?;
|
||||||
.text
|
let line_start = self.text.try_line_to_char(row)?;
|
||||||
.try_char_to_line(new_end_idx)
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
let line_start = self.text.try_line_to_char(row).map_err(internal_error)?;
|
|
||||||
let column = new_end_idx - line_start;
|
let column = new_end_idx - line_start;
|
||||||
(new_end_idx, Point { row, column })
|
(new_end_idx, Point { row, column })
|
||||||
};
|
};
|
||||||
|
@ -263,10 +207,7 @@ impl Document {
|
||||||
let edit = InputEdit {
|
let edit = InputEdit {
|
||||||
start_byte,
|
start_byte,
|
||||||
old_end_byte,
|
old_end_byte,
|
||||||
new_end_byte: self
|
new_end_byte: self.text.try_char_to_byte(new_end_idx)?,
|
||||||
.text
|
|
||||||
.try_char_to_byte(new_end_idx)
|
|
||||||
.map_err(internal_error)?,
|
|
||||||
start_position,
|
start_position,
|
||||||
old_end_position,
|
old_end_position,
|
||||||
new_end_position,
|
new_end_position,
|
||||||
|
|
64
crates/llm-ls/src/error.rs
Normal file
64
crates/llm-ls/src/error.rs
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
use std::fmt::Display;
|
||||||
|
|
||||||
|
use tower_lsp::jsonrpc::Error as LspError;
|
||||||
|
use tracing::error;
|
||||||
|
|
||||||
|
pub fn internal_error<E: Display>(err: E) -> LspError {
|
||||||
|
let err_msg = err.to_string();
|
||||||
|
error!(err_msg);
|
||||||
|
LspError {
|
||||||
|
code: tower_lsp::jsonrpc::ErrorCode::InternalError,
|
||||||
|
message: err_msg.into(),
|
||||||
|
data: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum Error {
|
||||||
|
#[error("http error: {0}")]
|
||||||
|
Http(#[from] reqwest::Error),
|
||||||
|
#[error("io error: {0}")]
|
||||||
|
Io(#[from] std::io::Error),
|
||||||
|
#[error("inference api error: {0}")]
|
||||||
|
InferenceApi(crate::APIError),
|
||||||
|
#[error("You are attempting to parse a result in the API inference format when using the `tgi` backend")]
|
||||||
|
InvalidBackend,
|
||||||
|
#[error("invalid header value: {0}")]
|
||||||
|
InvalidHeaderValue(#[from] reqwest::header::InvalidHeaderValue),
|
||||||
|
#[error("invalid repository id")]
|
||||||
|
InvalidRepositoryId,
|
||||||
|
#[error("invalid tokenizer path")]
|
||||||
|
InvalidTokenizerPath,
|
||||||
|
#[error("ollama error: {0}")]
|
||||||
|
Ollama(crate::APIError),
|
||||||
|
#[error("openai error: {0}")]
|
||||||
|
OpenAI(crate::backend::OpenAIError),
|
||||||
|
#[error("index out of bounds: {0}")]
|
||||||
|
OutOfBoundIndexing(usize),
|
||||||
|
#[error("line out of bounds: {0}")]
|
||||||
|
OutOfBoundLine(usize),
|
||||||
|
#[error("slice out of bounds: {0}..{1}")]
|
||||||
|
OutOfBoundSlice(usize, usize),
|
||||||
|
#[error("rope error: {0}")]
|
||||||
|
Rope(#[from] ropey::Error),
|
||||||
|
#[error("serde json error: {0}")]
|
||||||
|
SerdeJson(#[from] serde_json::Error),
|
||||||
|
#[error("tgi error: {0}")]
|
||||||
|
Tgi(crate::APIError),
|
||||||
|
#[error("tree-sitter language error: {0}")]
|
||||||
|
TreeSitterLanguage(#[from] tree_sitter::LanguageError),
|
||||||
|
#[error("tokenizer error: {0}")]
|
||||||
|
Tokenizer(#[from] tokenizers::Error),
|
||||||
|
#[error("tokio join error: {0}")]
|
||||||
|
TokioJoin(#[from] tokio::task::JoinError),
|
||||||
|
#[error("unknown backend: {0}")]
|
||||||
|
UnknownBackend(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
|
||||||
|
impl From<Error> for LspError {
|
||||||
|
fn from(err: Error) -> Self {
|
||||||
|
internal_error(err)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,5 +1,3 @@
|
||||||
use backend::{build_body, build_headers, parse_generations, Backend};
|
|
||||||
use document::Document;
|
|
||||||
use ropey::Rope;
|
use ropey::Rope;
|
||||||
use serde::{Deserialize, Deserializer, Serialize};
|
use serde::{Deserialize, Deserializer, Serialize};
|
||||||
use serde_json::{Map, Value};
|
use serde_json::{Map, Value};
|
||||||
|
@ -11,7 +9,7 @@ use std::time::{Duration, Instant};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::io::AsyncWriteExt;
|
use tokio::io::AsyncWriteExt;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use tower_lsp::jsonrpc::{Error, Result};
|
use tower_lsp::jsonrpc::Result as LspResult;
|
||||||
use tower_lsp::lsp_types::*;
|
use tower_lsp::lsp_types::*;
|
||||||
use tower_lsp::{Client, LanguageServer, LspService, Server};
|
use tower_lsp::{Client, LanguageServer, LspService, Server};
|
||||||
use tracing::{debug, error, info, info_span, warn, Instrument};
|
use tracing::{debug, error, info, info_span, warn, Instrument};
|
||||||
|
@ -19,8 +17,13 @@ use tracing_appender::rolling;
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::backend::{build_body, build_headers, parse_generations, Backend};
|
||||||
|
use crate::document::Document;
|
||||||
|
use crate::error::{internal_error, Error, Result};
|
||||||
|
|
||||||
mod backend;
|
mod backend;
|
||||||
mod document;
|
mod document;
|
||||||
|
mod error;
|
||||||
mod language_id;
|
mod language_id;
|
||||||
|
|
||||||
const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600);
|
const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600);
|
||||||
|
@ -29,10 +32,10 @@ pub const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||||
const HF_INFERENCE_API_HOSTNAME: &str = "api-inference.huggingface.co";
|
const HF_INFERENCE_API_HOSTNAME: &str = "api-inference.huggingface.co";
|
||||||
|
|
||||||
fn get_position_idx(rope: &Rope, row: usize, col: usize) -> Result<usize> {
|
fn get_position_idx(rope: &Rope, row: usize, col: usize) -> Result<usize> {
|
||||||
Ok(rope.try_line_to_char(row).map_err(internal_error)?
|
Ok(rope.try_line_to_char(row)?
|
||||||
+ col.min(
|
+ col.min(
|
||||||
rope.get_line(row.min(rope.len_lines().saturating_sub(1)))
|
rope.get_line(row.min(rope.len_lines().saturating_sub(1)))
|
||||||
.ok_or_else(|| internal_error(format!("failed to find line at {row}")))?
|
.ok_or(Error::OutOfBoundLine(row))?
|
||||||
.len_chars()
|
.len_chars()
|
||||||
.saturating_sub(1),
|
.saturating_sub(1),
|
||||||
))
|
))
|
||||||
|
@ -80,16 +83,12 @@ fn should_complete(document: &Document, position: Position) -> Result<Completion
|
||||||
let mut end_offset = get_position_idx(&document.text, end.row, end.column)? - 1;
|
let mut end_offset = get_position_idx(&document.text, end.row, end.column)? - 1;
|
||||||
let start_char = document
|
let start_char = document
|
||||||
.text
|
.text
|
||||||
.get_char(start_offset.min(document.text.len_chars() - 1))
|
.get_char(start_offset.min(document.text.len_chars().saturating_sub(1)))
|
||||||
.ok_or_else(|| {
|
.ok_or(Error::OutOfBoundIndexing(start_offset))?;
|
||||||
internal_error(format!("failed to find start char at {start_offset}"))
|
|
||||||
})?;
|
|
||||||
let end_char = document
|
let end_char = document
|
||||||
.text
|
.text
|
||||||
.get_char(end_offset.min(document.text.len_chars() - 1))
|
.get_char(end_offset.min(document.text.len_chars().saturating_sub(1)))
|
||||||
.ok_or_else(|| {
|
.ok_or(Error::OutOfBoundIndexing(end_offset))?;
|
||||||
internal_error(format!("failed to find end char at {end_offset}"))
|
|
||||||
})?;
|
|
||||||
if !start_char.is_whitespace() {
|
if !start_char.is_whitespace() {
|
||||||
start_offset += 1;
|
start_offset += 1;
|
||||||
}
|
}
|
||||||
|
@ -102,20 +101,13 @@ fn should_complete(document: &Document, position: Position) -> Result<Completion
|
||||||
let slice = document
|
let slice = document
|
||||||
.text
|
.text
|
||||||
.get_slice(start_offset..end_offset)
|
.get_slice(start_offset..end_offset)
|
||||||
.ok_or_else(|| {
|
.ok_or(Error::OutOfBoundSlice(start_offset, end_offset))?;
|
||||||
internal_error(format!(
|
|
||||||
"failed to find slice at {start_offset}..{end_offset}"
|
|
||||||
))
|
|
||||||
})?;
|
|
||||||
if slice.to_string().trim().is_empty() {
|
if slice.to_string().trim().is_empty() {
|
||||||
return Ok(CompletionType::MultiLine);
|
return Ok(CompletionType::MultiLine);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let start_idx = document
|
let start_idx = document.text.try_line_to_char(row)?;
|
||||||
.text
|
|
||||||
.try_line_to_char(row)
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
// XXX: We treat the end of a document as a newline
|
// XXX: We treat the end of a document as a newline
|
||||||
let next_char = document.text.get_char(start_idx + column).unwrap_or('\n');
|
let next_char = document.text.get_char(start_idx + column).unwrap_or('\n');
|
||||||
if next_char.is_whitespace() {
|
if next_char.is_whitespace() {
|
||||||
|
@ -200,6 +192,12 @@ pub struct APIError {
|
||||||
error: String,
|
error: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for APIError {
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
&self.error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Display for APIError {
|
impl Display for APIError {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
write!(f, "{}", self.error)
|
write!(f, "{}", self.error)
|
||||||
|
@ -297,16 +295,6 @@ struct CompletionResult {
|
||||||
completions: Vec<Completion>,
|
completions: Vec<Completion>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn internal_error<E: Display>(err: E) -> Error {
|
|
||||||
let err_msg = err.to_string();
|
|
||||||
error!(err_msg);
|
|
||||||
Error {
|
|
||||||
code: tower_lsp::jsonrpc::ErrorCode::InternalError,
|
|
||||||
message: err_msg.into(),
|
|
||||||
data: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_prompt(
|
fn build_prompt(
|
||||||
pos: Position,
|
pos: Position,
|
||||||
text: &Rope,
|
text: &Rope,
|
||||||
|
@ -335,10 +323,7 @@ fn build_prompt(
|
||||||
if let Some(before_line) = before_line {
|
if let Some(before_line) = before_line {
|
||||||
let before_line = before_line.to_string();
|
let before_line = before_line.to_string();
|
||||||
let tokens = if let Some(tokenizer) = tokenizer.clone() {
|
let tokens = if let Some(tokenizer) = tokenizer.clone() {
|
||||||
tokenizer
|
tokenizer.encode(before_line.clone(), false)?.len()
|
||||||
.encode(before_line.clone(), false)
|
|
||||||
.map_err(internal_error)?
|
|
||||||
.len()
|
|
||||||
} else {
|
} else {
|
||||||
before_line.len()
|
before_line.len()
|
||||||
};
|
};
|
||||||
|
@ -351,10 +336,7 @@ fn build_prompt(
|
||||||
if let Some(after_line) = after_line {
|
if let Some(after_line) = after_line {
|
||||||
let after_line = after_line.to_string();
|
let after_line = after_line.to_string();
|
||||||
let tokens = if let Some(tokenizer) = tokenizer.clone() {
|
let tokens = if let Some(tokenizer) = tokenizer.clone() {
|
||||||
tokenizer
|
tokenizer.encode(after_line.clone(), false)?.len()
|
||||||
.encode(after_line.clone(), false)
|
|
||||||
.map_err(internal_error)?
|
|
||||||
.len()
|
|
||||||
} else {
|
} else {
|
||||||
after_line.len()
|
after_line.len()
|
||||||
};
|
};
|
||||||
|
@ -390,10 +372,7 @@ fn build_prompt(
|
||||||
}
|
}
|
||||||
let line = line.to_string();
|
let line = line.to_string();
|
||||||
let tokens = if let Some(tokenizer) = tokenizer.clone() {
|
let tokens = if let Some(tokenizer) = tokenizer.clone() {
|
||||||
tokenizer
|
tokenizer.encode(line.clone(), false)?.len()
|
||||||
.encode(line.clone(), false)
|
|
||||||
.map_err(internal_error)?
|
|
||||||
.len()
|
|
||||||
} else {
|
} else {
|
||||||
line.len()
|
line.len()
|
||||||
};
|
};
|
||||||
|
@ -424,22 +403,18 @@ async fn request_completion(
|
||||||
.json(&json)
|
.json(&json)
|
||||||
.headers(headers)
|
.headers(headers)
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await?;
|
||||||
.map_err(internal_error)?;
|
|
||||||
|
|
||||||
let model = ¶ms.model;
|
let model = ¶ms.model;
|
||||||
let generations = parse_generations(
|
let generations = parse_generations(¶ms.backend, res.text().await?.as_str())?;
|
||||||
¶ms.backend,
|
|
||||||
res.text().await.map_err(internal_error)?.as_str(),
|
|
||||||
);
|
|
||||||
let time = t.elapsed().as_millis();
|
let time = t.elapsed().as_millis();
|
||||||
info!(
|
info!(
|
||||||
model,
|
model,
|
||||||
compute_generations_ms = time,
|
compute_generations_ms = time,
|
||||||
generations = serde_json::to_string(&generations).map_err(internal_error)?,
|
generations = serde_json::to_string(&generations)?,
|
||||||
"{model} computed generations in {time} ms"
|
"{model} computed generations in {time} ms"
|
||||||
);
|
);
|
||||||
generations
|
Ok(generations)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn format_generations(
|
fn format_generations(
|
||||||
|
@ -482,22 +457,19 @@ async fn download_tokenizer_file(
|
||||||
if to.as_ref().exists() {
|
if to.as_ref().exists() {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
tokio::fs::create_dir_all(
|
tokio::fs::create_dir_all(to.as_ref().parent().ok_or(Error::InvalidTokenizerPath)?).await?;
|
||||||
to.as_ref()
|
|
||||||
.parent()
|
|
||||||
.ok_or_else(|| internal_error("invalid tokenizer path"))?,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map_err(internal_error)?;
|
|
||||||
let headers = build_headers(&Backend::HuggingFace, api_token, ide)?;
|
let headers = build_headers(&Backend::HuggingFace, api_token, ide)?;
|
||||||
let mut file = tokio::fs::OpenOptions::new()
|
let mut file = tokio::fs::OpenOptions::new()
|
||||||
.write(true)
|
.write(true)
|
||||||
.create(true)
|
.create(true)
|
||||||
.open(to)
|
.open(to)
|
||||||
.await
|
.await?;
|
||||||
.map_err(internal_error)?;
|
|
||||||
let http_client = http_client.clone();
|
let http_client = http_client.clone();
|
||||||
let url = url.to_owned();
|
let url = url.to_owned();
|
||||||
|
// TODO:
|
||||||
|
// - create oneshot channel to send result of tokenizer download to display error message
|
||||||
|
// to user?
|
||||||
|
// - retry logic?
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let res = match http_client.get(url).headers(headers).send().await {
|
let res = match http_client.get(url).headers(headers).send().await {
|
||||||
Ok(res) => res,
|
Ok(res) => res,
|
||||||
|
@ -527,8 +499,7 @@ async fn download_tokenizer_file(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
})
|
})
|
||||||
.await
|
.await?;
|
||||||
.map_err(internal_error)?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -556,7 +527,14 @@ async fn get_tokenizer(
|
||||||
repository,
|
repository,
|
||||||
api_token,
|
api_token,
|
||||||
} => {
|
} => {
|
||||||
let path = cache_dir.as_ref().join(repository).join("tokenizer.json");
|
let (org, repo) = repository
|
||||||
|
.split_once('/')
|
||||||
|
.ok_or(Error::InvalidRepositoryId)?;
|
||||||
|
let path = cache_dir
|
||||||
|
.as_ref()
|
||||||
|
.join(org)
|
||||||
|
.join(repo)
|
||||||
|
.join("tokenizer.json");
|
||||||
let url =
|
let url =
|
||||||
format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json");
|
format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json");
|
||||||
download_tokenizer_file(http_client, &url, api_token.as_ref(), &path, ide).await?;
|
download_tokenizer_file(http_client, &url, api_token.as_ref(), &path, ide).await?;
|
||||||
|
@ -597,7 +575,7 @@ fn build_url(model: &str) -> String {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LlmService {
|
impl LlmService {
|
||||||
async fn get_completions(&self, params: CompletionParams) -> Result<CompletionResult> {
|
async fn get_completions(&self, params: CompletionParams) -> LspResult<CompletionResult> {
|
||||||
let request_id = Uuid::new_v4();
|
let request_id = Uuid::new_v4();
|
||||||
let span = info_span!("completion_request", %request_id);
|
let span = info_span!("completion_request", %request_id);
|
||||||
async move {
|
async move {
|
||||||
|
@ -669,7 +647,7 @@ impl LlmService {
|
||||||
}.instrument(span).await
|
}.instrument(span).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn accept_completion(&self, accepted: AcceptedCompletion) -> Result<()> {
|
async fn accept_completion(&self, accepted: AcceptedCompletion) -> LspResult<()> {
|
||||||
info!(
|
info!(
|
||||||
request_id = %accepted.request_id,
|
request_id = %accepted.request_id,
|
||||||
accepted_position = accepted.accepted_completion,
|
accepted_position = accepted.accepted_completion,
|
||||||
|
@ -679,7 +657,7 @@ impl LlmService {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn reject_completion(&self, rejected: RejectedCompletion) -> Result<()> {
|
async fn reject_completion(&self, rejected: RejectedCompletion) -> LspResult<()> {
|
||||||
info!(
|
info!(
|
||||||
request_id = %rejected.request_id,
|
request_id = %rejected.request_id,
|
||||||
shown_completions = serde_json::to_string(&rejected.shown_completions).map_err(internal_error)?,
|
shown_completions = serde_json::to_string(&rejected.shown_completions).map_err(internal_error)?,
|
||||||
|
@ -691,7 +669,7 @@ impl LlmService {
|
||||||
|
|
||||||
#[tower_lsp::async_trait]
|
#[tower_lsp::async_trait]
|
||||||
impl LanguageServer for LlmService {
|
impl LanguageServer for LlmService {
|
||||||
async fn initialize(&self, params: InitializeParams) -> Result<InitializeResult> {
|
async fn initialize(&self, params: InitializeParams) -> LspResult<InitializeResult> {
|
||||||
*self.workspace_folders.write().await = params.workspace_folders;
|
*self.workspace_folders.write().await = params.workspace_folders;
|
||||||
Ok(InitializeResult {
|
Ok(InitializeResult {
|
||||||
server_info: Some(ServerInfo {
|
server_info: Some(ServerInfo {
|
||||||
|
@ -743,7 +721,7 @@ impl LanguageServer for LlmService {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ignore the output scheme
|
// ignore the output scheme
|
||||||
if uri.starts_with("output:") {
|
if params.text_document.uri.scheme() == "output" {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -786,7 +764,7 @@ impl LanguageServer for LlmService {
|
||||||
info!("{uri} closed");
|
info!("{uri} closed");
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn shutdown(&self) -> Result<()> {
|
async fn shutdown(&self) -> LspResult<()> {
|
||||||
debug!("shutdown");
|
debug!("shutdown");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue