diff --git a/server/images_test.go b/server/images_test.go index 0f63a19b..4c2a7cac 100644 --- a/server/images_test.go +++ b/server/images_test.go @@ -238,18 +238,37 @@ func chatHistoryEqual(a, b ChatHistory) bool { if len(a.Prompts) != len(b.Prompts) { return false } - if len(a.CurrentImages) != len(b.CurrentImages) { - return false - } for i, v := range a.Prompts { - if v != b.Prompts[i] { + + if v.First != b.Prompts[i].First { return false } - } - for i, v := range a.CurrentImages { - if !bytes.Equal(v, b.CurrentImages[i]) { + + if v.Response != b.Prompts[i].Response { return false } + + if v.Prompt != b.Prompts[i].Prompt { + return false + } + + if v.System != b.Prompts[i].System { + return false + } + + if len(v.Images) != len(b.Prompts[i].Images) { + return false + } + + for j, img := range v.Images { + if img.ID != b.Prompts[i].Images[j].ID { + return false + } + + if !bytes.Equal(img.Data, b.Prompts[i].Images[j].Data) { + return false + } + } } return a.LastSystem == b.LastSystem } diff --git a/server/routes_test.go b/server/routes_test.go index 9c53dc20..2a0308b8 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -455,7 +455,8 @@ func Test_ChatPrompt(t *testing.T) { NumCtx: tt.numCtx, }, } - got, err := trimmedPrompt(context.Background(), tt.chat, m) + // TODO: add tests for trimming images + got, _, err := trimmedPrompt(context.Background(), tt.chat, m) if tt.wantErr != "" { if err == nil { t.Errorf("ChatPrompt() expected error, got nil")