diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index 9ba6505..b5970b6 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -24,7 +24,7 @@ struct RequestParams { temperature: f32, do_sample: bool, top_p: f32, - stop_token: String, + stop_tokens: Option>, } #[derive(Debug, Deserialize, Serialize)] @@ -41,7 +41,8 @@ struct APIParams { temperature: f32, do_sample: bool, top_p: f32, - stop: Vec, + #[serde(skip_serializing)] + stop: Option>, return_full_text: bool, } @@ -52,7 +53,7 @@ impl From for APIParams { temperature: params.temperature, do_sample: params.do_sample, top_p: params.top_p, - stop: vec![params.stop_token.clone()], + stop: params.stop_tokens, return_full_text: false, } } @@ -123,6 +124,7 @@ struct CompletionParams { fim: FimParams, api_token: Option, model: String, + model_eos: String, tokenizer_path: Option, context_window: usize, } @@ -243,11 +245,11 @@ async fn request_completion( } } -fn parse_generations(generations: Vec, stop_token: &str) -> Vec { +fn parse_generations(generations: Vec, eos: &str) -> Vec { generations .into_iter() .map(|g| Completion { - generated_text: g.generated_text.replace(stop_token, ""), + generated_text: g.generated_text.replace(eos, ""), }) .collect() } @@ -352,7 +354,6 @@ impl Backend { tokenizer, params.context_window, )?; - let stop_token = params.request_params.stop_token.clone(); let result = request_completion( &self.http_client, ¶ms.model, @@ -362,7 +363,7 @@ impl Backend { ) .await?; - Ok(parse_generations(result, &stop_token)) + Ok(parse_generations(result, ¶ms.model_eos)) } }