refactor: error handling (#71)

This commit is contained in:
Luc Georges 2024-02-07 12:34:17 +01:00 committed by GitHub
parent a9831d5720
commit 455b085c96
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 186 additions and 218 deletions

1
Cargo.lock generated
View file

@ -979,6 +979,7 @@ dependencies = [
"ropey",
"serde",
"serde_json",
"thiserror",
"tokenizers",
"tokio",
"tower-lsp",

View file

@ -18,6 +18,7 @@ reqwest = { version = "0.11", default-features = false, features = [
] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
thiserror = "1"
tokenizers = { version = "0.14", default-features = false, features = ["onig"] }
tokio = { version = "1", features = [
"fs",

View file

@ -1,59 +1,44 @@
use super::{
internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION,
};
use super::{APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION};
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use std::fmt::Display;
use tower_lsp::jsonrpc;
fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
use crate::error::{Error, Result};
fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
headers.insert(
USER_AGENT,
HeaderValue::from_str(&user_agent).map_err(internal_error)?,
);
headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?);
if let Some(api_token) = api_token {
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?,
HeaderValue::from_str(&format!("Bearer {api_token}"))?,
);
}
Ok(headers)
}
fn parse_tgi_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
let generations =
match serde_json::from_str(text).map_err(internal_error)? {
APIResponse::Generation(gen) => vec![gen],
APIResponse::Generations(_) => {
return Err(internal_error(
"You are attempting to parse a result in the API inference format when using the `tgi` backend",
))
fn parse_tgi_text(text: &str) -> Result<Vec<Generation>> {
match serde_json::from_str(text)? {
APIResponse::Generation(gen) => Ok(vec![gen]),
APIResponse::Generations(_) => Err(Error::InvalidBackend),
APIResponse::Error(err) => Err(Error::Tgi(err)),
}
APIResponse::Error(err) => return Err(internal_error(err)),
};
Ok(generations)
}
fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
build_tgi_headers(api_token, ide)
}
fn parse_api_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
let generations = match serde_json::from_str(text).map_err(internal_error)? {
APIResponse::Generation(gen) => vec![gen],
APIResponse::Generations(gens) => gens,
APIResponse::Error(err) => return Err(internal_error(err)),
};
Ok(generations)
}
fn build_ollama_headers() -> Result<HeaderMap, jsonrpc::Error> {
Ok(HeaderMap::new())
fn parse_api_text(text: &str) -> Result<Vec<Generation>> {
match serde_json::from_str(text)? {
APIResponse::Generation(gen) => Ok(vec![gen]),
APIResponse::Generations(gens) => Ok(gens),
APIResponse::Error(err) => Err(Error::InferenceApi(err)),
}
}
#[derive(Debug, Serialize, Deserialize)]
@ -76,16 +61,15 @@ enum OllamaAPIResponse {
Error(APIError),
}
fn parse_ollama_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
let generations = match serde_json::from_str(text).map_err(internal_error)? {
OllamaAPIResponse::Generation(gen) => vec![gen.into()],
OllamaAPIResponse::Error(err) => return Err(internal_error(err)),
};
Ok(generations)
fn build_ollama_headers() -> HeaderMap {
HeaderMap::new()
}
fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
build_api_headers(api_token, ide)
fn parse_ollama_text(text: &str) -> Result<Vec<Generation>> {
match serde_json::from_str(text)? {
OllamaAPIResponse::Generation(gen) => Ok(vec![gen.into()]),
OllamaAPIResponse::Error(err) => Err(Error::Ollama(err)),
}
}
#[derive(Debug, Deserialize)]
@ -130,7 +114,7 @@ struct OpenAIErrorDetail {
}
#[derive(Debug, Deserialize)]
struct OpenAIError {
pub struct OpenAIError {
detail: Vec<OpenAIErrorDetail>,
}
@ -153,13 +137,16 @@ enum OpenAIAPIResponse {
Error(OpenAIError),
}
fn parse_openai_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
match serde_json::from_str(text).map_err(internal_error) {
Ok(OpenAIAPIResponse::Generation(completion)) => {
fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
build_api_headers(api_token, ide)
}
fn parse_openai_text(text: &str) -> Result<Vec<Generation>> {
match serde_json::from_str(text)? {
OpenAIAPIResponse::Generation(completion) => {
Ok(completion.choices.into_iter().map(|x| x.into()).collect())
}
Ok(OpenAIAPIResponse::Error(err)) => Err(internal_error(err)),
Err(err) => Err(internal_error(err)),
OpenAIAPIResponse::Error(err) => Err(Error::OpenAI(err)),
}
}
@ -186,20 +173,16 @@ pub fn build_body(prompt: String, params: &CompletionParams) -> Map<String, Valu
body
}
pub fn build_headers(
backend: &Backend,
api_token: Option<&String>,
ide: Ide,
) -> Result<HeaderMap, jsonrpc::Error> {
pub fn build_headers(backend: &Backend, api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
match backend {
Backend::HuggingFace => build_api_headers(api_token, ide),
Backend::Ollama => build_ollama_headers(),
Backend::Ollama => Ok(build_ollama_headers()),
Backend::OpenAi => build_openai_headers(api_token, ide),
Backend::Tgi => build_tgi_headers(api_token, ide),
}
}
pub fn parse_generations(backend: &Backend, text: &str) -> jsonrpc::Result<Vec<Generation>> {
pub fn parse_generations(backend: &Backend, text: &str) -> Result<Vec<Generation>> {
match backend {
Backend::HuggingFace => parse_api_text(text),
Backend::Ollama => parse_ollama_text(text),

View file

@ -1,172 +1,126 @@
use ropey::Rope;
use tower_lsp::jsonrpc::Result;
use tower_lsp::lsp_types::Range;
use tree_sitter::{InputEdit, Parser, Point, Tree};
use crate::error::Result;
use crate::get_position_idx;
use crate::language_id::LanguageId;
use crate::{get_position_idx, internal_error};
fn get_parser(language_id: LanguageId) -> Result<Parser> {
match language_id {
LanguageId::Bash => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_bash::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_bash::language())?;
Ok(parser)
}
LanguageId::C => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_c::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_c::language())?;
Ok(parser)
}
LanguageId::Cpp => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_cpp::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_cpp::language())?;
Ok(parser)
}
LanguageId::CSharp => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_c_sharp::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_c_sharp::language())?;
Ok(parser)
}
LanguageId::Elixir => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_elixir::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_elixir::language())?;
Ok(parser)
}
LanguageId::Erlang => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_erlang::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_erlang::language())?;
Ok(parser)
}
LanguageId::Go => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_go::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_go::language())?;
Ok(parser)
}
LanguageId::Html => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_html::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_html::language())?;
Ok(parser)
}
LanguageId::Java => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_java::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_java::language())?;
Ok(parser)
}
LanguageId::JavaScript | LanguageId::JavaScriptReact => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_javascript::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_javascript::language())?;
Ok(parser)
}
LanguageId::Json => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_json::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_json::language())?;
Ok(parser)
}
LanguageId::Kotlin => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_kotlin::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_kotlin::language())?;
Ok(parser)
}
LanguageId::Lua => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_lua::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_lua::language())?;
Ok(parser)
}
LanguageId::Markdown => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_md::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_md::language())?;
Ok(parser)
}
LanguageId::ObjectiveC => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_objc::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_objc::language())?;
Ok(parser)
}
LanguageId::Python => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_python::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_python::language())?;
Ok(parser)
}
LanguageId::R => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_r::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_r::language())?;
Ok(parser)
}
LanguageId::Ruby => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_ruby::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_ruby::language())?;
Ok(parser)
}
LanguageId::Rust => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_rust::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_rust::language())?;
Ok(parser)
}
LanguageId::Scala => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_scala::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_scala::language())?;
Ok(parser)
}
LanguageId::Swift => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_swift::language())
.map_err(internal_error)?;
parser.set_language(tree_sitter_swift::language())?;
Ok(parser)
}
LanguageId::TypeScript => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_typescript::language_typescript())
.map_err(internal_error)?;
parser.set_language(tree_sitter_typescript::language_typescript())?;
Ok(parser)
}
LanguageId::TypeScriptReact => {
let mut parser = Parser::new();
parser
.set_language(tree_sitter_typescript::language_tsx())
.map_err(internal_error)?;
parser.set_language(tree_sitter_typescript::language_tsx())?;
Ok(parser)
}
LanguageId::Unknown => Ok(Parser::new()),
@ -200,19 +154,13 @@ impl Document {
range.start.line as usize,
range.start.character as usize,
)?;
let start_byte = self
.text
.try_char_to_byte(start_idx)
.map_err(internal_error)?;
let start_byte = self.text.try_char_to_byte(start_idx)?;
let old_end_idx = get_position_idx(
&self.text,
range.end.line as usize,
range.end.character as usize,
)?;
let old_end_byte = self
.text
.try_char_to_byte(old_end_idx)
.map_err(internal_error)?;
let old_end_byte = self.text.try_char_to_byte(old_end_idx)?;
let start_position = Point {
row: range.start.line as usize,
column: range.start.character as usize,
@ -224,7 +172,7 @@ impl Document {
let (new_end_idx, new_end_position) = if range.start == range.end {
let row = range.start.line as usize;
let column = range.start.character as usize;
let idx = self.text.try_line_to_char(row).map_err(internal_error)? + column;
let idx = self.text.try_line_to_char(row)? + column;
let rope = Rope::from_str(text);
let text_len = rope.len_chars();
let end_idx = idx + text_len;
@ -237,11 +185,10 @@ impl Document {
},
)
} else {
let removal_idx = self.text.try_line_to_char(range.end.line as usize).map_err(internal_error)? + (range.end.character as usize);
let removal_idx = self.text.try_line_to_char(range.end.line as usize)?
+ (range.end.character as usize);
let slice_size = removal_idx - start_idx;
self.text
.try_remove(start_idx..removal_idx)
.map_err(internal_error)?;
self.text.try_remove(start_idx..removal_idx)?;
self.text.insert(start_idx, text);
let rope = Rope::from_str(text);
let text_len = rope.len_chars();
@ -251,11 +198,8 @@ impl Document {
} else {
removal_idx + character_difference as usize
};
let row = self
.text
.try_char_to_line(new_end_idx)
.map_err(internal_error)?;
let line_start = self.text.try_line_to_char(row).map_err(internal_error)?;
let row = self.text.try_char_to_line(new_end_idx)?;
let line_start = self.text.try_line_to_char(row)?;
let column = new_end_idx - line_start;
(new_end_idx, Point { row, column })
};
@ -263,10 +207,7 @@ impl Document {
let edit = InputEdit {
start_byte,
old_end_byte,
new_end_byte: self
.text
.try_char_to_byte(new_end_idx)
.map_err(internal_error)?,
new_end_byte: self.text.try_char_to_byte(new_end_idx)?,
start_position,
old_end_position,
new_end_position,

View file

@ -0,0 +1,64 @@
use std::fmt::Display;
use tower_lsp::jsonrpc::Error as LspError;
use tracing::error;
pub fn internal_error<E: Display>(err: E) -> LspError {
let err_msg = err.to_string();
error!(err_msg);
LspError {
code: tower_lsp::jsonrpc::ErrorCode::InternalError,
message: err_msg.into(),
data: None,
}
}
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("http error: {0}")]
Http(#[from] reqwest::Error),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("inference api error: {0}")]
InferenceApi(crate::APIError),
#[error("You are attempting to parse a result in the API inference format when using the `tgi` backend")]
InvalidBackend,
#[error("invalid header value: {0}")]
InvalidHeaderValue(#[from] reqwest::header::InvalidHeaderValue),
#[error("invalid repository id")]
InvalidRepositoryId,
#[error("invalid tokenizer path")]
InvalidTokenizerPath,
#[error("ollama error: {0}")]
Ollama(crate::APIError),
#[error("openai error: {0}")]
OpenAI(crate::backend::OpenAIError),
#[error("index out of bounds: {0}")]
OutOfBoundIndexing(usize),
#[error("line out of bounds: {0}")]
OutOfBoundLine(usize),
#[error("slice out of bounds: {0}..{1}")]
OutOfBoundSlice(usize, usize),
#[error("rope error: {0}")]
Rope(#[from] ropey::Error),
#[error("serde json error: {0}")]
SerdeJson(#[from] serde_json::Error),
#[error("tgi error: {0}")]
Tgi(crate::APIError),
#[error("tree-sitter language error: {0}")]
TreeSitterLanguage(#[from] tree_sitter::LanguageError),
#[error("tokenizer error: {0}")]
Tokenizer(#[from] tokenizers::Error),
#[error("tokio join error: {0}")]
TokioJoin(#[from] tokio::task::JoinError),
#[error("unknown backend: {0}")]
UnknownBackend(String),
}
pub type Result<T> = std::result::Result<T, Error>;
impl From<Error> for LspError {
fn from(err: Error) -> Self {
internal_error(err)
}
}

View file

@ -1,5 +1,3 @@
use backend::{build_body, build_headers, parse_generations, Backend};
use document::Document;
use ropey::Rope;
use serde::{Deserialize, Deserializer, Serialize};
use serde_json::{Map, Value};
@ -11,7 +9,7 @@ use std::time::{Duration, Instant};
use tokenizers::Tokenizer;
use tokio::io::AsyncWriteExt;
use tokio::sync::RwLock;
use tower_lsp::jsonrpc::{Error, Result};
use tower_lsp::jsonrpc::Result as LspResult;
use tower_lsp::lsp_types::*;
use tower_lsp::{Client, LanguageServer, LspService, Server};
use tracing::{debug, error, info, info_span, warn, Instrument};
@ -19,8 +17,13 @@ use tracing_appender::rolling;
use tracing_subscriber::EnvFilter;
use uuid::Uuid;
use crate::backend::{build_body, build_headers, parse_generations, Backend};
use crate::document::Document;
use crate::error::{internal_error, Error, Result};
mod backend;
mod document;
mod error;
mod language_id;
const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600);
@ -29,10 +32,10 @@ pub const VERSION: &str = env!("CARGO_PKG_VERSION");
const HF_INFERENCE_API_HOSTNAME: &str = "api-inference.huggingface.co";
fn get_position_idx(rope: &Rope, row: usize, col: usize) -> Result<usize> {
Ok(rope.try_line_to_char(row).map_err(internal_error)?
Ok(rope.try_line_to_char(row)?
+ col.min(
rope.get_line(row.min(rope.len_lines().saturating_sub(1)))
.ok_or_else(|| internal_error(format!("failed to find line at {row}")))?
.ok_or(Error::OutOfBoundLine(row))?
.len_chars()
.saturating_sub(1),
))
@ -80,16 +83,12 @@ fn should_complete(document: &Document, position: Position) -> Result<Completion
let mut end_offset = get_position_idx(&document.text, end.row, end.column)? - 1;
let start_char = document
.text
.get_char(start_offset.min(document.text.len_chars() - 1))
.ok_or_else(|| {
internal_error(format!("failed to find start char at {start_offset}"))
})?;
.get_char(start_offset.min(document.text.len_chars().saturating_sub(1)))
.ok_or(Error::OutOfBoundIndexing(start_offset))?;
let end_char = document
.text
.get_char(end_offset.min(document.text.len_chars() - 1))
.ok_or_else(|| {
internal_error(format!("failed to find end char at {end_offset}"))
})?;
.get_char(end_offset.min(document.text.len_chars().saturating_sub(1)))
.ok_or(Error::OutOfBoundIndexing(end_offset))?;
if !start_char.is_whitespace() {
start_offset += 1;
}
@ -102,20 +101,13 @@ fn should_complete(document: &Document, position: Position) -> Result<Completion
let slice = document
.text
.get_slice(start_offset..end_offset)
.ok_or_else(|| {
internal_error(format!(
"failed to find slice at {start_offset}..{end_offset}"
))
})?;
.ok_or(Error::OutOfBoundSlice(start_offset, end_offset))?;
if slice.to_string().trim().is_empty() {
return Ok(CompletionType::MultiLine);
}
}
}
let start_idx = document
.text
.try_line_to_char(row)
.map_err(internal_error)?;
let start_idx = document.text.try_line_to_char(row)?;
// XXX: We treat the end of a document as a newline
let next_char = document.text.get_char(start_idx + column).unwrap_or('\n');
if next_char.is_whitespace() {
@ -200,6 +192,12 @@ pub struct APIError {
error: String,
}
impl std::error::Error for APIError {
fn description(&self) -> &str {
&self.error
}
}
impl Display for APIError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.error)
@ -297,16 +295,6 @@ struct CompletionResult {
completions: Vec<Completion>,
}
pub fn internal_error<E: Display>(err: E) -> Error {
let err_msg = err.to_string();
error!(err_msg);
Error {
code: tower_lsp::jsonrpc::ErrorCode::InternalError,
message: err_msg.into(),
data: None,
}
}
fn build_prompt(
pos: Position,
text: &Rope,
@ -335,10 +323,7 @@ fn build_prompt(
if let Some(before_line) = before_line {
let before_line = before_line.to_string();
let tokens = if let Some(tokenizer) = tokenizer.clone() {
tokenizer
.encode(before_line.clone(), false)
.map_err(internal_error)?
.len()
tokenizer.encode(before_line.clone(), false)?.len()
} else {
before_line.len()
};
@ -351,10 +336,7 @@ fn build_prompt(
if let Some(after_line) = after_line {
let after_line = after_line.to_string();
let tokens = if let Some(tokenizer) = tokenizer.clone() {
tokenizer
.encode(after_line.clone(), false)
.map_err(internal_error)?
.len()
tokenizer.encode(after_line.clone(), false)?.len()
} else {
after_line.len()
};
@ -390,10 +372,7 @@ fn build_prompt(
}
let line = line.to_string();
let tokens = if let Some(tokenizer) = tokenizer.clone() {
tokenizer
.encode(line.clone(), false)
.map_err(internal_error)?
.len()
tokenizer.encode(line.clone(), false)?.len()
} else {
line.len()
};
@ -424,22 +403,18 @@ async fn request_completion(
.json(&json)
.headers(headers)
.send()
.await
.map_err(internal_error)?;
.await?;
let model = &params.model;
let generations = parse_generations(
&params.backend,
res.text().await.map_err(internal_error)?.as_str(),
);
let generations = parse_generations(&params.backend, res.text().await?.as_str())?;
let time = t.elapsed().as_millis();
info!(
model,
compute_generations_ms = time,
generations = serde_json::to_string(&generations).map_err(internal_error)?,
generations = serde_json::to_string(&generations)?,
"{model} computed generations in {time} ms"
);
generations
Ok(generations)
}
fn format_generations(
@ -482,22 +457,19 @@ async fn download_tokenizer_file(
if to.as_ref().exists() {
return Ok(());
}
tokio::fs::create_dir_all(
to.as_ref()
.parent()
.ok_or_else(|| internal_error("invalid tokenizer path"))?,
)
.await
.map_err(internal_error)?;
tokio::fs::create_dir_all(to.as_ref().parent().ok_or(Error::InvalidTokenizerPath)?).await?;
let headers = build_headers(&Backend::HuggingFace, api_token, ide)?;
let mut file = tokio::fs::OpenOptions::new()
.write(true)
.create(true)
.open(to)
.await
.map_err(internal_error)?;
.await?;
let http_client = http_client.clone();
let url = url.to_owned();
// TODO:
// - create oneshot channel to send result of tokenizer download to display error message
// to user?
// - retry logic?
tokio::spawn(async move {
let res = match http_client.get(url).headers(headers).send().await {
Ok(res) => res,
@ -527,8 +499,7 @@ async fn download_tokenizer_file(
}
};
})
.await
.map_err(internal_error)?;
.await?;
Ok(())
}
@ -556,7 +527,14 @@ async fn get_tokenizer(
repository,
api_token,
} => {
let path = cache_dir.as_ref().join(repository).join("tokenizer.json");
let (org, repo) = repository
.split_once('/')
.ok_or(Error::InvalidRepositoryId)?;
let path = cache_dir
.as_ref()
.join(org)
.join(repo)
.join("tokenizer.json");
let url =
format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json");
download_tokenizer_file(http_client, &url, api_token.as_ref(), &path, ide).await?;
@ -597,7 +575,7 @@ fn build_url(model: &str) -> String {
}
impl LlmService {
async fn get_completions(&self, params: CompletionParams) -> Result<CompletionResult> {
async fn get_completions(&self, params: CompletionParams) -> LspResult<CompletionResult> {
let request_id = Uuid::new_v4();
let span = info_span!("completion_request", %request_id);
async move {
@ -669,7 +647,7 @@ impl LlmService {
}.instrument(span).await
}
async fn accept_completion(&self, accepted: AcceptedCompletion) -> Result<()> {
async fn accept_completion(&self, accepted: AcceptedCompletion) -> LspResult<()> {
info!(
request_id = %accepted.request_id,
accepted_position = accepted.accepted_completion,
@ -679,7 +657,7 @@ impl LlmService {
Ok(())
}
async fn reject_completion(&self, rejected: RejectedCompletion) -> Result<()> {
async fn reject_completion(&self, rejected: RejectedCompletion) -> LspResult<()> {
info!(
request_id = %rejected.request_id,
shown_completions = serde_json::to_string(&rejected.shown_completions).map_err(internal_error)?,
@ -691,7 +669,7 @@ impl LlmService {
#[tower_lsp::async_trait]
impl LanguageServer for LlmService {
async fn initialize(&self, params: InitializeParams) -> Result<InitializeResult> {
async fn initialize(&self, params: InitializeParams) -> LspResult<InitializeResult> {
*self.workspace_folders.write().await = params.workspace_folders;
Ok(InitializeResult {
server_info: Some(ServerInfo {
@ -743,7 +721,7 @@ impl LanguageServer for LlmService {
}
// ignore the output scheme
if uri.starts_with("output:") {
if params.text_document.uri.scheme() == "output" {
return;
}
@ -786,7 +764,7 @@ impl LanguageServer for LlmService {
info!("{uri} closed");
}
async fn shutdown(&self) -> Result<()> {
async fn shutdown(&self) -> LspResult<()> {
debug!("shutdown");
Ok(())
}