From 910e9401d0068190137e0ddabd0c2b216bfea6f2 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Mon, 11 Dec 2023 13:56:22 -0800 Subject: [PATCH] Multimodal support (#1216) --------- Co-authored-by: Matt Apperson --- api/types.go | 47 +++++++++------ cmd/cmd.go | 151 +++++++++++++++++++++++++++++++++++++++++++++- docs/modelfile.md | 1 + llm/llama.go | 16 +++++ server/images.go | 3 + server/routes.go | 45 +++++++++++--- 6 files changed, 235 insertions(+), 28 deletions(-) diff --git a/api/types.go b/api/types.go index 3fbc5829..340c91c4 100644 --- a/api/types.go +++ b/api/types.go @@ -31,15 +31,18 @@ func (e StatusError) Error() string { } } +type ImageData []byte + type GenerateRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - System string `json:"system"` - Template string `json:"template"` - Context []int `json:"context,omitempty"` - Stream *bool `json:"stream,omitempty"` - Raw bool `json:"raw,omitempty"` - Format string `json:"format"` + Model string `json:"model"` + Prompt string `json:"prompt"` + System string `json:"system"` + Template string `json:"template"` + Context []int `json:"context,omitempty"` + Stream *bool `json:"stream,omitempty"` + Raw bool `json:"raw,omitempty"` + Format string `json:"format"` + Images []ImageData `json:"images,omitempty"` Options map[string]interface{} `json:"options"` } @@ -148,11 +151,12 @@ type ShowRequest struct { } type ShowResponse struct { - License string `json:"license,omitempty"` - Modelfile string `json:"modelfile,omitempty"` - Parameters string `json:"parameters,omitempty"` - Template string `json:"template,omitempty"` - System string `json:"system,omitempty"` + License string `json:"license,omitempty"` + Modelfile string `json:"modelfile,omitempty"` + Parameters string `json:"parameters,omitempty"` + Template string `json:"template,omitempty"` + System string `json:"system,omitempty"` + Details ModelDetails `json:"details,omitempty"` } type CopyRequest struct { @@ -188,10 +192,11 @@ type ListResponse struct { } type ModelResponse struct { - Name string `json:"name"` - ModifiedAt time.Time `json:"modified_at"` - Size int64 `json:"size"` - Digest string `json:"digest"` + Name string `json:"name"` + ModifiedAt time.Time `json:"modified_at"` + Size int64 `json:"size"` + Digest string `json:"digest"` + Details ModelDetails `json:"details,omitempty"` } type TokenResponse struct { @@ -209,6 +214,14 @@ type GenerateResponse struct { Metrics } +type ModelDetails struct { + Format string `json:"format"` + Family string `json:"family"` + Families []string `json:"families"` + ParameterSize string `json:"parameter_size"` + QuantizationLevel string `json:"quantization_level"` +} + func (m *Metrics) Summary() { if m.TotalDuration > 0 { fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration) diff --git a/cmd/cmd.go b/cmd/cmd.go index d0b295a6..114080db 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -17,7 +17,9 @@ import ( "os/exec" "os/signal" "path/filepath" + "regexp" "runtime" + "slices" "strings" "syscall" "time" @@ -36,6 +38,8 @@ import ( "github.com/jmorganca/ollama/version" ) +type ImageData []byte + func CreateHandler(cmd *cobra.Command, args []string) error { filename, _ := cmd.Flags().GetString("file") filename, err := filepath.Abs(filename) @@ -418,6 +422,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error { Model: args[0], WordWrap: os.Getenv("TERM") == "xterm-256color", Options: map[string]interface{}{}, + Images: []ImageData{}, } format, err := cmd.Flags().GetString("format") @@ -427,7 +432,6 @@ func RunGenerate(cmd *cobra.Command, args []string) error { opts.Format = format prompts := args[1:] - // prepend stdin to the prompt if provided if !term.IsTerminal(int(os.Stdin.Fd())) { in, err := io.ReadAll(os.Stdin) @@ -466,6 +470,7 @@ type generateOptions struct { Format string System string Template string + Images []ImageData Options map[string]interface{} } @@ -551,6 +556,10 @@ func generate(cmd *cobra.Command, opts generateOptions) error { return nil } + images := make([]api.ImageData, 0) + for _, i := range opts.Images { + images = append(images, api.ImageData(i)) + } request := api.GenerateRequest{ Model: opts.Model, Prompt: opts.Prompt, @@ -559,6 +568,7 @@ func generate(cmd *cobra.Command, opts generateOptions) error { System: opts.System, Template: opts.Template, Options: opts.Options, + Images: images, } if err := client.Generate(ctx, &request, fn); err != nil { @@ -585,7 +595,9 @@ func generate(cmd *cobra.Command, opts generateOptions) error { latest.Summary() } - cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)) + ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context) + cmd.SetContext(ctx) + return nil } @@ -598,11 +610,31 @@ const ( MultilineTemplate ) +func modelIsMultiModal(cmd *cobra.Command, name string) bool { + // get model details + client, err := api.ClientFromEnvironment() + if err != nil { + fmt.Println("error: couldn't connect to ollama server") + return false + } + + req := api.ShowRequest{Name: name} + resp, err := client.Show(cmd.Context(), &req) + if err != nil { + return false + } + + return slices.Contains(resp.Details.Families, "clip") +} + func generateInteractive(cmd *cobra.Command, opts generateOptions) error { + multiModal := modelIsMultiModal(cmd, opts.Model) + // load the model loadOpts := generateOptions{ Model: opts.Model, Prompt: "", + Images: []ImageData{}, } if err := generate(cmd, loadOpts); err != nil { return err @@ -902,6 +934,26 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error { if len(prompt) > 0 && multiline == MultilineNone { opts.Prompt = prompt + if multiModal { + newPrompt, images, err := extractFileNames(prompt) + if err != nil { + return err + } + opts.Prompt = newPrompt + + // reset the context if we find another image + if len(images) > 0 { + opts.Images = images + ctx := cmd.Context() + ctx = context.WithValue(ctx, generateContextKey("context"), []int{}) + cmd.SetContext(ctx) + } + if len(opts.Images) == 0 { + fmt.Println("This model requires you to add a jpeg, png, or svg image.\n") + prompt = "" + continue + } + } if err := generate(cmd, opts); err != nil { return err } @@ -911,6 +963,57 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error { } } +func normalizeFilePath(fp string) string { + // Define a map of escaped characters and their replacements + replacements := map[string]string{ + "\\ ": " ", // Escaped space + "\\(": "(", // Escaped left parenthesis + "\\)": ")", // Escaped right parenthesis + "\\[": "[", // Escaped left square bracket + "\\]": "]", // Escaped right square bracket + "\\{": "{", // Escaped left curly brace + "\\}": "}", // Escaped right curly brace + "\\$": "$", // Escaped dollar sign + "\\&": "&", // Escaped ampersand + "\\;": ";", // Escaped semicolon + "\\'": "'", // Escaped single quote + "\\\\": "\\", // Escaped backslash + "\\*": "*", // Escaped asterisk + "\\?": "?", // Escaped question mark + } + + for escaped, actual := range replacements { + fp = strings.ReplaceAll(fp, escaped, actual) + } + return fp +} + +func extractFileNames(input string) (string, []ImageData, error) { + // Regex to match file paths starting with / or ./ and include escaped spaces (\ or %20) + // and followed by more characters and a file extension + regexPattern := `(?:\./|/)[\S\\ ]+?\.(?i:jpg|jpeg|png|svg)\b` + re := regexp.MustCompile(regexPattern) + + filePaths := re.FindAllString(input, -1) + var imgs []ImageData + + for _, fp := range filePaths { + nfp := normalizeFilePath(fp) + data, err := getImageData(nfp) + if err != nil { + if os.IsNotExist(err) { + continue + } + fmt.Printf("Couldn't process image: %q\n", err) + return "", imgs, err + } + fmt.Printf("Added image '%s'\n", nfp) + input = strings.ReplaceAll(input, fp, "") + imgs = append(imgs, data) + } + return input, imgs, nil +} + func RunServer(cmd *cobra.Command, _ []string) error { host, port, err := net.SplitHostPort(os.Getenv("OLLAMA_HOST")) if err != nil { @@ -937,6 +1040,50 @@ func RunServer(cmd *cobra.Command, _ []string) error { return server.Serve(ln, origins) } +func getImageData(filePath string) ([]byte, error) { + file, err := os.Open(filePath) + if err != nil { + return nil, err + } + defer file.Close() + + buf := make([]byte, 512) + _, err = file.Read(buf) + if err != nil { + return nil, err + } + + contentType := http.DetectContentType(buf) + allowedTypes := []string{"image/jpeg", "image/jpg", "image/svg+xml", "image/png"} + if !slices.Contains(allowedTypes, contentType) { + return nil, fmt.Errorf("invalid image type: %s", contentType) + } + + info, err := file.Stat() + if err != nil { + return nil, err + } + + // Check if the file size exceeds 100MB + var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes + if info.Size() > maxSize { + return nil, fmt.Errorf("file size exceeds maximum limit (100MB).") + } + + buf = make([]byte, info.Size()) + _, err = file.Seek(0, 0) + if err != nil { + return nil, err + } + + _, err = io.ReadFull(file, buf) + if err != nil { + return nil, err + } + + return buf, nil +} + func initializeKeypair() error { home, err := os.UserHomeDir() if err != nil { diff --git a/docs/modelfile.md b/docs/modelfile.md index 6f1d8e88..98941e0b 100644 --- a/docs/modelfile.md +++ b/docs/modelfile.md @@ -150,6 +150,7 @@ PARAMETER | top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 | | top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 | + ### TEMPLATE `TEMPLATE` of the full prompt template to be passed into the model. It may include (optionally) a system prompt and a user's prompt. This is used to create a full custom prompt, and syntax may be model specific. You can usually find the template for a given model in the readme for that model. diff --git a/llm/llama.go b/llm/llama.go index 6ccac861..33052e3c 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -223,8 +223,14 @@ type Running struct { *StatusWriter // captures error messages from the llama runner process } +type ImageData struct { + Data []byte `json:"data"` + ID int `json:"id"` +} + type llama struct { api.Options + ImageData []ImageData Running } @@ -547,6 +553,7 @@ const maxBufferSize = 512 * format.KiloByte type PredictOpts struct { Prompt string Format string + Images []api.ImageData CheckpointStart time.Time CheckpointLoaded time.Time } @@ -564,6 +571,14 @@ type PredictResult struct { } func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error { + imageData := llm.ImageData + if len(predict.Images) > 0 { + for cnt, i := range predict.Images { + imageData = append(imageData, ImageData{Data: i, ID: cnt}) + } + } + log.Printf("loaded %d images", len(imageData)) + request := map[string]any{ "prompt": predict.Prompt, "stream": true, @@ -585,6 +600,7 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred "penalize_nl": llm.PenalizeNewline, "seed": llm.Seed, "stop": llm.Stop, + "image_data": imageData, } if predict.Format == "json" { diff --git a/server/images.go b/server/images.go index 097cbe2f..5ac570de 100644 --- a/server/images.go +++ b/server/images.go @@ -46,6 +46,7 @@ type Model struct { System string License []string Digest string + Size int64 Options map[string]interface{} } @@ -242,6 +243,7 @@ func GetModel(name string) (*Model, error) { Digest: digest, Template: "{{ .Prompt }}", License: []string{}, + Size: manifest.GetTotalSize(), } filename, err := GetBlobsPath(manifest.Config.Digest) @@ -545,6 +547,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars } } + // xxx - can this be removed? if config.ModelType == "65B" { if gqa, ok := formattedParams["gqa"].(int); ok && gqa == 8 { config.ModelType = "70B" diff --git a/server/routes.go b/server/routes.go index 9716c84e..04e32e2f 100644 --- a/server/routes.go +++ b/server/routes.go @@ -156,9 +156,9 @@ func GenerateHandler(c *gin.Context) { defer loaded.mu.Unlock() checkpointStart := time.Now() - var req api.GenerateRequest err := c.ShouldBindJSON(&req) + switch { case errors.Is(err, io.EOF): c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) @@ -292,6 +292,7 @@ func GenerateHandler(c *gin.Context) { Format: req.Format, CheckpointStart: checkpointStart, CheckpointLoaded: checkpointLoaded, + Images: req.Images, } if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { ch <- gin.H{"error": err.Error()} @@ -614,10 +615,19 @@ func GetModelInfo(name string) (*api.ShowResponse, error) { return nil, err } + modelDetails := api.ModelDetails{ + Format: model.Config.ModelFormat, + Family: model.Config.ModelFamily, + Families: model.Config.ModelFamilies, + ParameterSize: model.Config.ModelType, + QuantizationLevel: model.Config.FileType, + } + resp := &api.ShowResponse{ License: strings.Join(model.License, "\n"), System: model.System, Template: model.Template, + Details: modelDetails, } mf, err := ShowModelfile(model) @@ -667,25 +677,42 @@ func ListModelsHandler(c *gin.Context) { return } + modelResponse := func(modelName string) (api.ModelResponse, error) { + model, err := GetModel(modelName) + if err != nil { + return api.ModelResponse{}, err + } + + modelDetails := api.ModelDetails{ + Format: model.Config.ModelFormat, + Family: model.Config.ModelFamily, + Families: model.Config.ModelFamilies, + ParameterSize: model.Config.ModelType, + QuantizationLevel: model.Config.FileType, + } + + return api.ModelResponse{ + Name: model.ShortName, + Size: model.Size, + Digest: model.Digest, + Details: modelDetails, + }, nil + } + walkFunc := func(path string, info os.FileInfo, _ error) error { if !info.IsDir() { dir, file := filepath.Split(path) dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator)) tag := strings.Join([]string{dir, file}, ":") - mp := ParseModelPath(tag) - manifest, digest, err := GetManifest(mp) + resp, err := modelResponse(tag) if err != nil { log.Printf("skipping file: %s", fp) return nil } - models = append(models, api.ModelResponse{ - Name: mp.GetShortTagname(), - Size: manifest.GetTotalSize(), - Digest: digest, - ModifiedAt: info.ModTime(), - }) + resp.ModifiedAt = info.ModTime() + models = append(models, resp) } return nil