refactor: adaptor -> backend (#70)

This commit is contained in:
Luc Georges 2024-02-06 21:26:53 +01:00 committed by GitHub
parent 1499fd6cbf
commit a9831d5720
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 71 additions and 139 deletions

View file

@ -1,26 +1,12 @@
use super::{ use super::{
internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, RequestParams, NAME, internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION,
VERSION,
}; };
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::{Map, Value};
use std::fmt::Display; use std::fmt::Display;
use tower_lsp::jsonrpc; use tower_lsp::jsonrpc;
fn build_tgi_body(prompt: String, params: &RequestParams) -> Value {
serde_json::json!({
"inputs": prompt,
"parameters": {
"max_new_tokens": params.max_new_tokens,
"temperature": params.temperature,
"do_sample": params.do_sample,
"top_p": params.top_p,
"stop_tokens": params.stop_tokens.clone()
},
})
}
fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> { fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}"); let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
@ -45,7 +31,7 @@ fn parse_tgi_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
APIResponse::Generation(gen) => vec![gen], APIResponse::Generation(gen) => vec![gen],
APIResponse::Generations(_) => { APIResponse::Generations(_) => {
return Err(internal_error( return Err(internal_error(
"You are attempting to parse a result in the API inference format when using the `tgi` adaptor", "You are attempting to parse a result in the API inference format when using the `tgi` backend",
)) ))
} }
APIResponse::Error(err) => return Err(internal_error(err)), APIResponse::Error(err) => return Err(internal_error(err)),
@ -53,10 +39,6 @@ fn parse_tgi_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
Ok(generations) Ok(generations)
} }
fn build_api_body(prompt: String, params: &RequestParams) -> Value {
build_tgi_body(prompt, params)
}
fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> { fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
build_tgi_headers(api_token, ide) build_tgi_headers(api_token, ide)
} }
@ -70,20 +52,6 @@ fn parse_api_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
Ok(generations) Ok(generations)
} }
fn build_ollama_body(prompt: String, params: &CompletionParams) -> Value {
serde_json::json!({
"prompt": prompt,
"model": params.request_body.as_ref().ok_or_else(|| internal_error("missing request_body")).expect("Unable to make request for ollama").get("model"),
"stream": false,
// As per [modelfile](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values)
"options": {
"num_predict": params.request_params.max_new_tokens,
"temperature": params.request_params.temperature,
"top_p": params.request_params.top_p,
"stop": params.request_params.stop_tokens.clone(),
}
})
}
fn build_ollama_headers() -> Result<HeaderMap, jsonrpc::Error> { fn build_ollama_headers() -> Result<HeaderMap, jsonrpc::Error> {
Ok(HeaderMap::new()) Ok(HeaderMap::new())
} }
@ -116,17 +84,6 @@ fn parse_ollama_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
Ok(generations) Ok(generations)
} }
fn build_openai_body(prompt: String, params: &CompletionParams) -> Value {
serde_json::json!({
"prompt": prompt,
"model": params.request_body.as_ref().ok_or_else(|| internal_error("missing request_body")).expect("Unable to make request for openai").get("model"),
"max_tokens": params.request_params.max_new_tokens,
"temperature": params.request_params.temperature,
"top_p": params.request_params.top_p,
"stop": params.request_params.stop_tokens.clone(),
})
}
fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> { fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
build_api_headers(api_token, ide) build_api_headers(api_token, ide)
} }
@ -206,51 +163,47 @@ fn parse_openai_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
} }
} }
pub(crate) const TGI: &str = "tgi"; #[derive(Debug, Default, Deserialize, Serialize)]
pub(crate) const HUGGING_FACE: &str = "huggingface"; #[serde(rename_all = "lowercase")]
pub(crate) const OLLAMA: &str = "ollama"; pub(crate) enum Backend {
pub(crate) const OPENAI: &str = "openai"; #[default]
pub(crate) const DEFAULT_ADAPTOR: &str = HUGGING_FACE; HuggingFace,
Ollama,
fn unknown_adaptor_error(adaptor: Option<&String>) -> jsonrpc::Error { OpenAi,
internal_error(format!("Unknown adaptor {:?}", adaptor)) Tgi,
} }
pub fn adapt_body(prompt: String, params: &CompletionParams) -> Result<Value, jsonrpc::Error> { pub fn build_body(prompt: String, params: &CompletionParams) -> Map<String, Value> {
match params let mut body = params.request_body.clone();
.adaptor match params.backend {
.as_ref() Backend::HuggingFace | Backend::Tgi => {
.unwrap_or(&DEFAULT_ADAPTOR.to_string()) body.insert("inputs".to_string(), Value::String(prompt))
.as_str() }
{ Backend::Ollama | Backend::OpenAi => {
TGI => Ok(build_tgi_body(prompt, &params.request_params)), body.insert("prompt".to_string(), Value::String(prompt))
HUGGING_FACE => Ok(build_api_body(prompt, &params.request_params)), }
OLLAMA => Ok(build_ollama_body(prompt, params)), };
OPENAI => Ok(build_openai_body(prompt, params)), body
_ => Err(unknown_adaptor_error(params.adaptor.as_ref())),
}
} }
pub fn adapt_headers( pub fn build_headers(
adaptor: Option<&String>, backend: &Backend,
api_token: Option<&String>, api_token: Option<&String>,
ide: Ide, ide: Ide,
) -> Result<HeaderMap, jsonrpc::Error> { ) -> Result<HeaderMap, jsonrpc::Error> {
match adaptor.unwrap_or(&DEFAULT_ADAPTOR.to_string()).as_str() { match backend {
TGI => build_tgi_headers(api_token, ide), Backend::HuggingFace => build_api_headers(api_token, ide),
HUGGING_FACE => build_api_headers(api_token, ide), Backend::Ollama => build_ollama_headers(),
OLLAMA => build_ollama_headers(), Backend::OpenAi => build_openai_headers(api_token, ide),
OPENAI => build_openai_headers(api_token, ide), Backend::Tgi => build_tgi_headers(api_token, ide),
_ => Err(unknown_adaptor_error(adaptor)),
} }
} }
pub fn parse_generations(adaptor: Option<&String>, text: &str) -> jsonrpc::Result<Vec<Generation>> { pub fn parse_generations(backend: &Backend, text: &str) -> jsonrpc::Result<Vec<Generation>> {
match adaptor.unwrap_or(&DEFAULT_ADAPTOR.to_string()).as_str() { match backend {
TGI => parse_tgi_text(text), Backend::HuggingFace => parse_api_text(text),
HUGGING_FACE => parse_api_text(text), Backend::Ollama => parse_ollama_text(text),
OLLAMA => parse_ollama_text(text), Backend::OpenAi => parse_openai_text(text),
OPENAI => parse_openai_text(text), Backend::Tgi => parse_tgi_text(text),
_ => Err(unknown_adaptor_error(adaptor)),
} }
} }

View file

@ -1,8 +1,8 @@
use adaptors::{adapt_body, adapt_headers, parse_generations}; use backend::{build_body, build_headers, parse_generations, Backend};
use document::Document; use document::Document;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
use ropey::Rope; use ropey::Rope;
use serde::{Deserialize, Deserializer, Serialize}; use serde::{Deserialize, Deserializer, Serialize};
use serde_json::{Map, Value};
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Display; use std::fmt::Display;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
@ -19,7 +19,7 @@ use tracing_appender::rolling;
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
use uuid::Uuid; use uuid::Uuid;
mod adaptors; mod backend;
mod document; mod document;
mod language_id; mod language_id;
@ -117,10 +117,7 @@ fn should_complete(document: &Document, position: Position) -> Result<Completion
.try_line_to_char(row) .try_line_to_char(row)
.map_err(internal_error)?; .map_err(internal_error)?;
// XXX: We treat the end of a document as a newline // XXX: We treat the end of a document as a newline
let next_char = document let next_char = document.text.get_char(start_idx + column).unwrap_or('\n');
.text
.get_char(start_idx + column)
.unwrap_or('\n');
if next_char.is_whitespace() { if next_char.is_whitespace() {
Ok(CompletionType::SingleLine) Ok(CompletionType::SingleLine)
} else { } else {
@ -131,9 +128,17 @@ fn should_complete(document: &Document, position: Position) -> Result<Completion
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)] #[serde(untagged)]
enum TokenizerConfig { enum TokenizerConfig {
Local { path: PathBuf }, Local {
HuggingFace { repository: String }, path: PathBuf,
Download { url: String, to: PathBuf }, },
HuggingFace {
repository: String,
api_token: Option<String>,
},
Download {
url: String,
to: PathBuf,
},
} }
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]
@ -209,7 +214,7 @@ pub enum APIResponse {
Error(APIError), Error(APIError),
} }
struct Backend { struct LlmService {
cache_dir: PathBuf, cache_dir: PathBuf,
client: Client, client: Client,
document_map: Arc<RwLock<HashMap<String, Document>>>, document_map: Arc<RwLock<HashMap<String, Document>>>,
@ -272,19 +277,18 @@ struct RejectedCompletion {
pub struct CompletionParams { pub struct CompletionParams {
#[serde(flatten)] #[serde(flatten)]
text_document_position: TextDocumentPositionParams, text_document_position: TextDocumentPositionParams,
request_params: RequestParams,
#[serde(default)] #[serde(default)]
#[serde(deserialize_with = "parse_ide")] #[serde(deserialize_with = "parse_ide")]
ide: Ide, ide: Ide,
fim: FimParams, fim: FimParams,
api_token: Option<String>, api_token: Option<String>,
model: String, model: String,
adaptor: Option<String>, backend: Backend,
tokens_to_clear: Vec<String>, tokens_to_clear: Vec<String>,
tokenizer_config: Option<TokenizerConfig>, tokenizer_config: Option<TokenizerConfig>,
context_window: usize, context_window: usize,
tls_skip_verify_insecure: bool, tls_skip_verify_insecure: bool,
request_body: Option<serde_json::Map<String, serde_json::Value>>, request_body: Map<String, Value>,
} }
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize)]
@ -413,12 +417,8 @@ async fn request_completion(
) -> Result<Vec<Generation>> { ) -> Result<Vec<Generation>> {
let t = Instant::now(); let t = Instant::now();
let json = adapt_body(prompt, params).map_err(internal_error)?; let json = build_body(prompt, params);
let headers = adapt_headers( let headers = build_headers(&params.backend, params.api_token.as_ref(), params.ide)?;
params.adaptor.as_ref(),
params.api_token.as_ref(),
params.ide,
)?;
let res = http_client let res = http_client
.post(build_url(&params.model)) .post(build_url(&params.model))
.json(&json) .json(&json)
@ -429,7 +429,7 @@ async fn request_completion(
let model = &params.model; let model = &params.model;
let generations = parse_generations( let generations = parse_generations(
params.adaptor.as_ref(), &params.backend,
res.text().await.map_err(internal_error)?.as_str(), res.text().await.map_err(internal_error)?.as_str(),
); );
let time = t.elapsed().as_millis(); let time = t.elapsed().as_millis();
@ -489,7 +489,7 @@ async fn download_tokenizer_file(
) )
.await .await
.map_err(internal_error)?; .map_err(internal_error)?;
let headers = build_headers(api_token, ide)?; let headers = build_headers(&Backend::HuggingFace, api_token, ide)?;
let mut file = tokio::fs::OpenOptions::new() let mut file = tokio::fs::OpenOptions::new()
.write(true) .write(true)
.create(true) .create(true)
@ -538,7 +538,6 @@ async fn get_tokenizer(
tokenizer_config: Option<&TokenizerConfig>, tokenizer_config: Option<&TokenizerConfig>,
http_client: &reqwest::Client, http_client: &reqwest::Client,
cache_dir: impl AsRef<Path>, cache_dir: impl AsRef<Path>,
api_token: Option<&String>,
ide: Ide, ide: Ide,
) -> Result<Option<Arc<Tokenizer>>> { ) -> Result<Option<Arc<Tokenizer>>> {
if let Some(tokenizer) = tokenizer_map.get(model) { if let Some(tokenizer) = tokenizer_map.get(model) {
@ -553,11 +552,14 @@ async fn get_tokenizer(
None None
} }
}, },
TokenizerConfig::HuggingFace { repository } => { TokenizerConfig::HuggingFace {
repository,
api_token,
} => {
let path = cache_dir.as_ref().join(repository).join("tokenizer.json"); let path = cache_dir.as_ref().join(repository).join("tokenizer.json");
let url = let url =
format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json"); format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json");
download_tokenizer_file(http_client, &url, api_token, &path, ide).await?; download_tokenizer_file(http_client, &url, api_token.as_ref(), &path, ide).await?;
match Tokenizer::from_file(path) { match Tokenizer::from_file(path) {
Ok(tokenizer) => Some(Arc::new(tokenizer)), Ok(tokenizer) => Some(Arc::new(tokenizer)),
Err(err) => { Err(err) => {
@ -567,7 +569,7 @@ async fn get_tokenizer(
} }
} }
TokenizerConfig::Download { url, to } => { TokenizerConfig::Download { url, to } => {
download_tokenizer_file(http_client, url, api_token, &to, ide).await?; download_tokenizer_file(http_client, url, None, &to, ide).await?;
match Tokenizer::from_file(to) { match Tokenizer::from_file(to) {
Ok(tokenizer) => Some(Arc::new(tokenizer)), Ok(tokenizer) => Some(Arc::new(tokenizer)),
Err(err) => { Err(err) => {
@ -594,7 +596,7 @@ fn build_url(model: &str) -> String {
} }
} }
impl Backend { impl LlmService {
async fn get_completions(&self, params: CompletionParams) -> Result<CompletionResult> { async fn get_completions(&self, params: CompletionParams) -> Result<CompletionResult> {
let request_id = Uuid::new_v4(); let request_id = Uuid::new_v4();
let span = info_span!("completion_request", %request_id); let span = info_span!("completion_request", %request_id);
@ -611,15 +613,11 @@ impl Backend {
language_id = %document.language_id, language_id = %document.language_id,
model = params.model, model = params.model,
ide = %params.ide, ide = %params.ide,
max_new_tokens = params.request_params.max_new_tokens, request_body = serde_json::to_string(&params.request_body).map_err(internal_error)?,
temperature = params.request_params.temperature,
do_sample = params.request_params.do_sample,
top_p = params.request_params.top_p,
stop_tokens = ?params.request_params.stop_tokens,
"received completion request for {}", "received completion request for {}",
params.text_document_position.text_document.uri params.text_document_position.text_document.uri
); );
let is_using_inference_api = params.adaptor.as_ref().unwrap_or(&adaptors::DEFAULT_ADAPTOR.to_owned()).as_str() == adaptors::HUGGING_FACE; let is_using_inference_api = matches!(params.backend, Backend::HuggingFace);
if params.api_token.is_none() && is_using_inference_api { if params.api_token.is_none() && is_using_inference_api {
let now = Instant::now(); let now = Instant::now();
let unauthenticated_warn_at = self.unauthenticated_warn_at.read().await; let unauthenticated_warn_at = self.unauthenticated_warn_at.read().await;
@ -642,7 +640,6 @@ impl Backend {
params.tokenizer_config.as_ref(), params.tokenizer_config.as_ref(),
&self.http_client, &self.http_client,
&self.cache_dir, &self.cache_dir,
params.api_token.as_ref(),
params.ide, params.ide,
) )
.await?; .await?;
@ -693,7 +690,7 @@ impl Backend {
} }
#[tower_lsp::async_trait] #[tower_lsp::async_trait]
impl LanguageServer for Backend { impl LanguageServer for LlmService {
async fn initialize(&self, params: InitializeParams) -> Result<InitializeResult> { async fn initialize(&self, params: InitializeParams) -> Result<InitializeResult> {
*self.workspace_folders.write().await = params.workspace_folders; *self.workspace_folders.write().await = params.workspace_folders;
Ok(InitializeResult { Ok(InitializeResult {
@ -795,24 +792,6 @@ impl LanguageServer for Backend {
} }
} }
fn build_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
headers.insert(
USER_AGENT,
HeaderValue::from_str(&user_agent).map_err(internal_error)?,
);
if let Some(api_token) = api_token {
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?,
);
}
Ok(headers)
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
let stdin = tokio::io::stdin(); let stdin = tokio::io::stdin();
@ -846,7 +825,7 @@ async fn main() {
.build() .build()
.expect("failed to build reqwest unsafe client"); .expect("failed to build reqwest unsafe client");
let (service, socket) = LspService::build(|client| Backend { let (service, socket) = LspService::build(|client| LlmService {
cache_dir, cache_dir,
client, client,
document_map: Arc::new(RwLock::new(HashMap::new())), document_map: Arc::new(RwLock::new(HashMap::new())),
@ -860,9 +839,9 @@ async fn main() {
.expect("instant to be in bounds"), .expect("instant to be in bounds"),
)), )),
}) })
.custom_method("llm-ls/getCompletions", Backend::get_completions) .custom_method("llm-ls/getCompletions", LlmService::get_completions)
.custom_method("llm-ls/acceptCompletion", Backend::accept_completion) .custom_method("llm-ls/acceptCompletion", LlmService::accept_completion)
.custom_method("llm-ls/rejectCompletion", Backend::reject_completion) .custom_method("llm-ls/rejectCompletion", LlmService::reject_completion)
.finish(); .finish();
Server::new(stdin, stdout, socket).serve(service).await; Server::new(stdin, stdout, socket).serve(service).await;