diff --git a/server/images.go b/server/images.go index 503dd8e2..26b59e0d 100644 --- a/server/images.go +++ b/server/images.go @@ -179,8 +179,13 @@ func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { prompts = append(prompts, currentVars) currentVars = PromptVars{} } + currentVars.Prompt = msg.Content - currentImages = msg.Images + for i := range msg.Images { + currentVars.Prompt += fmt.Sprintf(" [img-%d]", len(currentImages)+i) + } + + currentImages = append(currentImages, msg.Images...) case "assistant": currentVars.Response = msg.Content prompts = append(prompts, currentVars) diff --git a/server/routes.go b/server/routes.go index 01a898a8..4dc1be5b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -244,6 +244,10 @@ func GenerateHandler(c *gin.Context) { promptVars.System = model.System } + for i := range req.Images { + promptVars.Prompt += fmt.Sprintf(" [img-%d]", i) + } + p, err := model.PreResponsePrompt(promptVars) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})