diff --git a/server/prompt.go b/server/prompt.go index 6b684963..88da5b6b 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -121,13 +121,15 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str p = prompt{} } - p.Prompt = msg.Content - + var sb strings.Builder for range msg.Images { - p.Prompt += fmt.Sprintf(" [img-%d]", imgId) + fmt.Fprintf(&sb, "[img-%d] ", imgId) p.images = append(p.images, imgId) imgId += 1 } + + sb.WriteString(msg.Content) + p.Prompt = sb.String() case "assistant": if p.Response != "" { prompts = append(prompts, p) diff --git a/server/prompt_test.go b/server/prompt_test.go index 75c02d7b..500ee522 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -155,7 +155,7 @@ func TestChatPrompt(t *testing.T) { {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}}, }, window: 1024, - want: "You are a Wizard. Hello [img-0]", + want: "You are a Wizard. [img-0] Hello", }, { name: "images truncated", @@ -165,7 +165,7 @@ func TestChatPrompt(t *testing.T) { {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}}, }, window: 1024, - want: "You are a Wizard. Hello [img-1]", + want: "You are a Wizard. [img-0] [img-1] Hello", }, { name: "empty list", @@ -198,7 +198,7 @@ func TestChatPrompt(t *testing.T) { } if got != tc.want { - t.Errorf("got = %v, want %v", got, tc.want) + t.Errorf("got: %q, want: %q", got, tc.want) } }) } diff --git a/server/routes.go b/server/routes.go index dd14d4f8..cd53d103 100644 --- a/server/routes.go +++ b/server/routes.go @@ -250,6 +250,19 @@ func GenerateHandler(c *gin.Context) { slog.Debug("generate handler", "system", req.System) var sb strings.Builder + for i := range req.Images { + fmt.Fprintf(&sb, "[img-%d] ", i) + } + + sb.WriteString(req.Prompt) + + p, err := Prompt(req.Template, req.System, sb.String(), "", true) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + sb.Reset() if req.Context != nil { prev, err := loaded.runner.Decode(c.Request.Context(), req.Context) if err != nil { @@ -260,18 +273,6 @@ func GenerateHandler(c *gin.Context) { sb.WriteString(prev) } - // write image tags - // TODO: limit the number of images to fit in the context similar to the chat endpoint - for i := range req.Images { - req.Prompt += fmt.Sprintf(" [img-%d]", i) - } - - p, err := Prompt(req.Template, req.System, req.Prompt, "", true) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - sb.WriteString(p) prompt = sb.String()