diff --git a/llm/dyn_ext_server.go b/llm/dyn_ext_server.go index 782fd382..f7e19a7b 100644 --- a/llm/dyn_ext_server.go +++ b/llm/dyn_ext_server.go @@ -161,13 +161,10 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error { resp := newExtServerResp(128) defer freeExtServerResp(resp) - var imageData []ImageData + if len(predict.Images) > 0 { - for cnt, i := range predict.Images { - imageData = append(imageData, ImageData{Data: i, ID: cnt}) - } + slog.Info(fmt.Sprintf("loaded %d images", len(predict.Images))) } - slog.Info(fmt.Sprintf("loaded %d images", len(imageData))) request := map[string]any{ "prompt": predict.Prompt, @@ -189,7 +186,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu "penalize_nl": predict.Options.PenalizeNewline, "seed": predict.Options.Seed, "stop": predict.Options.Stop, - "image_data": imageData, + "image_data": predict.Images, "cache_prompt": true, } diff --git a/llm/llama.go b/llm/llama.go index 80b4f75b..a5d2036a 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -62,7 +62,7 @@ const maxRetries = 3 type PredictOpts struct { Prompt string Format string - Images []api.ImageData + Images []ImageData Options api.Options } diff --git a/server/images.go b/server/images.go index 503dd8e2..6f59d72d 100644 --- a/server/images.go +++ b/server/images.go @@ -63,6 +63,7 @@ type PromptVars struct { Prompt string Response string First bool + Images []llm.ImageData } // extractParts extracts the parts of the template before and after the {{.Response}} node. @@ -147,15 +148,13 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) { } type ChatHistory struct { - Prompts []PromptVars - CurrentImages []api.ImageData - LastSystem string + Prompts []PromptVars + LastSystem string } // ChatPrompts returns a list of formatted chat prompts from a list of messages func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { // build the prompt from the list of messages - var currentImages []api.ImageData lastSystem := m.System currentVars := PromptVars{ First: true, @@ -163,6 +162,7 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { } prompts := []PromptVars{} + var images []llm.ImageData for _, msg := range msgs { switch strings.ToLower(msg.Role) { @@ -179,8 +179,18 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { prompts = append(prompts, currentVars) currentVars = PromptVars{} } + currentVars.Prompt = msg.Content - currentImages = msg.Images + for i := range msg.Images { + id := len(images) + i + currentVars.Prompt += fmt.Sprintf(" [img-%d]", id) + currentVars.Images = append(currentVars.Images, llm.ImageData{ + ID: id, + Data: msg.Images[i], + }) + } + + images = append(images, currentVars.Images...) case "assistant": currentVars.Response = msg.Content prompts = append(prompts, currentVars) @@ -196,9 +206,8 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { } return &ChatHistory{ - Prompts: prompts, - CurrentImages: currentImages, - LastSystem: lastSystem, + Prompts: prompts, + LastSystem: lastSystem, }, nil } diff --git a/server/images_test.go b/server/images_test.go index 0f63a19b..4c2a7cac 100644 --- a/server/images_test.go +++ b/server/images_test.go @@ -238,18 +238,37 @@ func chatHistoryEqual(a, b ChatHistory) bool { if len(a.Prompts) != len(b.Prompts) { return false } - if len(a.CurrentImages) != len(b.CurrentImages) { - return false - } for i, v := range a.Prompts { - if v != b.Prompts[i] { + + if v.First != b.Prompts[i].First { return false } - } - for i, v := range a.CurrentImages { - if !bytes.Equal(v, b.CurrentImages[i]) { + + if v.Response != b.Prompts[i].Response { return false } + + if v.Prompt != b.Prompts[i].Prompt { + return false + } + + if v.System != b.Prompts[i].System { + return false + } + + if len(v.Images) != len(b.Prompts[i].Images) { + return false + } + + for j, img := range v.Images { + if img.ID != b.Prompts[i].Images[j].ID { + return false + } + + if !bytes.Equal(img.Data, b.Prompts[i].Images[j].Data) { + return false + } + } } return a.LastSystem == b.LastSystem } diff --git a/server/routes.go b/server/routes.go index 1a4726c1..7d1f9dfb 100644 --- a/server/routes.go +++ b/server/routes.go @@ -244,6 +244,10 @@ func GenerateHandler(c *gin.Context) { promptVars.System = model.System } + for i := range req.Images { + promptVars.Prompt += fmt.Sprintf(" [img-%d]", i) + } + p, err := model.PreResponsePrompt(promptVars) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -308,11 +312,19 @@ func GenerateHandler(c *gin.Context) { ch <- resp } + var images []llm.ImageData + for i := range req.Images { + images = append(images, llm.ImageData{ + ID: i, + Data: req.Images[i], + }) + } + // Start prediction predictReq := llm.PredictOpts{ Prompt: prompt, Format: req.Format, - Images: req.Images, + Images: images, Options: opts, } if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { @@ -1139,7 +1151,8 @@ func ChatHandler(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - prompt, err := trimmedPrompt(c.Request.Context(), chat, model) + + prompt, images, err := trimmedPrompt(c.Request.Context(), chat, model) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -1182,7 +1195,7 @@ func ChatHandler(c *gin.Context) { predictReq := llm.PredictOpts{ Prompt: prompt, Format: req.Format, - Images: chat.CurrentImages, + Images: images, Options: opts, } if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { @@ -1229,34 +1242,47 @@ type promptInfo struct { // trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length, // while preserving the most recent system message. -func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, error) { +func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, []llm.ImageData, error) { if len(chat.Prompts) == 0 { - return "", nil + return "", nil, nil } var promptsToAdd []promptInfo var totalTokenLength int var systemPromptIncluded bool + var images []llm.ImageData // reverse iterate through the prompts to build the prompt string in a way that fits the max context length for i := len(chat.Prompts) - 1; i >= 0; i-- { - promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1) + prompt := chat.Prompts[i] + promptText, err := promptString(model, prompt, i == len(chat.Prompts)-1) if err != nil { - return "", err + return "", nil, err } encodedTokens, err := loaded.runner.Encode(ctx, promptText) if err != nil { - return "", err + return "", nil, err } if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 { break // reached max context length, stop adding more prompts } + for j := range prompt.Images { + if totalTokenLength+768 > loaded.NumCtx { + // this decreases the token length but overestimating is fine + prompt.Prompt = strings.ReplaceAll(prompt.Prompt, fmt.Sprintf(" [img-%d]", prompt.Images[j].ID), "") + continue + } + + totalTokenLength += 768 + images = append(images, prompt.Images[j]) + } + totalTokenLength += len(encodedTokens) - systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != "" - promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)}) + systemPromptIncluded = systemPromptIncluded || prompt.System != "" + promptsToAdd = append(promptsToAdd, promptInfo{vars: prompt, tokenLen: len(encodedTokens)}) } // ensure the system prompt is included, if not already @@ -1264,7 +1290,7 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string var err error promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd) if err != nil { - return "", err + return "", nil, err } } @@ -1275,11 +1301,12 @@ func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string for i, prompt := range promptsToAdd { promptText, err := promptString(model, prompt.vars, i == 0) if err != nil { - return "", err + return "", nil, err } result = promptText + result } - return result, nil + + return result, images, nil } // promptString applies the model template to the prompt diff --git a/server/routes_test.go b/server/routes_test.go index 9c53dc20..2a0308b8 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -455,7 +455,8 @@ func Test_ChatPrompt(t *testing.T) { NumCtx: tt.numCtx, }, } - got, err := trimmedPrompt(context.Background(), tt.chat, m) + // TODO: add tests for trimming images + got, _, err := trimmedPrompt(context.Background(), tt.chat, m) if tt.wantErr != "" { if err == nil { t.Errorf("ChatPrompt() expected error, got nil")