refactor: use return full text (#3)

* feat: add completion context

* feat: remove unnecessary completion context
This commit is contained in:
Luc Georges 2023-08-26 12:34:17 +02:00 committed by GitHub
parent b3b7bb2b4e
commit 6f455eca18
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -38,10 +38,33 @@ struct FimParams {
suffix: String,
}
#[derive(Debug, Serialize)]
struct APIParams {
max_new_tokens: u32,
temperature: f32,
do_sample: bool,
top_p: f32,
stop: Vec<String>,
return_full_text: bool,
}
impl From<RequestParams> for APIParams {
fn from(params: RequestParams) -> Self {
Self {
max_new_tokens: params.max_new_tokens,
temperature: params.temperature,
do_sample: params.do_sample,
top_p: params.top_p,
stop: vec![params.stop_token.clone()],
return_full_text: false,
}
}
}
#[derive(Serialize)]
struct APIRequest {
inputs: String,
parameters: RequestParams,
parameters: APIParams,
}
#[derive(Debug, Deserialize)]
@ -185,7 +208,7 @@ async fn request_completion(
) -> Result<Vec<Generation>> {
let mut req = http_client.post(model).json(&APIRequest {
inputs: prompt,
parameters: request_params,
parameters: request_params.into(),
});
if let Some(api_token) = api_token.clone() {
@ -200,15 +223,11 @@ async fn request_completion(
}
}
fn parse_generations(
generations: Vec<Generation>,
prompt: &str,
stop_token: &str,
) -> Vec<Completion> {
fn parse_generations(generations: Vec<Generation>, stop_token: &str) -> Vec<Completion> {
generations
.into_iter()
.map(|g| Completion {
generated_text: g.generated_text.replace(prompt, "").replace(stop_token, ""),
generated_text: g.generated_text.replace(stop_token, ""),
})
.collect()
}
@ -243,7 +262,7 @@ impl Backend {
)
.await?;
Ok(parse_generations(result, &prompt, &stop_token))
Ok(parse_generations(result, &stop_token))
}
}