diff --git a/api/types.go b/api/types.go index 340c91c4..2217f02e 100644 --- a/api/types.go +++ b/api/types.go @@ -57,8 +57,9 @@ type ChatRequest struct { } type Message struct { - Role string `json:"role"` // one of ["system", "user", "assistant"] - Content string `json:"content"` + Role string `json:"role"` // one of ["system", "user", "assistant"] + Content string `json:"content"` + Images []ImageData `json:"images, omitempty"` } type ChatResponse struct { diff --git a/server/images.go b/server/images.go index 8735f4b4..9442bd74 100644 --- a/server/images.go +++ b/server/images.go @@ -86,9 +86,10 @@ func (m *Model) Prompt(p PromptVars) (string, error) { return prompt.String(), nil } -func (m *Model) ChatPrompt(msgs []api.Message) (string, error) { +func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) { // build the prompt from the list of messages var prompt strings.Builder + var currentImages []api.ImageData currentVars := PromptVars{ First: true, } @@ -108,35 +109,36 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, error) { case "system": if currentVars.System != "" { if err := writePrompt(); err != nil { - return "", err + return "", nil, err } } currentVars.System = msg.Content case "user": if currentVars.Prompt != "" { if err := writePrompt(); err != nil { - return "", err + return "", nil, err } } currentVars.Prompt = msg.Content + currentImages = msg.Images case "assistant": currentVars.Response = msg.Content if err := writePrompt(); err != nil { - return "", err + return "", nil, err } default: - return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role) + return "", nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role) } } // Append the last set of vars if they are non-empty if currentVars.Prompt != "" || currentVars.System != "" { if err := writePrompt(); err != nil { - return "", err + return "", nil, err } } - return prompt.String(), nil + return prompt.String(), currentImages, nil } type ManifestV2 struct { diff --git a/server/routes.go b/server/routes.go index 04e32e2f..6df7d2e4 100644 --- a/server/routes.go +++ b/server/routes.go @@ -994,7 +994,7 @@ func ChatHandler(c *gin.Context) { checkpointLoaded := time.Now() - prompt, err := model.ChatPrompt(req.Messages) + prompt, images, err := model.ChatPrompt(req.Messages) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return @@ -1037,6 +1037,7 @@ func ChatHandler(c *gin.Context) { Format: req.Format, CheckpointStart: checkpointStart, CheckpointLoaded: checkpointLoaded, + Images: images, } if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { ch <- gin.H{"error": err.Error()}