fix: helix editor build crash (#73)

This commit is contained in:
Luc Georges 2024-02-08 22:43:56 +01:00 committed by GitHub
parent 54b25a8731
commit 92fc885503
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 240 additions and 237 deletions

12
Cargo.lock generated
View file

@ -427,6 +427,16 @@ dependencies = [
"typenum", "typenum",
] ]
[[package]]
name = "custom-types"
version = "0.1.0"
dependencies = [
"lsp-types",
"serde",
"serde_json",
"uuid",
]
[[package]] [[package]]
name = "darling" name = "darling"
version = "0.14.4" version = "0.14.4"
@ -975,6 +985,7 @@ name = "llm-ls"
version = "0.4.0" version = "0.4.0"
dependencies = [ dependencies = [
"clap", "clap",
"custom-types",
"home", "home",
"reqwest", "reqwest",
"ropey", "ropey",
@ -1968,6 +1979,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"clap", "clap",
"custom-types",
"futures", "futures",
"futures-util", "futures-util",
"home", "home",

View file

@ -0,0 +1,14 @@
[package]
name = "custom-types"
version = "0.1.0"
edition.workspace = true
license.workspace = true
authors.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
lsp-types = "0.94"
serde = "1"
serde_json = "1"
uuid = "1"

View file

@ -0,0 +1,2 @@
pub mod llm_ls;
pub mod request;

View file

@ -0,0 +1,118 @@
use std::{fmt::Display, path::PathBuf};
use lsp_types::TextDocumentPositionParams;
use serde::{Deserialize, Deserializer, Serialize};
use serde_json::{Map, Value};
use uuid::Uuid;
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct AcceptCompletionParams {
pub request_id: Uuid,
pub accepted_completion: u32,
pub shown_completions: Vec<u32>,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct RejectCompletionParams {
pub request_id: Uuid,
pub shown_completions: Vec<u32>,
}
#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum Ide {
Neovim,
VSCode,
JetBrains,
Emacs,
Jupyter,
Sublime,
VisualStudio,
#[default]
Unknown,
}
impl Display for Ide {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.serialize(f)
}
}
fn parse_ide<'de, D>(d: D) -> std::result::Result<Ide, D::Error>
where
D: Deserializer<'de>,
{
Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown))
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum Backend {
#[default]
HuggingFace,
Ollama,
OpenAi,
Tgi,
}
impl Display for Backend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.serialize(f)
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FimParams {
pub enabled: bool,
pub prefix: String,
pub middle: String,
pub suffix: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub enum TokenizerConfig {
Local {
path: PathBuf,
},
HuggingFace {
repository: String,
api_token: Option<String>,
},
Download {
url: String,
to: PathBuf,
},
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GetCompletionsParams {
#[serde(flatten)]
pub text_document_position: TextDocumentPositionParams,
#[serde(default)]
#[serde(deserialize_with = "parse_ide")]
pub ide: Ide,
pub fim: FimParams,
pub api_token: Option<String>,
pub model: String,
pub backend: Backend,
pub tokens_to_clear: Vec<String>,
pub tokenizer_config: Option<TokenizerConfig>,
pub context_window: usize,
pub tls_skip_verify_insecure: bool,
pub request_body: Map<String, Value>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct Completion {
pub generated_text: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct GetCompletionsResult {
pub request_id: Uuid,
pub completions: Vec<Completion>,
}

View file

@ -0,0 +1,32 @@
use lsp_types::request::Request;
use crate::llm_ls::{
AcceptCompletionParams, GetCompletionsParams, GetCompletionsResult, RejectCompletionParams,
};
#[derive(Debug)]
pub enum GetCompletions {}
impl Request for GetCompletions {
type Params = GetCompletionsParams;
type Result = GetCompletionsResult;
const METHOD: &'static str = "llm-ls/getCompletions";
}
#[derive(Debug)]
pub enum AcceptCompletion {}
impl Request for AcceptCompletion {
type Params = AcceptCompletionParams;
type Result = ();
const METHOD: &'static str = "llm-ls/acceptCompletion";
}
#[derive(Debug)]
pub enum RejectCompletion {}
impl Request for RejectCompletion {
type Params = RejectCompletionParams;
type Result = ();
const METHOD: &'static str = "llm-ls/rejectCompletion";
}

View file

@ -8,6 +8,7 @@ name = "llm-ls"
[dependencies] [dependencies]
clap = { version = "4", features = ["derive"] } clap = { version = "4", features = ["derive"] }
custom-types = { path = "../custom-types" }
home = "0.5" home = "0.5"
ropey = { version = "1.6", default-features = false, features = [ ropey = { version = "1.6", default-features = false, features = [
"simd", "simd",

View file

@ -1,4 +1,5 @@
use super::{APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION}; use super::{APIError, APIResponse, Generation, NAME, VERSION};
use custom_types::llm_ls::{Backend, GetCompletionsParams, Ide};
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{Map, Value}; use serde_json::{Map, Value};
@ -150,17 +151,7 @@ fn parse_openai_text(text: &str) -> Result<Vec<Generation>> {
} }
} }
#[derive(Debug, Default, Deserialize, Serialize)] pub fn build_body(prompt: String, params: &GetCompletionsParams) -> Map<String, Value> {
#[serde(rename_all = "lowercase")]
pub(crate) enum Backend {
#[default]
HuggingFace,
Ollama,
OpenAi,
Tgi,
}
pub fn build_body(prompt: String, params: &CompletionParams) -> Map<String, Value> {
let mut body = params.request_body.clone(); let mut body = params.request_body.clone();
match params.backend { match params.backend {
Backend::HuggingFace | Backend::Tgi => { Backend::HuggingFace | Backend::Tgi => {

View file

@ -1,7 +1,10 @@
use clap::Parser; use clap::Parser;
use custom_types::llm_ls::{
AcceptCompletionParams, Backend, Completion, FimParams, GetCompletionsParams,
GetCompletionsResult, Ide, TokenizerConfig,
};
use ropey::Rope; use ropey::Rope;
use serde::{Deserialize, Deserializer, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Display; use std::fmt::Display;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
@ -19,7 +22,7 @@ use tracing_appender::rolling;
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
use uuid::Uuid; use uuid::Uuid;
use crate::backend::{build_body, build_headers, parse_generations, Backend}; use crate::backend::{build_body, build_headers, parse_generations};
use crate::document::Document; use crate::document::Document;
use crate::error::{internal_error, Error, Result}; use crate::error::{internal_error, Error, Result};
@ -119,22 +122,6 @@ fn should_complete(document: &Document, position: Position) -> Result<Completion
} }
} }
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
enum TokenizerConfig {
Local {
path: PathBuf,
},
HuggingFace {
repository: String,
api_token: Option<String>,
},
Download {
url: String,
to: PathBuf,
},
}
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct RequestParams { pub struct RequestParams {
@ -145,14 +132,6 @@ pub struct RequestParams {
stop_tokens: Option<Vec<String>>, stop_tokens: Option<Vec<String>>,
} }
#[derive(Debug, Deserialize, Serialize)]
struct FimParams {
enabled: bool,
prefix: String,
middle: String,
suffix: String,
}
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
struct APIParams { struct APIParams {
max_new_tokens: u32, max_new_tokens: u32,
@ -225,78 +204,6 @@ struct LlmService {
unauthenticated_warn_at: Arc<RwLock<Instant>>, unauthenticated_warn_at: Arc<RwLock<Instant>>,
} }
#[derive(Debug, Deserialize, Serialize)]
struct Completion {
generated_text: String,
}
#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum Ide {
Neovim,
VSCode,
JetBrains,
Emacs,
Jupyter,
Sublime,
VisualStudio,
#[default]
Unknown,
}
impl Display for Ide {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.serialize(f)
}
}
fn parse_ide<'de, D>(d: D) -> std::result::Result<Ide, D::Error>
where
D: Deserializer<'de>,
{
Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown))
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
struct AcceptedCompletion {
request_id: Uuid,
accepted_completion: u32,
shown_completions: Vec<u32>,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
struct RejectedCompletion {
request_id: Uuid,
shown_completions: Vec<u32>,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletionParams {
#[serde(flatten)]
text_document_position: TextDocumentPositionParams,
#[serde(default)]
#[serde(deserialize_with = "parse_ide")]
ide: Ide,
fim: FimParams,
api_token: Option<String>,
model: String,
backend: Backend,
tokens_to_clear: Vec<String>,
tokenizer_config: Option<TokenizerConfig>,
context_window: usize,
tls_skip_verify_insecure: bool,
request_body: Map<String, Value>,
}
#[derive(Debug, Deserialize, Serialize)]
struct CompletionResult {
request_id: Uuid,
completions: Vec<Completion>,
}
fn build_prompt( fn build_prompt(
pos: Position, pos: Position,
text: &Rope, text: &Rope,
@ -394,7 +301,7 @@ fn build_prompt(
async fn request_completion( async fn request_completion(
http_client: &reqwest::Client, http_client: &reqwest::Client,
prompt: String, prompt: String,
params: &CompletionParams, params: &GetCompletionsParams,
) -> Result<Vec<Generation>> { ) -> Result<Vec<Generation>> {
let t = Instant::now(); let t = Instant::now();
@ -577,7 +484,10 @@ fn build_url(model: &str) -> String {
} }
impl LlmService { impl LlmService {
async fn get_completions(&self, params: CompletionParams) -> LspResult<CompletionResult> { async fn get_completions(
&self,
params: GetCompletionsParams,
) -> LspResult<GetCompletionsResult> {
let request_id = Uuid::new_v4(); let request_id = Uuid::new_v4();
let span = info_span!("completion_request", %request_id); let span = info_span!("completion_request", %request_id);
async move { async move {
@ -592,6 +502,7 @@ impl LlmService {
cursor_character = ?params.text_document_position.position.character, cursor_character = ?params.text_document_position.position.character,
language_id = %document.language_id, language_id = %document.language_id,
model = params.model, model = params.model,
backend = %params.backend,
ide = %params.ide, ide = %params.ide,
request_body = serde_json::to_string(&params.request_body).map_err(internal_error)?, request_body = serde_json::to_string(&params.request_body).map_err(internal_error)?,
"received completion request for {}", "received completion request for {}",
@ -611,7 +522,7 @@ impl LlmService {
let completion_type = should_complete(document, params.text_document_position.position)?; let completion_type = should_complete(document, params.text_document_position.position)?;
info!(%completion_type, "completion type: {completion_type:?}"); info!(%completion_type, "completion type: {completion_type:?}");
if completion_type == CompletionType::Empty { if completion_type == CompletionType::Empty {
return Ok(CompletionResult { request_id, completions: vec![]}); return Ok(GetCompletionsResult { request_id, completions: vec![]});
} }
let tokenizer = get_tokenizer( let tokenizer = get_tokenizer(
@ -645,11 +556,11 @@ impl LlmService {
.await?; .await?;
let completions = format_generations(result, &params.tokens_to_clear, completion_type); let completions = format_generations(result, &params.tokens_to_clear, completion_type);
Ok(CompletionResult { request_id, completions }) Ok(GetCompletionsResult { request_id, completions })
}.instrument(span).await }.instrument(span).await
} }
async fn accept_completion(&self, accepted: AcceptedCompletion) -> LspResult<()> { async fn accept_completion(&self, accepted: AcceptCompletionParams) -> LspResult<()> {
info!( info!(
request_id = %accepted.request_id, request_id = %accepted.request_id,
accepted_position = accepted.accepted_completion, accepted_position = accepted.accepted_completion,
@ -659,7 +570,7 @@ impl LlmService {
Ok(()) Ok(())
} }
async fn reject_completion(&self, rejected: RejectedCompletion) -> LspResult<()> { async fn reject_completion(&self, rejected: AcceptCompletionParams) -> LspResult<()> {
info!( info!(
request_id = %rejected.request_id, request_id = %rejected.request_id,
shown_completions = serde_json::to_string(&rejected.shown_completions).map_err(internal_error)?, shown_completions = serde_json::to_string(&rejected.shown_completions).map_err(internal_error)?,

View file

@ -13,4 +13,3 @@ serde = "1"
serde_json = "1" serde_json = "1"
tokio = { version = "1", features = ["io-util", "process"] } tokio = { version = "1", features = ["io-util", "process"] }
tracing = "0.1" tracing = "0.1"

View file

@ -58,7 +58,7 @@ impl LspClient {
pub async fn send_request<R: lsp_types::request::Request>( pub async fn send_request<R: lsp_types::request::Request>(
&self, &self,
params: R::Params, params: R::Params,
) -> Result<Response> { ) -> Result<R::Result> {
let (sender, receiver) = oneshot::channel::<Response>(); let (sender, receiver) = oneshot::channel::<Response>();
let request = let request =
self.res_queue self.res_queue
@ -68,7 +68,8 @@ impl LspClient {
.register(R::METHOD.to_string(), params, sender); .register(R::METHOD.to_string(), params, sender);
self.send(request.into()); self.send(request.into());
Ok(receiver.await?) let (_, result) = receiver.await?.extract::<R::Result>()?;
Ok(result)
} }
async fn complete_request( async fn complete_request(

View file

@ -65,7 +65,9 @@ impl fmt::Display for ExtractError {
pub enum Error { pub enum Error {
ChannelClosed(RecvError), ChannelClosed(RecvError),
Io(io::Error), Io(io::Error),
Extract(ExtractError),
MissingBinaryPath, MissingBinaryPath,
Parse(String),
} }
impl std::error::Error for Error {} impl std::error::Error for Error {}
@ -76,6 +78,8 @@ impl fmt::Display for Error {
Error::ChannelClosed(e) => write!(f, "Channel closed: {}", e), Error::ChannelClosed(e) => write!(f, "Channel closed: {}", e),
Error::Io(e) => write!(f, "IO error: {}", e), Error::Io(e) => write!(f, "IO error: {}", e),
Error::MissingBinaryPath => write!(f, "Missing binary path"), Error::MissingBinaryPath => write!(f, "Missing binary path"),
Error::Parse(e) => write!(f, "parse error: {}", e),
Error::Extract(e) => write!(f, "extract error: {}", e),
} }
} }
} }
@ -92,4 +96,10 @@ impl From<io::Error> for Error {
} }
} }
impl From<ExtractError> for Error {
fn from(value: ExtractError) -> Self {
Self::Extract(value)
}
}
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;

View file

@ -11,6 +11,7 @@ authors.workspace = true
[dependencies] [dependencies]
anyhow = "1" anyhow = "1"
clap = { version = "4", features = ["derive"] } clap = { version = "4", features = ["derive"] }
custom-types = { path = "../custom-types" }
futures = "0.3" futures = "0.3"
futures-util = "0.3" futures-util = "0.3"
home = "0.5" home = "0.5"

View file

@ -1 +1 @@
[{"cursor":{"line":73,"character":10},"file":"helix-core/src/chars.rs"},{"cursor":{"line":257,"character":11},"file":"helix-dap/src/types.rs"},{"cursor":{"line":39,"character":14},"file":"helix-view/src/info.rs"},{"cursor":{"line":116,"character":12},"file":"helix-term/src/ui/mod.rs"},{"cursor":{"line":1,"character":14},"file":"helix-term/src/ui/text.rs"},{"cursor":{"line":2,"character":5},"file":"helix-core/src/config.rs"},{"cursor":{"line":151,"character":14},"file":"helix-view/src/gutter.rs"},{"cursor":{"line":11,"character":10},"file":"helix-term/src/ui/lsp.rs"},{"cursor":{"line":18,"character":0},"file":"helix-term/src/ui/text.rs"},{"cursor":{"line":230,"character":3},"file":"helix-term/src/ui/markdown.rs"}] [{"cursor":{"line":9,"character":0},"file":"helix-core/src/increment/mod.rs"},{"cursor":{"line":47,"character":5},"file":"helix-stdx/src/env.rs"},{"cursor":{"line":444,"character":4},"file":"helix-term/src/ui/editor.rs"},{"cursor":{"line":939,"character":8},"file":"helix-tui/src/buffer.rs"},{"cursor":{"line":30,"character":6},"file":"helix-view/src/handlers.rs"},{"cursor":{"line":332,"character":0},"file":"helix-term/src/health.rs"},{"cursor":{"line":15,"character":2},"file":"helix-term/src/events.rs"},{"cursor":{"line":415,"character":2},"file":"helix-tui/src/widgets/reflow.rs"},{"cursor":{"line":316,"character":2},"file":"helix-core/src/shellwords.rs"},{"cursor":{"line":218,"character":2},"file":"helix-tui/src/backend/crossterm.rs"}]

File diff suppressed because one or more lines are too long

View file

@ -6,11 +6,12 @@ fim:
middle: <fim_middle> middle: <fim_middle>
suffix: <fim_suffix> suffix: <fim_suffix>
model: bigcode/starcoder model: bigcode/starcoder
request_params: backend: huggingface
maxNewTokens: 150 request_body:
max_new_tokens: 150
temperature: 0.2 temperature: 0.2
doSample: true do_sample: true
topP: 0.95 top_p: 0.95
tls_skip_verify_insecure: false tls_skip_verify_insecure: false
tokenizer_config: tokenizer_config:
repository: bigcode/starcoder repository: bigcode/starcoder
@ -202,7 +203,7 @@ repositories:
type: github type: github
owner: helix-editor owner: helix-editor
name: helix name: helix
revision: ae6a0a9cfd377fbfa494760282498cf2ca322782 revision: a1272bdb17a63361342a318982e46129d558743c
exclude_paths: exclude_paths:
- .cargo - .cargo
- .github - .github

View file

@ -6,11 +6,12 @@ fim:
middle: <fim_middle> middle: <fim_middle>
suffix: <fim_suffix> suffix: <fim_suffix>
model: bigcode/starcoder model: bigcode/starcoder
request_params: backend: huggingface
maxNewTokens: 150 request_body:
max_new_tokens: 150
temperature: 0.2 temperature: 0.2
doSample: true do_sample: true
topP: 0.95 top_p: 0.95
tls_skip_verify_insecure: false tls_skip_verify_insecure: false
tokenizer_config: tokenizer_config:
repository: bigcode/starcoder repository: bigcode/starcoder
@ -202,7 +203,7 @@ repositories:
type: github type: github
owner: helix-editor owner: helix-editor
name: helix name: helix
revision: ae6a0a9cfd377fbfa494760282498cf2ca322782 revision: a1272bdb17a63361342a318982e46129d558743c
exclude_paths: exclude_paths:
- .cargo - .cargo
- .github - .github

View file

@ -10,9 +10,13 @@ use std::{
use anyhow::anyhow; use anyhow::anyhow;
use clap::Parser; use clap::Parser;
use custom_types::{
llm_ls::{Backend, FimParams, GetCompletionsParams, Ide, TokenizerConfig},
request::GetCompletions,
};
use futures_util::{stream::FuturesUnordered, StreamExt, TryStreamExt}; use futures_util::{stream::FuturesUnordered, StreamExt, TryStreamExt};
use lang::Language; use lang::Language;
use lsp_client::{client::LspClient, error::ExtractError, msg::RequestId, server::Server}; use lsp_client::{client::LspClient, error::ExtractError, server::Server};
use lsp_types::{ use lsp_types::{
DidOpenTextDocumentParams, InitializeParams, TextDocumentIdentifier, TextDocumentItem, DidOpenTextDocumentParams, InitializeParams, TextDocumentIdentifier, TextDocumentItem,
TextDocumentPositionParams, TextDocumentPositionParams,
@ -20,6 +24,7 @@ use lsp_types::{
use ropey::Rope; use ropey::Rope;
use runner::Runner; use runner::Runner;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use tempfile::TempDir; use tempfile::TempDir;
use tokio::{ use tokio::{
fs::{self, read_to_string, File, OpenOptions}, fs::{self, read_to_string, File, OpenOptions},
@ -32,19 +37,11 @@ use tracing::{debug, error, info, info_span, warn, Instrument};
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
use url::Url; use url::Url;
use crate::{ use crate::{holes_generator::generate_holes, runner::run_test};
holes_generator::generate_holes,
runner::run_test,
types::{
FimParams, GetCompletions, GetCompletionsParams, GetCompletionsResult, Ide, RequestParams,
TokenizerConfig,
},
};
mod holes_generator; mod holes_generator;
mod lang; mod lang;
mod runner; mod runner;
mod types;
/// Testbed runs llm-ls' code completion to measure its performance /// Testbed runs llm-ls' code completion to measure its performance
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -201,11 +198,12 @@ struct RepositoriesConfig {
context_window: usize, context_window: usize,
fim: FimParams, fim: FimParams,
model: String, model: String,
request_params: RequestParams, backend: Backend,
repositories: Vec<Repository>, repositories: Vec<Repository>,
tls_skip_verify_insecure: bool, tls_skip_verify_insecure: bool,
tokenizer_config: Option<TokenizerConfig>, tokenizer_config: Option<TokenizerConfig>,
tokens_to_clear: Vec<String>, tokens_to_clear: Vec<String>,
request_body: Map<String, Value>,
} }
struct HoleCompletionResult { struct HoleCompletionResult {
@ -463,10 +461,11 @@ async fn complete_holes(
context_window, context_window,
fim, fim,
model, model,
request_params, backend,
tls_skip_verify_insecure, tls_skip_verify_insecure,
tokenizer_config, tokenizer_config,
tokens_to_clear, tokens_to_clear,
request_body,
.. ..
} = repos_config; } = repos_config;
async move { async move {
@ -516,14 +515,14 @@ async fn complete_holes(
}, },
}, },
); );
let response = client let result = client
.send_request::<GetCompletions>(GetCompletionsParams { .send_request::<GetCompletions>(GetCompletionsParams {
api_token: api_token.clone(), api_token: api_token.clone(),
context_window, context_window,
fim: fim.clone(), fim: fim.clone(),
ide: Ide::default(), ide: Ide::default(),
model: model.clone(), model: model.clone(),
request_params: request_params.clone(), backend,
text_document_position: TextDocumentPositionParams { text_document_position: TextDocumentPositionParams {
position: hole.cursor, position: hole.cursor,
text_document: TextDocumentIdentifier { uri }, text_document: TextDocumentIdentifier { uri },
@ -531,9 +530,9 @@ async fn complete_holes(
tls_skip_verify_insecure, tls_skip_verify_insecure,
tokens_to_clear: tokens_to_clear.clone(), tokens_to_clear: tokens_to_clear.clone(),
tokenizer_config: tokenizer_config.clone(), tokenizer_config: tokenizer_config.clone(),
request_body: request_body.clone(),
}) })
.await?; .await?;
let (_, result): (RequestId, GetCompletionsResult) = response.extract()?;
file_content.insert(hole_start, &result.completions[0].generated_text); file_content.insert(hole_start, &result.completions[0].generated_text);
let mut file = OpenOptions::new() let mut file = OpenOptions::new()

View file

@ -1,90 +0,0 @@
use std::path::PathBuf;
use lsp_types::{request::Request, TextDocumentPositionParams};
use serde::{Deserialize, Deserializer, Serialize};
use uuid::Uuid;
#[derive(Debug)]
pub(crate) enum GetCompletions {}
impl Request for GetCompletions {
type Params = GetCompletionsParams;
type Result = GetCompletionsResult;
const METHOD: &'static str = "llm-ls/getCompletions";
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct RequestParams {
pub(crate) max_new_tokens: u32,
pub(crate) temperature: f32,
pub(crate) do_sample: bool,
pub(crate) top_p: f32,
pub(crate) stop_tokens: Option<Vec<String>>,
}
#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub(crate) enum Ide {
Neovim,
VSCode,
JetBrains,
Emacs,
Jupyter,
Sublime,
VisualStudio,
#[default]
Unknown,
}
fn parse_ide<'de, D>(d: D) -> std::result::Result<Ide, D::Error>
where
D: Deserializer<'de>,
{
Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown))
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct FimParams {
pub(crate) enabled: bool,
pub(crate) prefix: String,
pub(crate) middle: String,
pub(crate) suffix: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub(crate) enum TokenizerConfig {
Local { path: PathBuf },
HuggingFace { repository: String },
Download { url: String, to: PathBuf },
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct GetCompletionsParams {
#[serde(flatten)]
pub(crate) text_document_position: TextDocumentPositionParams,
pub(crate) request_params: RequestParams,
#[serde(default)]
#[serde(deserialize_with = "parse_ide")]
pub(crate) ide: Ide,
pub(crate) fim: FimParams,
pub(crate) api_token: Option<String>,
pub(crate) model: String,
pub(crate) tokens_to_clear: Vec<String>,
pub(crate) tokenizer_config: Option<TokenizerConfig>,
pub(crate) context_window: usize,
pub(crate) tls_skip_verify_insecure: bool,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct Completion {
pub(crate) generated_text: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub(crate) struct GetCompletionsResult {
request_id: Uuid,
pub(crate) completions: Vec<Completion>,
}