refactor: use return full text (#3)

* feat: add completion context

* feat: remove unnecessary completion context
This commit is contained in:
Luc Georges 2023-08-26 12:34:17 +02:00 committed by GitHub
parent b3b7bb2b4e
commit 6f455eca18
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -38,10 +38,33 @@ struct FimParams {
suffix: String, suffix: String,
} }
#[derive(Debug, Serialize)]
struct APIParams {
max_new_tokens: u32,
temperature: f32,
do_sample: bool,
top_p: f32,
stop: Vec<String>,
return_full_text: bool,
}
impl From<RequestParams> for APIParams {
fn from(params: RequestParams) -> Self {
Self {
max_new_tokens: params.max_new_tokens,
temperature: params.temperature,
do_sample: params.do_sample,
top_p: params.top_p,
stop: vec![params.stop_token.clone()],
return_full_text: false,
}
}
}
#[derive(Serialize)] #[derive(Serialize)]
struct APIRequest { struct APIRequest {
inputs: String, inputs: String,
parameters: RequestParams, parameters: APIParams,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -185,7 +208,7 @@ async fn request_completion(
) -> Result<Vec<Generation>> { ) -> Result<Vec<Generation>> {
let mut req = http_client.post(model).json(&APIRequest { let mut req = http_client.post(model).json(&APIRequest {
inputs: prompt, inputs: prompt,
parameters: request_params, parameters: request_params.into(),
}); });
if let Some(api_token) = api_token.clone() { if let Some(api_token) = api_token.clone() {
@ -200,15 +223,11 @@ async fn request_completion(
} }
} }
fn parse_generations( fn parse_generations(generations: Vec<Generation>, stop_token: &str) -> Vec<Completion> {
generations: Vec<Generation>,
prompt: &str,
stop_token: &str,
) -> Vec<Completion> {
generations generations
.into_iter() .into_iter()
.map(|g| Completion { .map(|g| Completion {
generated_text: g.generated_text.replace(prompt, "").replace(stop_token, ""), generated_text: g.generated_text.replace(stop_token, ""),
}) })
.collect() .collect()
} }
@ -243,7 +262,7 @@ impl Backend {
) )
.await?; .await?;
Ok(parse_generations(result, &prompt, &stop_token)) Ok(parse_generations(result, &stop_token))
} }
} }