refactor: use return full text (#3)
* feat: add completion context * feat: remove unnecessary completion context
This commit is contained in:
parent
b3b7bb2b4e
commit
6f455eca18
37
src/main.rs
37
src/main.rs
|
@ -38,10 +38,33 @@ struct FimParams {
|
|||
suffix: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct APIParams {
|
||||
max_new_tokens: u32,
|
||||
temperature: f32,
|
||||
do_sample: bool,
|
||||
top_p: f32,
|
||||
stop: Vec<String>,
|
||||
return_full_text: bool,
|
||||
}
|
||||
|
||||
impl From<RequestParams> for APIParams {
|
||||
fn from(params: RequestParams) -> Self {
|
||||
Self {
|
||||
max_new_tokens: params.max_new_tokens,
|
||||
temperature: params.temperature,
|
||||
do_sample: params.do_sample,
|
||||
top_p: params.top_p,
|
||||
stop: vec![params.stop_token.clone()],
|
||||
return_full_text: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct APIRequest {
|
||||
inputs: String,
|
||||
parameters: RequestParams,
|
||||
parameters: APIParams,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
@ -185,7 +208,7 @@ async fn request_completion(
|
|||
) -> Result<Vec<Generation>> {
|
||||
let mut req = http_client.post(model).json(&APIRequest {
|
||||
inputs: prompt,
|
||||
parameters: request_params,
|
||||
parameters: request_params.into(),
|
||||
});
|
||||
|
||||
if let Some(api_token) = api_token.clone() {
|
||||
|
@ -200,15 +223,11 @@ async fn request_completion(
|
|||
}
|
||||
}
|
||||
|
||||
fn parse_generations(
|
||||
generations: Vec<Generation>,
|
||||
prompt: &str,
|
||||
stop_token: &str,
|
||||
) -> Vec<Completion> {
|
||||
fn parse_generations(generations: Vec<Generation>, stop_token: &str) -> Vec<Completion> {
|
||||
generations
|
||||
.into_iter()
|
||||
.map(|g| Completion {
|
||||
generated_text: g.generated_text.replace(prompt, "").replace(stop_token, ""),
|
||||
generated_text: g.generated_text.replace(stop_token, ""),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
@ -243,7 +262,7 @@ impl Backend {
|
|||
)
|
||||
.await?;
|
||||
|
||||
Ok(parse_generations(result, &prompt, &stop_token))
|
||||
Ok(parse_generations(result, &stop_token))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue