fix: stop tokens issue (#9)

This commit is contained in:
Luc Georges 2023-09-07 16:29:44 +02:00 committed by GitHub
parent f5e6911932
commit c774ec74fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -24,7 +24,7 @@ struct RequestParams {
temperature: f32,
do_sample: bool,
top_p: f32,
stop_token: String,
stop_tokens: Option<Vec<String>>,
}
#[derive(Debug, Deserialize, Serialize)]
@ -41,7 +41,8 @@ struct APIParams {
temperature: f32,
do_sample: bool,
top_p: f32,
stop: Vec<String>,
#[serde(skip_serializing)]
stop: Option<Vec<String>>,
return_full_text: bool,
}
@ -52,7 +53,7 @@ impl From<RequestParams> for APIParams {
temperature: params.temperature,
do_sample: params.do_sample,
top_p: params.top_p,
stop: vec![params.stop_token.clone()],
stop: params.stop_tokens,
return_full_text: false,
}
}
@ -123,6 +124,7 @@ struct CompletionParams {
fim: FimParams,
api_token: Option<String>,
model: String,
model_eos: String,
tokenizer_path: Option<String>,
context_window: usize,
}
@ -243,11 +245,11 @@ async fn request_completion(
}
}
fn parse_generations(generations: Vec<Generation>, stop_token: &str) -> Vec<Completion> {
fn parse_generations(generations: Vec<Generation>, eos: &str) -> Vec<Completion> {
generations
.into_iter()
.map(|g| Completion {
generated_text: g.generated_text.replace(stop_token, ""),
generated_text: g.generated_text.replace(eos, ""),
})
.collect()
}
@ -352,7 +354,6 @@ impl Backend {
tokenizer,
params.context_window,
)?;
let stop_token = params.request_params.stop_token.clone();
let result = request_completion(
&self.http_client,
&params.model,
@ -362,7 +363,7 @@ impl Backend {
)
.await?;
Ok(parse_generations(result, &stop_token))
Ok(parse_generations(result, &params.model_eos))
}
}