feat: add adaptors for various backends (#40)
* ollama * tgi * api-inference * OpenAI based APIs
This commit is contained in:
parent
2a433cdf75
commit
585ea3aae8
250
crates/llm-ls/src/adaptors.rs
Normal file
250
crates/llm-ls/src/adaptors.rs
Normal file
|
@ -0,0 +1,250 @@
|
||||||
|
use super::{
|
||||||
|
internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, RequestParams, NAME,
|
||||||
|
VERSION,
|
||||||
|
};
|
||||||
|
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::fmt::Display;
|
||||||
|
use tower_lsp::jsonrpc;
|
||||||
|
|
||||||
|
fn build_tgi_body(prompt: String, params: &RequestParams) -> Value {
|
||||||
|
serde_json::json!({
|
||||||
|
"inputs": prompt,
|
||||||
|
"parameters": params,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
|
||||||
|
headers.insert(
|
||||||
|
USER_AGENT,
|
||||||
|
HeaderValue::from_str(&user_agent).map_err(internal_error)?,
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(api_token) = api_token {
|
||||||
|
headers.insert(
|
||||||
|
AUTHORIZATION,
|
||||||
|
HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(headers)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_tgi_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
|
||||||
|
let generations =
|
||||||
|
match serde_json::from_str(text).map_err(internal_error)? {
|
||||||
|
APIResponse::Generation(gen) => vec![gen],
|
||||||
|
APIResponse::Generations(_) => {
|
||||||
|
return Err(internal_error(
|
||||||
|
"You are attempting to parse a result in the API inference format when using the `tgi` adaptor",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
APIResponse::Error(err) => return Err(internal_error(err)),
|
||||||
|
};
|
||||||
|
Ok(generations)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_api_body(prompt: String, params: &RequestParams) -> Value {
|
||||||
|
build_tgi_body(prompt, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
|
||||||
|
build_tgi_headers(api_token, ide)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_api_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
|
||||||
|
let generations = match serde_json::from_str(text).map_err(internal_error)? {
|
||||||
|
APIResponse::Generation(gen) => vec![gen],
|
||||||
|
APIResponse::Generations(gens) => gens,
|
||||||
|
APIResponse::Error(err) => return Err(internal_error(err)),
|
||||||
|
};
|
||||||
|
Ok(generations)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_ollama_body(prompt: String, params: &CompletionParams) -> Value {
|
||||||
|
serde_json::json!({
|
||||||
|
"prompt": prompt,
|
||||||
|
"model": params.request_body.as_ref().ok_or_else(|| internal_error("missing request_body")).expect("Unable to make request for ollama").get("model"),
|
||||||
|
"stream": false,
|
||||||
|
// As per [modelfile](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values)
|
||||||
|
"options": {
|
||||||
|
"num_predict": params.request_params.max_new_tokens,
|
||||||
|
"temperature": params.request_params.temperature,
|
||||||
|
"top_p": params.request_params.top_p,
|
||||||
|
"stop": params.request_params.stop_tokens.clone(),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
fn build_ollama_headers() -> Result<HeaderMap, jsonrpc::Error> {
|
||||||
|
Ok(HeaderMap::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct OllamaGeneration {
|
||||||
|
response: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<OllamaGeneration> for Generation {
|
||||||
|
fn from(value: OllamaGeneration) -> Self {
|
||||||
|
Generation {
|
||||||
|
generated_text: value.response,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum OllamaAPIResponse {
|
||||||
|
Generation(OllamaGeneration),
|
||||||
|
Error(APIError),
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_ollama_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
|
||||||
|
let generations = match serde_json::from_str(text).map_err(internal_error)? {
|
||||||
|
OllamaAPIResponse::Generation(gen) => vec![gen.into()],
|
||||||
|
OllamaAPIResponse::Error(err) => return Err(internal_error(err)),
|
||||||
|
};
|
||||||
|
Ok(generations)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_openai_body(prompt: String, params: &CompletionParams) -> Value {
|
||||||
|
serde_json::json!({
|
||||||
|
"prompt": prompt,
|
||||||
|
"model": params.request_body.as_ref().ok_or_else(|| internal_error("missing request_body")).expect("Unable to make request for openai").get("model"),
|
||||||
|
"max_tokens": params.request_params.max_new_tokens,
|
||||||
|
"temperature": params.request_params.temperature,
|
||||||
|
"top_p": params.request_params.top_p,
|
||||||
|
"stop": params.request_params.stop_tokens.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
|
||||||
|
build_api_headers(api_token, ide)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OpenAIGenerationChoice {
|
||||||
|
text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<OpenAIGenerationChoice> for Generation {
|
||||||
|
fn from(value: OpenAIGenerationChoice) -> Self {
|
||||||
|
Generation {
|
||||||
|
generated_text: value.text,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OpenAIGeneration {
|
||||||
|
choices: Vec<OpenAIGenerationChoice>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum OpenAIErrorLoc {
|
||||||
|
String(String),
|
||||||
|
Int(u32),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for OpenAIErrorLoc {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
OpenAIErrorLoc::String(s) => s.fmt(f),
|
||||||
|
OpenAIErrorLoc::Int(i) => i.fmt(f),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OpenAIErrorDetail {
|
||||||
|
loc: OpenAIErrorLoc,
|
||||||
|
msg: String,
|
||||||
|
r#type: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OpenAIError {
|
||||||
|
detail: Vec<OpenAIErrorDetail>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for OpenAIError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
for (i, item) in self.detail.iter().enumerate() {
|
||||||
|
if i != 0 {
|
||||||
|
writeln!(f)?;
|
||||||
|
}
|
||||||
|
write!(f, "{}: {} ({})", item.loc, item.msg, item.r#type)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum OpenAIAPIResponse {
|
||||||
|
Generation(OpenAIGeneration),
|
||||||
|
Error(OpenAIError),
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_openai_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
|
||||||
|
match serde_json::from_str(text).map_err(internal_error) {
|
||||||
|
Ok(OpenAIAPIResponse::Generation(completion)) => {
|
||||||
|
Ok(completion.choices.into_iter().map(|x| x.into()).collect())
|
||||||
|
}
|
||||||
|
Ok(OpenAIAPIResponse::Error(err)) => Err(internal_error(err)),
|
||||||
|
Err(err) => Err(internal_error(err)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const TGI: &str = "tgi";
|
||||||
|
const HUGGING_FACE: &str = "huggingface";
|
||||||
|
const OLLAMA: &str = "ollama";
|
||||||
|
const OPENAI: &str = "openai";
|
||||||
|
const DEFAULT_ADAPTOR: &str = HUGGING_FACE;
|
||||||
|
|
||||||
|
fn unknown_adaptor_error(adaptor: Option<&String>) -> jsonrpc::Error {
|
||||||
|
internal_error(format!("Unknown adaptor {:?}", adaptor))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn adapt_body(prompt: String, params: &CompletionParams) -> Result<Value, jsonrpc::Error> {
|
||||||
|
match params
|
||||||
|
.adaptor
|
||||||
|
.as_ref()
|
||||||
|
.unwrap_or(&DEFAULT_ADAPTOR.to_string())
|
||||||
|
.as_str()
|
||||||
|
{
|
||||||
|
TGI => Ok(build_tgi_body(prompt, ¶ms.request_params)),
|
||||||
|
HUGGING_FACE => Ok(build_api_body(prompt, ¶ms.request_params)),
|
||||||
|
OLLAMA => Ok(build_ollama_body(prompt, params)),
|
||||||
|
OPENAI => Ok(build_openai_body(prompt, params)),
|
||||||
|
_ => Err(unknown_adaptor_error(params.adaptor.as_ref())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn adapt_headers(
|
||||||
|
adaptor: Option<&String>,
|
||||||
|
api_token: Option<&String>,
|
||||||
|
ide: Ide,
|
||||||
|
) -> Result<HeaderMap, jsonrpc::Error> {
|
||||||
|
match adaptor.unwrap_or(&DEFAULT_ADAPTOR.to_string()).as_str() {
|
||||||
|
TGI => build_tgi_headers(api_token, ide),
|
||||||
|
HUGGING_FACE => build_api_headers(api_token, ide),
|
||||||
|
OLLAMA => build_ollama_headers(),
|
||||||
|
OPENAI => build_openai_headers(api_token, ide),
|
||||||
|
_ => Err(unknown_adaptor_error(adaptor)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn parse_generations(adaptor: Option<&String>, text: &str) -> jsonrpc::Result<Vec<Generation>> {
|
||||||
|
match adaptor.unwrap_or(&DEFAULT_ADAPTOR.to_string()).as_str() {
|
||||||
|
TGI => parse_tgi_text(text),
|
||||||
|
HUGGING_FACE => parse_api_text(text),
|
||||||
|
OLLAMA => parse_ollama_text(text),
|
||||||
|
OPENAI => parse_openai_text(text),
|
||||||
|
_ => Err(unknown_adaptor_error(adaptor)),
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,3 +1,4 @@
|
||||||
|
use adaptors::{adapt_body, adapt_headers, parse_generations};
|
||||||
use document::Document;
|
use document::Document;
|
||||||
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
|
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
|
||||||
use ropey::Rope;
|
use ropey::Rope;
|
||||||
|
@ -18,20 +19,21 @@ use tracing_appender::rolling;
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
mod adaptors;
|
||||||
mod document;
|
mod document;
|
||||||
mod language_id;
|
mod language_id;
|
||||||
|
|
||||||
const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600);
|
const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600);
|
||||||
const NAME: &str = "llm-ls";
|
pub const NAME: &str = "llm-ls";
|
||||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||||
|
|
||||||
fn get_position_idx(rope: &Rope, row: usize, col: usize) -> Result<usize> {
|
fn get_position_idx(rope: &Rope, row: usize, col: usize) -> Result<usize> {
|
||||||
Ok(rope.try_line_to_char(row).map_err(internal_error)?
|
Ok(rope.try_line_to_char(row).map_err(internal_error)?
|
||||||
+ col.min(
|
+ col.min(
|
||||||
rope.get_line(row.min(rope.len_lines() - 1))
|
rope.get_line(row.min(rope.len_lines().saturating_sub(1)))
|
||||||
.ok_or_else(|| internal_error(format!("failed to find line at {row}")))?
|
.ok_or_else(|| internal_error(format!("failed to find line at {row}")))?
|
||||||
.len_chars()
|
.len_chars()
|
||||||
- 1,
|
.saturating_sub(1),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,7 +132,7 @@ enum TokenizerConfig {
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
struct RequestParams {
|
pub struct RequestParams {
|
||||||
max_new_tokens: u32,
|
max_new_tokens: u32,
|
||||||
temperature: f32,
|
temperature: f32,
|
||||||
do_sample: bool,
|
do_sample: bool,
|
||||||
|
@ -178,12 +180,12 @@ struct APIRequest {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct Generation {
|
pub struct Generation {
|
||||||
generated_text: String,
|
generated_text: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct APIError {
|
pub struct APIError {
|
||||||
error: String,
|
error: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -195,7 +197,7 @@ impl Display for APIError {
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
enum APIResponse {
|
pub enum APIResponse {
|
||||||
Generation(Generation),
|
Generation(Generation),
|
||||||
Generations(Vec<Generation>),
|
Generations(Vec<Generation>),
|
||||||
Error(APIError),
|
Error(APIError),
|
||||||
|
@ -219,7 +221,7 @@ struct Completion {
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)]
|
#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
enum Ide {
|
pub enum Ide {
|
||||||
Neovim,
|
Neovim,
|
||||||
VSCode,
|
VSCode,
|
||||||
JetBrains,
|
JetBrains,
|
||||||
|
@ -261,7 +263,7 @@ struct RejectedCompletion {
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
struct CompletionParams {
|
pub struct CompletionParams {
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
text_document_position: TextDocumentPositionParams,
|
text_document_position: TextDocumentPositionParams,
|
||||||
request_params: RequestParams,
|
request_params: RequestParams,
|
||||||
|
@ -271,10 +273,12 @@ struct CompletionParams {
|
||||||
fim: FimParams,
|
fim: FimParams,
|
||||||
api_token: Option<String>,
|
api_token: Option<String>,
|
||||||
model: String,
|
model: String,
|
||||||
|
adaptor: Option<String>,
|
||||||
tokens_to_clear: Vec<String>,
|
tokens_to_clear: Vec<String>,
|
||||||
tokenizer_config: Option<TokenizerConfig>,
|
tokenizer_config: Option<TokenizerConfig>,
|
||||||
context_window: usize,
|
context_window: usize,
|
||||||
tls_skip_verify_insecure: bool,
|
tls_skip_verify_insecure: bool,
|
||||||
|
request_body: Option<serde_json::Map<String, serde_json::Value>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
@ -283,7 +287,7 @@ struct CompletionResult {
|
||||||
completions: Vec<Completion>,
|
completions: Vec<Completion>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn internal_error<E: Display>(err: E) -> Error {
|
pub fn internal_error<E: Display>(err: E) -> Error {
|
||||||
let err_msg = err.to_string();
|
let err_msg = err.to_string();
|
||||||
error!(err_msg);
|
error!(err_msg);
|
||||||
Error {
|
Error {
|
||||||
|
@ -398,29 +402,30 @@ fn build_prompt(
|
||||||
|
|
||||||
async fn request_completion(
|
async fn request_completion(
|
||||||
http_client: &reqwest::Client,
|
http_client: &reqwest::Client,
|
||||||
ide: Ide,
|
|
||||||
model: &str,
|
|
||||||
request_params: RequestParams,
|
|
||||||
api_token: Option<&String>,
|
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
params: &CompletionParams,
|
||||||
) -> Result<Vec<Generation>> {
|
) -> Result<Vec<Generation>> {
|
||||||
let t = Instant::now();
|
let t = Instant::now();
|
||||||
|
|
||||||
|
let json = adapt_body(prompt, params).map_err(internal_error)?;
|
||||||
|
let headers = adapt_headers(
|
||||||
|
params.adaptor.as_ref(),
|
||||||
|
params.api_token.as_ref(),
|
||||||
|
params.ide,
|
||||||
|
)?;
|
||||||
let res = http_client
|
let res = http_client
|
||||||
.post(build_url(model))
|
.post(build_url(¶ms.model))
|
||||||
.json(&APIRequest {
|
.json(&json)
|
||||||
inputs: prompt,
|
.headers(headers)
|
||||||
parameters: request_params.into(),
|
|
||||||
})
|
|
||||||
.headers(build_headers(api_token, ide)?)
|
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.map_err(internal_error)?;
|
.map_err(internal_error)?;
|
||||||
|
|
||||||
let generations = match res.json().await.map_err(internal_error)? {
|
let model = ¶ms.model;
|
||||||
APIResponse::Generation(gen) => vec![gen],
|
let generations = parse_generations(
|
||||||
APIResponse::Generations(gens) => gens,
|
params.adaptor.as_ref(),
|
||||||
APIResponse::Error(err) => return Err(internal_error(err)),
|
res.text().await.map_err(internal_error)?.as_str(),
|
||||||
};
|
);
|
||||||
let time = t.elapsed().as_millis();
|
let time = t.elapsed().as_millis();
|
||||||
info!(
|
info!(
|
||||||
model,
|
model,
|
||||||
|
@ -428,10 +433,10 @@ async fn request_completion(
|
||||||
generations = serde_json::to_string(&generations).map_err(internal_error)?,
|
generations = serde_json::to_string(&generations).map_err(internal_error)?,
|
||||||
"{model} computed generations in {time} ms"
|
"{model} computed generations in {time} ms"
|
||||||
);
|
);
|
||||||
Ok(generations)
|
generations
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_generations(
|
fn format_generations(
|
||||||
generations: Vec<Generation>,
|
generations: Vec<Generation>,
|
||||||
tokens_to_clear: &[String],
|
tokens_to_clear: &[String],
|
||||||
completion_type: CompletionType,
|
completion_type: CompletionType,
|
||||||
|
@ -524,7 +529,7 @@ async fn download_tokenizer_file(
|
||||||
async fn get_tokenizer(
|
async fn get_tokenizer(
|
||||||
model: &str,
|
model: &str,
|
||||||
tokenizer_map: &mut HashMap<String, Arc<Tokenizer>>,
|
tokenizer_map: &mut HashMap<String, Arc<Tokenizer>>,
|
||||||
tokenizer_config: Option<TokenizerConfig>,
|
tokenizer_config: Option<&TokenizerConfig>,
|
||||||
http_client: &reqwest::Client,
|
http_client: &reqwest::Client,
|
||||||
cache_dir: impl AsRef<Path>,
|
cache_dir: impl AsRef<Path>,
|
||||||
api_token: Option<&String>,
|
api_token: Option<&String>,
|
||||||
|
@ -543,7 +548,7 @@ async fn get_tokenizer(
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
TokenizerConfig::HuggingFace { repository } => {
|
TokenizerConfig::HuggingFace { repository } => {
|
||||||
let path = cache_dir.as_ref().join(model).join("tokenizer.json");
|
let path = cache_dir.as_ref().join(repository).join("tokenizer.json");
|
||||||
let url =
|
let url =
|
||||||
format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json");
|
format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json");
|
||||||
download_tokenizer_file(http_client, &url, api_token, &path, ide).await?;
|
download_tokenizer_file(http_client, &url, api_token, &path, ide).await?;
|
||||||
|
@ -556,7 +561,7 @@ async fn get_tokenizer(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TokenizerConfig::Download { url, to } => {
|
TokenizerConfig::Download { url, to } => {
|
||||||
download_tokenizer_file(http_client, &url, api_token, &to, ide).await?;
|
download_tokenizer_file(http_client, url, api_token, &to, ide).await?;
|
||||||
match Tokenizer::from_file(to) {
|
match Tokenizer::from_file(to) {
|
||||||
Ok(tokenizer) => Some(Arc::new(tokenizer)),
|
Ok(tokenizer) => Some(Arc::new(tokenizer)),
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
|
@ -627,7 +632,7 @@ impl Backend {
|
||||||
let tokenizer = get_tokenizer(
|
let tokenizer = get_tokenizer(
|
||||||
¶ms.model,
|
¶ms.model,
|
||||||
&mut *self.tokenizer_map.write().await,
|
&mut *self.tokenizer_map.write().await,
|
||||||
params.tokenizer_config,
|
params.tokenizer_config.as_ref(),
|
||||||
&self.http_client,
|
&self.http_client,
|
||||||
&self.cache_dir,
|
&self.cache_dir,
|
||||||
params.api_token.as_ref(),
|
params.api_token.as_ref(),
|
||||||
|
@ -650,15 +655,12 @@ impl Backend {
|
||||||
};
|
};
|
||||||
let result = request_completion(
|
let result = request_completion(
|
||||||
http_client,
|
http_client,
|
||||||
params.ide,
|
|
||||||
¶ms.model,
|
|
||||||
params.request_params,
|
|
||||||
params.api_token.as_ref(),
|
|
||||||
prompt,
|
prompt,
|
||||||
|
¶ms,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let completions = parse_generations(result, ¶ms.tokens_to_clear, completion_type);
|
let completions = format_generations(result, ¶ms.tokens_to_clear, completion_type);
|
||||||
Ok(CompletionResult { request_id, completions })
|
Ok(CompletionResult { request_id, completions })
|
||||||
}.instrument(span).await
|
}.instrument(span).await
|
||||||
}
|
}
|
||||||
|
@ -849,3 +851,4 @@ async fn main() {
|
||||||
|
|
||||||
Server::new(stdin, stdout, socket).serve(service).await;
|
Server::new(stdin, stdout, socket).serve(service).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue