From 1f9078d6aef1ce5de4efdfcc12b56405d8c1fa0d Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Mon, 12 Feb 2024 11:16:20 -0800 Subject: [PATCH] Check image filetype in api handlers (#2467) --- cmd/interactive.go | 2 +- server/routes.go | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/cmd/interactive.go b/cmd/interactive.go index fe450cf9..c9836372 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -625,7 +625,7 @@ func getImageData(filePath string) ([]byte, error) { } contentType := http.DetectContentType(buf) - allowedTypes := []string{"image/jpeg", "image/jpg", "image/svg+xml", "image/png"} + allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"} if !slices.Contains(allowedTypes, contentType) { return nil, fmt.Errorf("invalid image type: %s", contentType) } diff --git a/server/routes.go b/server/routes.go index 3da62a07..9abaea42 100644 --- a/server/routes.go +++ b/server/routes.go @@ -22,6 +22,7 @@ import ( "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" + "golang.org/x/exp/slices" "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/gpu" @@ -136,6 +137,12 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options return opts, nil } +func isSupportedImageType(image []byte) bool { + contentType := http.DetectContentType(image) + allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"} + return slices.Contains(allowedTypes, contentType) +} + func GenerateHandler(c *gin.Context) { loaded.mu.Lock() defer loaded.mu.Unlock() @@ -166,6 +173,13 @@ func GenerateHandler(c *gin.Context) { return } + for _, img := range req.Images { + if !isSupportedImageType(img) { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"}) + return + } + } + model, err := GetModel(req.Model) if err != nil { var pErr *fs.PathError @@ -1103,6 +1117,15 @@ func ChatHandler(c *gin.Context) { return } + for _, msg := range req.Messages { + for _, img := range msg.Images { + if !isSupportedImageType(img) { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"}) + return + } + } + } + model, err := GetModel(req.Model) if err != nil { var pErr *fs.PathError