From 6f455eca1848ee3c921101dd879ff72a69f8dffc Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Sat, 26 Aug 2023 12:34:17 +0200 Subject: [PATCH] refactor: use return full text (#3) * feat: add completion context * feat: remove unnecessary completion context --- src/main.rs | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/src/main.rs b/src/main.rs index 3841ab0..5a02a69 100644 --- a/src/main.rs +++ b/src/main.rs @@ -38,10 +38,33 @@ struct FimParams { suffix: String, } +#[derive(Debug, Serialize)] +struct APIParams { + max_new_tokens: u32, + temperature: f32, + do_sample: bool, + top_p: f32, + stop: Vec, + return_full_text: bool, +} + +impl From for APIParams { + fn from(params: RequestParams) -> Self { + Self { + max_new_tokens: params.max_new_tokens, + temperature: params.temperature, + do_sample: params.do_sample, + top_p: params.top_p, + stop: vec![params.stop_token.clone()], + return_full_text: false, + } + } +} + #[derive(Serialize)] struct APIRequest { inputs: String, - parameters: RequestParams, + parameters: APIParams, } #[derive(Debug, Deserialize)] @@ -185,7 +208,7 @@ async fn request_completion( ) -> Result> { let mut req = http_client.post(model).json(&APIRequest { inputs: prompt, - parameters: request_params, + parameters: request_params.into(), }); if let Some(api_token) = api_token.clone() { @@ -200,15 +223,11 @@ async fn request_completion( } } -fn parse_generations( - generations: Vec, - prompt: &str, - stop_token: &str, -) -> Vec { +fn parse_generations(generations: Vec, stop_token: &str) -> Vec { generations .into_iter() .map(|g| Completion { - generated_text: g.generated_text.replace(prompt, "").replace(stop_token, ""), + generated_text: g.generated_text.replace(stop_token, ""), }) .collect() } @@ -243,7 +262,7 @@ impl Backend { ) .await?; - Ok(parse_generations(result, &prompt, &stop_token)) + Ok(parse_generations(result, &stop_token)) } }