upgrade ollama backend support to use the chat completions api

This commit is contained in:
Gered 2024-11-10 18:52:20 -05:00
parent 29cec6445c
commit 8cde5ce43f
2 changed files with 28 additions and 14 deletions

View file

@ -99,14 +99,26 @@ fn parse_llamacpp_text(text: &str) -> Result<Vec<Generation>> {
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct OllamaGeneration { struct OllamaGenerationChoiceMessage {
response: String, role: String,
content: String,
} }
impl From<OllamaGeneration> for Generation { #[derive(Debug, Serialize, Deserialize)]
fn from(value: OllamaGeneration) -> Self { struct OllamaGenerationChoice {
index: i32,
message: OllamaGenerationChoiceMessage,
}
#[derive(Debug, Serialize, Deserialize)]
struct OllamaGeneration {
choices: Vec<OllamaGenerationChoice>
}
impl From<OllamaGenerationChoice> for Generation {
fn from(value: OllamaGenerationChoice) -> Self {
Generation { Generation {
generated_text: value.response, generated_text: value.message.content,
} }
} }
} }
@ -124,7 +136,7 @@ fn build_ollama_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMa
fn parse_ollama_text(text: &str) -> Result<Vec<Generation>> { fn parse_ollama_text(text: &str) -> Result<Vec<Generation>> {
match serde_json::from_str(text)? { match serde_json::from_str(text)? {
OllamaAPIResponse::Generation(gen) => Ok(vec![gen.into()]), OllamaAPIResponse::Generation(completion) => Ok(completion.choices.into_iter().map(|x| x.into()).collect()),
OllamaAPIResponse::Error(err) => Err(Error::Ollama(err)), OllamaAPIResponse::Error(err) => Err(Error::Ollama(err)),
} }
} }
@ -227,7 +239,9 @@ pub(crate) fn build_body(
request_body.insert("prompt".to_owned(), Value::String(prompt)); request_body.insert("prompt".to_owned(), Value::String(prompt));
} }
Backend::Ollama { .. } | Backend::OpenAi { .. } => { Backend::Ollama { .. } | Backend::OpenAi { .. } => {
request_body.insert("prompt".to_owned(), Value::String(prompt)); request_body.insert("messages".to_owned(), json!([
{ "role": "user", "content": prompt }
]));
request_body.insert("model".to_owned(), Value::String(model)); request_body.insert("model".to_owned(), Value::String(model));
request_body.insert("stream".to_owned(), Value::Bool(false)); request_body.insert("stream".to_owned(), Value::Bool(false));
} }

View file

@ -441,19 +441,19 @@ fn build_url(backend: Backend, model: &str, disable_url_path_completion: bool) -
} }
} }
Backend::Ollama { mut url } => { Backend::Ollama { mut url } => {
if url.ends_with("/api/generate") { if url.ends_with("/v1/chat/completions") {
url url
} else if url.ends_with("/api/") { } else if url.ends_with("/v1/") {
url.push_str("generate"); url.push_str("chat/completions");
url url
} else if url.ends_with("/api") { } else if url.ends_with("/v1") {
url.push_str("/generate"); url.push_str("/chat/completions");
url url
} else if url.ends_with('/') { } else if url.ends_with('/') {
url.push_str("api/generate"); url.push_str("v1/chat/completions");
url url
} else { } else {
url.push_str("/api/generate"); url.push_str("/v1/chat/completions");
url url
} }
} }