fix: stop tokens issue (#9)
This commit is contained in:
parent
f5e6911932
commit
c774ec74fd
|
@ -24,7 +24,7 @@ struct RequestParams {
|
||||||
temperature: f32,
|
temperature: f32,
|
||||||
do_sample: bool,
|
do_sample: bool,
|
||||||
top_p: f32,
|
top_p: f32,
|
||||||
stop_token: String,
|
stop_tokens: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
@ -41,7 +41,8 @@ struct APIParams {
|
||||||
temperature: f32,
|
temperature: f32,
|
||||||
do_sample: bool,
|
do_sample: bool,
|
||||||
top_p: f32,
|
top_p: f32,
|
||||||
stop: Vec<String>,
|
#[serde(skip_serializing)]
|
||||||
|
stop: Option<Vec<String>>,
|
||||||
return_full_text: bool,
|
return_full_text: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,7 +53,7 @@ impl From<RequestParams> for APIParams {
|
||||||
temperature: params.temperature,
|
temperature: params.temperature,
|
||||||
do_sample: params.do_sample,
|
do_sample: params.do_sample,
|
||||||
top_p: params.top_p,
|
top_p: params.top_p,
|
||||||
stop: vec![params.stop_token.clone()],
|
stop: params.stop_tokens,
|
||||||
return_full_text: false,
|
return_full_text: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -123,6 +124,7 @@ struct CompletionParams {
|
||||||
fim: FimParams,
|
fim: FimParams,
|
||||||
api_token: Option<String>,
|
api_token: Option<String>,
|
||||||
model: String,
|
model: String,
|
||||||
|
model_eos: String,
|
||||||
tokenizer_path: Option<String>,
|
tokenizer_path: Option<String>,
|
||||||
context_window: usize,
|
context_window: usize,
|
||||||
}
|
}
|
||||||
|
@ -243,11 +245,11 @@ async fn request_completion(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_generations(generations: Vec<Generation>, stop_token: &str) -> Vec<Completion> {
|
fn parse_generations(generations: Vec<Generation>, eos: &str) -> Vec<Completion> {
|
||||||
generations
|
generations
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|g| Completion {
|
.map(|g| Completion {
|
||||||
generated_text: g.generated_text.replace(stop_token, ""),
|
generated_text: g.generated_text.replace(eos, ""),
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
@ -352,7 +354,6 @@ impl Backend {
|
||||||
tokenizer,
|
tokenizer,
|
||||||
params.context_window,
|
params.context_window,
|
||||||
)?;
|
)?;
|
||||||
let stop_token = params.request_params.stop_token.clone();
|
|
||||||
let result = request_completion(
|
let result = request_completion(
|
||||||
&self.http_client,
|
&self.http_client,
|
||||||
¶ms.model,
|
¶ms.model,
|
||||||
|
@ -362,7 +363,7 @@ impl Backend {
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok(parse_generations(result, &stop_token))
|
Ok(parse_generations(result, ¶ms.model_eos))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue