fix: stop tokens issue (#9)

This commit is contained in:
Luc Georges 2023-09-07 16:29:44 +02:00 committed by GitHub
parent f5e6911932
commit c774ec74fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -24,7 +24,7 @@ struct RequestParams {
temperature: f32, temperature: f32,
do_sample: bool, do_sample: bool,
top_p: f32, top_p: f32,
stop_token: String, stop_tokens: Option<Vec<String>>,
} }
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize)]
@ -41,7 +41,8 @@ struct APIParams {
temperature: f32, temperature: f32,
do_sample: bool, do_sample: bool,
top_p: f32, top_p: f32,
stop: Vec<String>, #[serde(skip_serializing)]
stop: Option<Vec<String>>,
return_full_text: bool, return_full_text: bool,
} }
@ -52,7 +53,7 @@ impl From<RequestParams> for APIParams {
temperature: params.temperature, temperature: params.temperature,
do_sample: params.do_sample, do_sample: params.do_sample,
top_p: params.top_p, top_p: params.top_p,
stop: vec![params.stop_token.clone()], stop: params.stop_tokens,
return_full_text: false, return_full_text: false,
} }
} }
@ -123,6 +124,7 @@ struct CompletionParams {
fim: FimParams, fim: FimParams,
api_token: Option<String>, api_token: Option<String>,
model: String, model: String,
model_eos: String,
tokenizer_path: Option<String>, tokenizer_path: Option<String>,
context_window: usize, context_window: usize,
} }
@ -243,11 +245,11 @@ async fn request_completion(
} }
} }
fn parse_generations(generations: Vec<Generation>, stop_token: &str) -> Vec<Completion> { fn parse_generations(generations: Vec<Generation>, eos: &str) -> Vec<Completion> {
generations generations
.into_iter() .into_iter()
.map(|g| Completion { .map(|g| Completion {
generated_text: g.generated_text.replace(stop_token, ""), generated_text: g.generated_text.replace(eos, ""),
}) })
.collect() .collect()
} }
@ -352,7 +354,6 @@ impl Backend {
tokenizer, tokenizer,
params.context_window, params.context_window,
)?; )?;
let stop_token = params.request_params.stop_token.clone();
let result = request_completion( let result = request_completion(
&self.http_client, &self.http_client,
&params.model, &params.model,
@ -362,7 +363,7 @@ impl Backend {
) )
.await?; .await?;
Ok(parse_generations(result, &stop_token)) Ok(parse_generations(result, &params.model_eos))
} }
} }