refactor: adaptor -> backend (#70)
This commit is contained in:
parent
1499fd6cbf
commit
a9831d5720
|
@ -1,26 +1,12 @@
|
||||||
use super::{
|
use super::{
|
||||||
internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, RequestParams, NAME,
|
internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION,
|
||||||
VERSION,
|
|
||||||
};
|
};
|
||||||
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
|
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::{Map, Value};
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
use tower_lsp::jsonrpc;
|
use tower_lsp::jsonrpc;
|
||||||
|
|
||||||
fn build_tgi_body(prompt: String, params: &RequestParams) -> Value {
|
|
||||||
serde_json::json!({
|
|
||||||
"inputs": prompt,
|
|
||||||
"parameters": {
|
|
||||||
"max_new_tokens": params.max_new_tokens,
|
|
||||||
"temperature": params.temperature,
|
|
||||||
"do_sample": params.do_sample,
|
|
||||||
"top_p": params.top_p,
|
|
||||||
"stop_tokens": params.stop_tokens.clone()
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
|
fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
|
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
|
||||||
|
@ -45,7 +31,7 @@ fn parse_tgi_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
|
||||||
APIResponse::Generation(gen) => vec![gen],
|
APIResponse::Generation(gen) => vec![gen],
|
||||||
APIResponse::Generations(_) => {
|
APIResponse::Generations(_) => {
|
||||||
return Err(internal_error(
|
return Err(internal_error(
|
||||||
"You are attempting to parse a result in the API inference format when using the `tgi` adaptor",
|
"You are attempting to parse a result in the API inference format when using the `tgi` backend",
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
APIResponse::Error(err) => return Err(internal_error(err)),
|
APIResponse::Error(err) => return Err(internal_error(err)),
|
||||||
|
@ -53,10 +39,6 @@ fn parse_tgi_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
|
||||||
Ok(generations)
|
Ok(generations)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_api_body(prompt: String, params: &RequestParams) -> Value {
|
|
||||||
build_tgi_body(prompt, params)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
|
fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
|
||||||
build_tgi_headers(api_token, ide)
|
build_tgi_headers(api_token, ide)
|
||||||
}
|
}
|
||||||
|
@ -70,20 +52,6 @@ fn parse_api_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
|
||||||
Ok(generations)
|
Ok(generations)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_ollama_body(prompt: String, params: &CompletionParams) -> Value {
|
|
||||||
serde_json::json!({
|
|
||||||
"prompt": prompt,
|
|
||||||
"model": params.request_body.as_ref().ok_or_else(|| internal_error("missing request_body")).expect("Unable to make request for ollama").get("model"),
|
|
||||||
"stream": false,
|
|
||||||
// As per [modelfile](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values)
|
|
||||||
"options": {
|
|
||||||
"num_predict": params.request_params.max_new_tokens,
|
|
||||||
"temperature": params.request_params.temperature,
|
|
||||||
"top_p": params.request_params.top_p,
|
|
||||||
"stop": params.request_params.stop_tokens.clone(),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
fn build_ollama_headers() -> Result<HeaderMap, jsonrpc::Error> {
|
fn build_ollama_headers() -> Result<HeaderMap, jsonrpc::Error> {
|
||||||
Ok(HeaderMap::new())
|
Ok(HeaderMap::new())
|
||||||
}
|
}
|
||||||
|
@ -116,17 +84,6 @@ fn parse_ollama_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
|
||||||
Ok(generations)
|
Ok(generations)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_openai_body(prompt: String, params: &CompletionParams) -> Value {
|
|
||||||
serde_json::json!({
|
|
||||||
"prompt": prompt,
|
|
||||||
"model": params.request_body.as_ref().ok_or_else(|| internal_error("missing request_body")).expect("Unable to make request for openai").get("model"),
|
|
||||||
"max_tokens": params.request_params.max_new_tokens,
|
|
||||||
"temperature": params.request_params.temperature,
|
|
||||||
"top_p": params.request_params.top_p,
|
|
||||||
"stop": params.request_params.stop_tokens.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
|
fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
|
||||||
build_api_headers(api_token, ide)
|
build_api_headers(api_token, ide)
|
||||||
}
|
}
|
||||||
|
@ -206,51 +163,47 @@ fn parse_openai_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) const TGI: &str = "tgi";
|
#[derive(Debug, Default, Deserialize, Serialize)]
|
||||||
pub(crate) const HUGGING_FACE: &str = "huggingface";
|
#[serde(rename_all = "lowercase")]
|
||||||
pub(crate) const OLLAMA: &str = "ollama";
|
pub(crate) enum Backend {
|
||||||
pub(crate) const OPENAI: &str = "openai";
|
#[default]
|
||||||
pub(crate) const DEFAULT_ADAPTOR: &str = HUGGING_FACE;
|
HuggingFace,
|
||||||
|
Ollama,
|
||||||
fn unknown_adaptor_error(adaptor: Option<&String>) -> jsonrpc::Error {
|
OpenAi,
|
||||||
internal_error(format!("Unknown adaptor {:?}", adaptor))
|
Tgi,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn adapt_body(prompt: String, params: &CompletionParams) -> Result<Value, jsonrpc::Error> {
|
pub fn build_body(prompt: String, params: &CompletionParams) -> Map<String, Value> {
|
||||||
match params
|
let mut body = params.request_body.clone();
|
||||||
.adaptor
|
match params.backend {
|
||||||
.as_ref()
|
Backend::HuggingFace | Backend::Tgi => {
|
||||||
.unwrap_or(&DEFAULT_ADAPTOR.to_string())
|
body.insert("inputs".to_string(), Value::String(prompt))
|
||||||
.as_str()
|
}
|
||||||
{
|
Backend::Ollama | Backend::OpenAi => {
|
||||||
TGI => Ok(build_tgi_body(prompt, ¶ms.request_params)),
|
body.insert("prompt".to_string(), Value::String(prompt))
|
||||||
HUGGING_FACE => Ok(build_api_body(prompt, ¶ms.request_params)),
|
}
|
||||||
OLLAMA => Ok(build_ollama_body(prompt, params)),
|
};
|
||||||
OPENAI => Ok(build_openai_body(prompt, params)),
|
body
|
||||||
_ => Err(unknown_adaptor_error(params.adaptor.as_ref())),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn adapt_headers(
|
pub fn build_headers(
|
||||||
adaptor: Option<&String>,
|
backend: &Backend,
|
||||||
api_token: Option<&String>,
|
api_token: Option<&String>,
|
||||||
ide: Ide,
|
ide: Ide,
|
||||||
) -> Result<HeaderMap, jsonrpc::Error> {
|
) -> Result<HeaderMap, jsonrpc::Error> {
|
||||||
match adaptor.unwrap_or(&DEFAULT_ADAPTOR.to_string()).as_str() {
|
match backend {
|
||||||
TGI => build_tgi_headers(api_token, ide),
|
Backend::HuggingFace => build_api_headers(api_token, ide),
|
||||||
HUGGING_FACE => build_api_headers(api_token, ide),
|
Backend::Ollama => build_ollama_headers(),
|
||||||
OLLAMA => build_ollama_headers(),
|
Backend::OpenAi => build_openai_headers(api_token, ide),
|
||||||
OPENAI => build_openai_headers(api_token, ide),
|
Backend::Tgi => build_tgi_headers(api_token, ide),
|
||||||
_ => Err(unknown_adaptor_error(adaptor)),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn parse_generations(adaptor: Option<&String>, text: &str) -> jsonrpc::Result<Vec<Generation>> {
|
pub fn parse_generations(backend: &Backend, text: &str) -> jsonrpc::Result<Vec<Generation>> {
|
||||||
match adaptor.unwrap_or(&DEFAULT_ADAPTOR.to_string()).as_str() {
|
match backend {
|
||||||
TGI => parse_tgi_text(text),
|
Backend::HuggingFace => parse_api_text(text),
|
||||||
HUGGING_FACE => parse_api_text(text),
|
Backend::Ollama => parse_ollama_text(text),
|
||||||
OLLAMA => parse_ollama_text(text),
|
Backend::OpenAi => parse_openai_text(text),
|
||||||
OPENAI => parse_openai_text(text),
|
Backend::Tgi => parse_tgi_text(text),
|
||||||
_ => Err(unknown_adaptor_error(adaptor)),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -1,8 +1,8 @@
|
||||||
use adaptors::{adapt_body, adapt_headers, parse_generations};
|
use backend::{build_body, build_headers, parse_generations, Backend};
|
||||||
use document::Document;
|
use document::Document;
|
||||||
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
|
|
||||||
use ropey::Rope;
|
use ropey::Rope;
|
||||||
use serde::{Deserialize, Deserializer, Serialize};
|
use serde::{Deserialize, Deserializer, Serialize};
|
||||||
|
use serde_json::{Map, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
@ -19,7 +19,7 @@ use tracing_appender::rolling;
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
mod adaptors;
|
mod backend;
|
||||||
mod document;
|
mod document;
|
||||||
mod language_id;
|
mod language_id;
|
||||||
|
|
||||||
|
@ -117,10 +117,7 @@ fn should_complete(document: &Document, position: Position) -> Result<Completion
|
||||||
.try_line_to_char(row)
|
.try_line_to_char(row)
|
||||||
.map_err(internal_error)?;
|
.map_err(internal_error)?;
|
||||||
// XXX: We treat the end of a document as a newline
|
// XXX: We treat the end of a document as a newline
|
||||||
let next_char = document
|
let next_char = document.text.get_char(start_idx + column).unwrap_or('\n');
|
||||||
.text
|
|
||||||
.get_char(start_idx + column)
|
|
||||||
.unwrap_or('\n');
|
|
||||||
if next_char.is_whitespace() {
|
if next_char.is_whitespace() {
|
||||||
Ok(CompletionType::SingleLine)
|
Ok(CompletionType::SingleLine)
|
||||||
} else {
|
} else {
|
||||||
|
@ -131,9 +128,17 @@ fn should_complete(document: &Document, position: Position) -> Result<Completion
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
enum TokenizerConfig {
|
enum TokenizerConfig {
|
||||||
Local { path: PathBuf },
|
Local {
|
||||||
HuggingFace { repository: String },
|
path: PathBuf,
|
||||||
Download { url: String, to: PathBuf },
|
},
|
||||||
|
HuggingFace {
|
||||||
|
repository: String,
|
||||||
|
api_token: Option<String>,
|
||||||
|
},
|
||||||
|
Download {
|
||||||
|
url: String,
|
||||||
|
to: PathBuf,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
@ -209,7 +214,7 @@ pub enum APIResponse {
|
||||||
Error(APIError),
|
Error(APIError),
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Backend {
|
struct LlmService {
|
||||||
cache_dir: PathBuf,
|
cache_dir: PathBuf,
|
||||||
client: Client,
|
client: Client,
|
||||||
document_map: Arc<RwLock<HashMap<String, Document>>>,
|
document_map: Arc<RwLock<HashMap<String, Document>>>,
|
||||||
|
@ -272,19 +277,18 @@ struct RejectedCompletion {
|
||||||
pub struct CompletionParams {
|
pub struct CompletionParams {
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
text_document_position: TextDocumentPositionParams,
|
text_document_position: TextDocumentPositionParams,
|
||||||
request_params: RequestParams,
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[serde(deserialize_with = "parse_ide")]
|
#[serde(deserialize_with = "parse_ide")]
|
||||||
ide: Ide,
|
ide: Ide,
|
||||||
fim: FimParams,
|
fim: FimParams,
|
||||||
api_token: Option<String>,
|
api_token: Option<String>,
|
||||||
model: String,
|
model: String,
|
||||||
adaptor: Option<String>,
|
backend: Backend,
|
||||||
tokens_to_clear: Vec<String>,
|
tokens_to_clear: Vec<String>,
|
||||||
tokenizer_config: Option<TokenizerConfig>,
|
tokenizer_config: Option<TokenizerConfig>,
|
||||||
context_window: usize,
|
context_window: usize,
|
||||||
tls_skip_verify_insecure: bool,
|
tls_skip_verify_insecure: bool,
|
||||||
request_body: Option<serde_json::Map<String, serde_json::Value>>,
|
request_body: Map<String, Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
@ -413,12 +417,8 @@ async fn request_completion(
|
||||||
) -> Result<Vec<Generation>> {
|
) -> Result<Vec<Generation>> {
|
||||||
let t = Instant::now();
|
let t = Instant::now();
|
||||||
|
|
||||||
let json = adapt_body(prompt, params).map_err(internal_error)?;
|
let json = build_body(prompt, params);
|
||||||
let headers = adapt_headers(
|
let headers = build_headers(¶ms.backend, params.api_token.as_ref(), params.ide)?;
|
||||||
params.adaptor.as_ref(),
|
|
||||||
params.api_token.as_ref(),
|
|
||||||
params.ide,
|
|
||||||
)?;
|
|
||||||
let res = http_client
|
let res = http_client
|
||||||
.post(build_url(¶ms.model))
|
.post(build_url(¶ms.model))
|
||||||
.json(&json)
|
.json(&json)
|
||||||
|
@ -429,7 +429,7 @@ async fn request_completion(
|
||||||
|
|
||||||
let model = ¶ms.model;
|
let model = ¶ms.model;
|
||||||
let generations = parse_generations(
|
let generations = parse_generations(
|
||||||
params.adaptor.as_ref(),
|
¶ms.backend,
|
||||||
res.text().await.map_err(internal_error)?.as_str(),
|
res.text().await.map_err(internal_error)?.as_str(),
|
||||||
);
|
);
|
||||||
let time = t.elapsed().as_millis();
|
let time = t.elapsed().as_millis();
|
||||||
|
@ -489,7 +489,7 @@ async fn download_tokenizer_file(
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(internal_error)?;
|
.map_err(internal_error)?;
|
||||||
let headers = build_headers(api_token, ide)?;
|
let headers = build_headers(&Backend::HuggingFace, api_token, ide)?;
|
||||||
let mut file = tokio::fs::OpenOptions::new()
|
let mut file = tokio::fs::OpenOptions::new()
|
||||||
.write(true)
|
.write(true)
|
||||||
.create(true)
|
.create(true)
|
||||||
|
@ -538,7 +538,6 @@ async fn get_tokenizer(
|
||||||
tokenizer_config: Option<&TokenizerConfig>,
|
tokenizer_config: Option<&TokenizerConfig>,
|
||||||
http_client: &reqwest::Client,
|
http_client: &reqwest::Client,
|
||||||
cache_dir: impl AsRef<Path>,
|
cache_dir: impl AsRef<Path>,
|
||||||
api_token: Option<&String>,
|
|
||||||
ide: Ide,
|
ide: Ide,
|
||||||
) -> Result<Option<Arc<Tokenizer>>> {
|
) -> Result<Option<Arc<Tokenizer>>> {
|
||||||
if let Some(tokenizer) = tokenizer_map.get(model) {
|
if let Some(tokenizer) = tokenizer_map.get(model) {
|
||||||
|
@ -553,11 +552,14 @@ async fn get_tokenizer(
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
TokenizerConfig::HuggingFace { repository } => {
|
TokenizerConfig::HuggingFace {
|
||||||
|
repository,
|
||||||
|
api_token,
|
||||||
|
} => {
|
||||||
let path = cache_dir.as_ref().join(repository).join("tokenizer.json");
|
let path = cache_dir.as_ref().join(repository).join("tokenizer.json");
|
||||||
let url =
|
let url =
|
||||||
format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json");
|
format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json");
|
||||||
download_tokenizer_file(http_client, &url, api_token, &path, ide).await?;
|
download_tokenizer_file(http_client, &url, api_token.as_ref(), &path, ide).await?;
|
||||||
match Tokenizer::from_file(path) {
|
match Tokenizer::from_file(path) {
|
||||||
Ok(tokenizer) => Some(Arc::new(tokenizer)),
|
Ok(tokenizer) => Some(Arc::new(tokenizer)),
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
|
@ -567,7 +569,7 @@ async fn get_tokenizer(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TokenizerConfig::Download { url, to } => {
|
TokenizerConfig::Download { url, to } => {
|
||||||
download_tokenizer_file(http_client, url, api_token, &to, ide).await?;
|
download_tokenizer_file(http_client, url, None, &to, ide).await?;
|
||||||
match Tokenizer::from_file(to) {
|
match Tokenizer::from_file(to) {
|
||||||
Ok(tokenizer) => Some(Arc::new(tokenizer)),
|
Ok(tokenizer) => Some(Arc::new(tokenizer)),
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
|
@ -594,7 +596,7 @@ fn build_url(model: &str) -> String {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Backend {
|
impl LlmService {
|
||||||
async fn get_completions(&self, params: CompletionParams) -> Result<CompletionResult> {
|
async fn get_completions(&self, params: CompletionParams) -> Result<CompletionResult> {
|
||||||
let request_id = Uuid::new_v4();
|
let request_id = Uuid::new_v4();
|
||||||
let span = info_span!("completion_request", %request_id);
|
let span = info_span!("completion_request", %request_id);
|
||||||
|
@ -611,15 +613,11 @@ impl Backend {
|
||||||
language_id = %document.language_id,
|
language_id = %document.language_id,
|
||||||
model = params.model,
|
model = params.model,
|
||||||
ide = %params.ide,
|
ide = %params.ide,
|
||||||
max_new_tokens = params.request_params.max_new_tokens,
|
request_body = serde_json::to_string(¶ms.request_body).map_err(internal_error)?,
|
||||||
temperature = params.request_params.temperature,
|
|
||||||
do_sample = params.request_params.do_sample,
|
|
||||||
top_p = params.request_params.top_p,
|
|
||||||
stop_tokens = ?params.request_params.stop_tokens,
|
|
||||||
"received completion request for {}",
|
"received completion request for {}",
|
||||||
params.text_document_position.text_document.uri
|
params.text_document_position.text_document.uri
|
||||||
);
|
);
|
||||||
let is_using_inference_api = params.adaptor.as_ref().unwrap_or(&adaptors::DEFAULT_ADAPTOR.to_owned()).as_str() == adaptors::HUGGING_FACE;
|
let is_using_inference_api = matches!(params.backend, Backend::HuggingFace);
|
||||||
if params.api_token.is_none() && is_using_inference_api {
|
if params.api_token.is_none() && is_using_inference_api {
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
let unauthenticated_warn_at = self.unauthenticated_warn_at.read().await;
|
let unauthenticated_warn_at = self.unauthenticated_warn_at.read().await;
|
||||||
|
@ -642,7 +640,6 @@ impl Backend {
|
||||||
params.tokenizer_config.as_ref(),
|
params.tokenizer_config.as_ref(),
|
||||||
&self.http_client,
|
&self.http_client,
|
||||||
&self.cache_dir,
|
&self.cache_dir,
|
||||||
params.api_token.as_ref(),
|
|
||||||
params.ide,
|
params.ide,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
@ -693,7 +690,7 @@ impl Backend {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tower_lsp::async_trait]
|
#[tower_lsp::async_trait]
|
||||||
impl LanguageServer for Backend {
|
impl LanguageServer for LlmService {
|
||||||
async fn initialize(&self, params: InitializeParams) -> Result<InitializeResult> {
|
async fn initialize(&self, params: InitializeParams) -> Result<InitializeResult> {
|
||||||
*self.workspace_folders.write().await = params.workspace_folders;
|
*self.workspace_folders.write().await = params.workspace_folders;
|
||||||
Ok(InitializeResult {
|
Ok(InitializeResult {
|
||||||
|
@ -795,24 +792,6 @@ impl LanguageServer for Backend {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
|
|
||||||
let mut headers = HeaderMap::new();
|
|
||||||
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
|
|
||||||
headers.insert(
|
|
||||||
USER_AGENT,
|
|
||||||
HeaderValue::from_str(&user_agent).map_err(internal_error)?,
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Some(api_token) = api_token {
|
|
||||||
headers.insert(
|
|
||||||
AUTHORIZATION,
|
|
||||||
HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(headers)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
let stdin = tokio::io::stdin();
|
let stdin = tokio::io::stdin();
|
||||||
|
@ -846,7 +825,7 @@ async fn main() {
|
||||||
.build()
|
.build()
|
||||||
.expect("failed to build reqwest unsafe client");
|
.expect("failed to build reqwest unsafe client");
|
||||||
|
|
||||||
let (service, socket) = LspService::build(|client| Backend {
|
let (service, socket) = LspService::build(|client| LlmService {
|
||||||
cache_dir,
|
cache_dir,
|
||||||
client,
|
client,
|
||||||
document_map: Arc::new(RwLock::new(HashMap::new())),
|
document_map: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
@ -860,9 +839,9 @@ async fn main() {
|
||||||
.expect("instant to be in bounds"),
|
.expect("instant to be in bounds"),
|
||||||
)),
|
)),
|
||||||
})
|
})
|
||||||
.custom_method("llm-ls/getCompletions", Backend::get_completions)
|
.custom_method("llm-ls/getCompletions", LlmService::get_completions)
|
||||||
.custom_method("llm-ls/acceptCompletion", Backend::accept_completion)
|
.custom_method("llm-ls/acceptCompletion", LlmService::accept_completion)
|
||||||
.custom_method("llm-ls/rejectCompletion", Backend::reject_completion)
|
.custom_method("llm-ls/rejectCompletion", LlmService::reject_completion)
|
||||||
.finish();
|
.finish();
|
||||||
|
|
||||||
Server::new(stdin, stdout, socket).serve(service).await;
|
Server::new(stdin, stdout, socket).serve(service).await;
|
||||||
|
|
Loading…
Reference in a new issue