From 51b2fd299cd568093ce796aef3e7e37ae656b02a Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Fri, 19 Jul 2024 11:19:20 -0700 Subject: [PATCH] adjust openai chat msg processing (#5729) --- openai/openai.go | 7 +++---- openai/openai_test.go | 8 ++++++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/openai/openai.go b/openai/openai.go index 01864e48..93b63296 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -351,7 +351,6 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { case string: messages = append(messages, api.Message{Role: msg.Role, Content: content}) case []any: - message := api.Message{Role: msg.Role} for _, c := range content { data, ok := c.(map[string]any) if !ok { @@ -363,7 +362,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { if !ok { return nil, fmt.Errorf("invalid message format") } - message.Content = text + messages = append(messages, api.Message{Role: msg.Role, Content: text}) case "image_url": var url string if urlMap, ok := data["image_url"].(map[string]any); ok { @@ -395,12 +394,12 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { if err != nil { return nil, fmt.Errorf("invalid message format") } - message.Images = append(message.Images, img) + + messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}}) default: return nil, fmt.Errorf("invalid message format") } } - messages = append(messages, message) default: if msg.ToolCalls == nil { return nil, fmt.Errorf("invalid message content type: %T", content) diff --git a/openai/openai_test.go b/openai/openai_test.go index 046ee69c..ad056e6d 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -161,8 +161,12 @@ func TestMiddlewareRequests(t *testing.T) { img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):]) - if !bytes.Equal(chatReq.Messages[0].Images[0], img) { - t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0]) + if chatReq.Messages[1].Role != "user" { + t.Fatalf("expected 'user', got %s", chatReq.Messages[1].Role) + } + + if !bytes.Equal(chatReq.Messages[1].Images[0], img) { + t.Fatalf("expected image encoding, got %s", chatReq.Messages[1].Images[0]) } }, },