refactor: use return full text (#3)
* feat: add completion context * feat: remove unnecessary completion context
This commit is contained in:
parent
b3b7bb2b4e
commit
6f455eca18
37
src/main.rs
37
src/main.rs
|
@ -38,10 +38,33 @@ struct FimParams {
|
||||||
suffix: String,
|
suffix: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct APIParams {
|
||||||
|
max_new_tokens: u32,
|
||||||
|
temperature: f32,
|
||||||
|
do_sample: bool,
|
||||||
|
top_p: f32,
|
||||||
|
stop: Vec<String>,
|
||||||
|
return_full_text: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<RequestParams> for APIParams {
|
||||||
|
fn from(params: RequestParams) -> Self {
|
||||||
|
Self {
|
||||||
|
max_new_tokens: params.max_new_tokens,
|
||||||
|
temperature: params.temperature,
|
||||||
|
do_sample: params.do_sample,
|
||||||
|
top_p: params.top_p,
|
||||||
|
stop: vec![params.stop_token.clone()],
|
||||||
|
return_full_text: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
struct APIRequest {
|
struct APIRequest {
|
||||||
inputs: String,
|
inputs: String,
|
||||||
parameters: RequestParams,
|
parameters: APIParams,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
@ -185,7 +208,7 @@ async fn request_completion(
|
||||||
) -> Result<Vec<Generation>> {
|
) -> Result<Vec<Generation>> {
|
||||||
let mut req = http_client.post(model).json(&APIRequest {
|
let mut req = http_client.post(model).json(&APIRequest {
|
||||||
inputs: prompt,
|
inputs: prompt,
|
||||||
parameters: request_params,
|
parameters: request_params.into(),
|
||||||
});
|
});
|
||||||
|
|
||||||
if let Some(api_token) = api_token.clone() {
|
if let Some(api_token) = api_token.clone() {
|
||||||
|
@ -200,15 +223,11 @@ async fn request_completion(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_generations(
|
fn parse_generations(generations: Vec<Generation>, stop_token: &str) -> Vec<Completion> {
|
||||||
generations: Vec<Generation>,
|
|
||||||
prompt: &str,
|
|
||||||
stop_token: &str,
|
|
||||||
) -> Vec<Completion> {
|
|
||||||
generations
|
generations
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|g| Completion {
|
.map(|g| Completion {
|
||||||
generated_text: g.generated_text.replace(prompt, "").replace(stop_token, ""),
|
generated_text: g.generated_text.replace(stop_token, ""),
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
@ -243,7 +262,7 @@ impl Backend {
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok(parse_generations(result, &prompt, &stop_token))
|
Ok(parse_generations(result, &stop_token))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue