diff --git a/api/types.go b/api/types.go index d4e385bf..585daf6c 100644 --- a/api/types.go +++ b/api/types.go @@ -171,6 +171,7 @@ type ShowResponse struct { Template string `json:"template,omitempty"` System string `json:"system,omitempty"` Details ModelDetails `json:"details,omitempty"` + Messages []Message `json:"messages,omitempty"` } type CopyRequest struct { @@ -236,6 +237,7 @@ type GenerateResponse struct { } type ModelDetails struct { + ParentModel string `json:"parent_model"` Format string `json:"format"` Family string `json:"family"` Families []string `json:"families"` diff --git a/cmd/cmd.go b/cmd/cmd.go index 76e3c7a9..915fa993 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -458,15 +458,17 @@ func RunGenerate(cmd *cobra.Command, args []string) error { type generateContextKey string type runOptions struct { - Model string - Prompt string - Messages []api.Message - WordWrap bool - Format string - System string - Template string - Images []api.ImageData - Options map[string]interface{} + Model string + ParentModel string + Prompt string + Messages []api.Message + WordWrap bool + Format string + System string + Template string + Images []api.ImageData + Options map[string]interface{} + MultiModal bool } type displayResponseState struct { diff --git a/cmd/interactive.go b/cmd/interactive.go index da3c5b72..d337e555 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -7,12 +7,14 @@ import ( "net/http" "os" "regexp" + "sort" "strings" "github.com/spf13/cobra" "golang.org/x/exp/slices" "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/progress" "github.com/jmorganca/ollama/readline" ) @@ -25,43 +27,75 @@ const ( MultilineTemplate ) -func modelIsMultiModal(cmd *cobra.Command, name string) bool { - // get model details +func loadModel(cmd *cobra.Command, opts *runOptions) error { client, err := api.ClientFromEnvironment() if err != nil { - fmt.Println("error: couldn't connect to ollama server") - return false + return err } - req := api.ShowRequest{Name: name} - resp, err := client.Show(cmd.Context(), &req) + p := progress.NewProgress(os.Stderr) + defer p.StopAndClear() + + spinner := progress.NewSpinner("") + p.Add("", spinner) + + showReq := api.ShowRequest{Name: opts.Model} + showResp, err := client.Show(cmd.Context(), &showReq) if err != nil { - return false + return err + } + opts.MultiModal = slices.Contains(showResp.Details.Families, "clip") + opts.ParentModel = showResp.Details.ParentModel + + if len(showResp.Messages) > 0 { + opts.Messages = append(opts.Messages, showResp.Messages...) } - return slices.Contains(resp.Details.Families, "clip") + chatReq := &api.ChatRequest{ + Model: opts.Model, + Messages: []api.Message{}, + } + err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error { + p.StopAndClear() + if len(opts.Messages) > 0 { + for _, msg := range opts.Messages { + switch msg.Role { + case "user": + fmt.Printf(">>> %s\n", msg.Content) + case "assistant": + state := &displayResponseState{} + displayResponse(msg.Content, opts.WordWrap, state) + fmt.Println() + fmt.Println() + } + } + } + return nil + }) + if err != nil { + return err + } + + return nil } func generateInteractive(cmd *cobra.Command, opts runOptions) error { - multiModal := modelIsMultiModal(cmd, opts.Model) + opts.Messages = make([]api.Message, 0) - // load the model - loadOpts := runOptions{ - Model: opts.Model, - Prompt: "", - Messages: []api.Message{}, - } - if _, err := chat(cmd, loadOpts); err != nil { + err := loadModel(cmd, &opts) + if err != nil { return err } usage := func() { fmt.Fprintln(os.Stderr, "Available Commands:") - fmt.Fprintln(os.Stderr, " /set Set session variables") - fmt.Fprintln(os.Stderr, " /show Show model information") - fmt.Fprintln(os.Stderr, " /bye Exit") - fmt.Fprintln(os.Stderr, " /?, /help Help for a command") - fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts") + fmt.Fprintln(os.Stderr, " /set Set session variables") + fmt.Fprintln(os.Stderr, " /show Show model information") + fmt.Fprintln(os.Stderr, " /load Load a session or model") + fmt.Fprintln(os.Stderr, " /save Save your current session") + fmt.Fprintln(os.Stderr, " /bye Exit") + fmt.Fprintln(os.Stderr, " /?, /help Help for a command") + fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts") fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.") fmt.Fprintln(os.Stderr, "") @@ -140,7 +174,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { var sb strings.Builder var multiline MultilineState - opts.Messages = make([]api.Message, 0) for { line, err := scanner.Readline() @@ -203,6 +236,44 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { if err := ListHandler(cmd, args[1:]); err != nil { return err } + case strings.HasPrefix(line, "/load"): + args := strings.Fields(line) + if len(args) != 2 { + fmt.Println("Usage:\n /load ") + continue + } + opts.Model = args[1] + opts.Messages = []api.Message{} + fmt.Printf("Loading model '%s'\n", opts.Model) + if err := loadModel(cmd, &opts); err != nil { + return err + } + continue + case strings.HasPrefix(line, "/save"): + args := strings.Fields(line) + if len(args) != 2 { + fmt.Println("Usage:\n /save ") + continue + } + + client, err := api.ClientFromEnvironment() + if err != nil { + fmt.Println("error: couldn't connect to ollama server") + return err + } + + req := &api.CreateRequest{ + Name: args[1], + Modelfile: buildModelfile(opts), + } + fn := func(resp api.ProgressResponse) error { return nil } + err = client.Create(cmd.Context(), req, fn) + if err != nil { + fmt.Println("error: couldn't save model") + return err + } + fmt.Printf("Created new model '%s'\n", args[1]) + continue case strings.HasPrefix(line, "/set"): args := strings.Fields(line) if len(args) > 1 { @@ -389,7 +460,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { args := strings.Fields(line) isFile := false - if multiModal { + if opts.MultiModal { for _, f := range extractFileNames(line) { if strings.HasPrefix(f, args[0]) { isFile = true @@ -411,7 +482,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { if sb.Len() > 0 && multiline == MultilineNone { newMessage := api.Message{Role: "user", Content: sb.String()} - if multiModal { + if opts.MultiModal { msg, images, err := extractFileData(sb.String()) if err != nil { return err @@ -454,6 +525,38 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { } } +func buildModelfile(opts runOptions) string { + var mf strings.Builder + model := opts.ParentModel + if model == "" { + model = opts.Model + } + fmt.Fprintf(&mf, "FROM %s\n", model) + if opts.System != "" { + fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System) + } + + if opts.Template != "" { + fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template) + } + + keys := make([]string, 0) + for k := range opts.Options { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + fmt.Fprintf(&mf, "PARAMETER %s %v\n", k, opts.Options[k]) + } + fmt.Fprintln(&mf) + + for _, msg := range opts.Messages { + fmt.Fprintf(&mf, "MESSAGE %s \"\"\"%s\"\"\"\n", msg.Role, msg.Content) + } + + return mf.String() +} + func normalizeFilePath(fp string) string { // Define a map of escaped characters and their replacements replacements := map[string]string{ diff --git a/cmd/interactive_test.go b/cmd/interactive_test.go index 1bd5058a..19e43287 100644 --- a/cmd/interactive_test.go +++ b/cmd/interactive_test.go @@ -1,9 +1,13 @@ package cmd import ( + "bytes" "testing" + "text/template" "github.com/stretchr/testify/assert" + + "github.com/jmorganca/ollama/api" ) func TestExtractFilenames(t *testing.T) { @@ -49,3 +53,64 @@ d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8 assert.Contains(t, res[9], "ten.svg") assert.Contains(t, res[9], "E:") } + +func TestModelfileBuilder(t *testing.T) { + opts := runOptions{ + Model: "hork", + System: "You are part horse and part shark, but all hork. Do horklike things", + Template: "This is a template.", + Messages: []api.Message{ + {Role: "user", Content: "Hey there hork!"}, + {Role: "assistant", Content: "Yes it is true, I am half horse, half shark."}, + }, + Options: map[string]interface{}{}, + } + + opts.Options["temperature"] = 0.9 + opts.Options["seed"] = 42 + opts.Options["penalize_newline"] = false + opts.Options["stop"] = []string{"hi", "there"} + + mf := buildModelfile(opts) + expectedModelfile := `FROM {{.Model}} +SYSTEM """{{.System}}""" +TEMPLATE """{{.Template}}""" +PARAMETER penalize_newline false +PARAMETER seed 42 +PARAMETER stop [hi there] +PARAMETER temperature 0.9 + +MESSAGE user """Hey there hork!""" +MESSAGE assistant """Yes it is true, I am half horse, half shark.""" +` + + tmpl, err := template.New("").Parse(expectedModelfile) + assert.Nil(t, err) + + var buf bytes.Buffer + err = tmpl.Execute(&buf, opts) + assert.Nil(t, err) + assert.Equal(t, buf.String(), mf) + + opts.ParentModel = "horseshark" + mf = buildModelfile(opts) + expectedModelfile = `FROM {{.ParentModel}} +SYSTEM """{{.System}}""" +TEMPLATE """{{.Template}}""" +PARAMETER penalize_newline false +PARAMETER seed 42 +PARAMETER stop [hi there] +PARAMETER temperature 0.9 + +MESSAGE user """Hey there hork!""" +MESSAGE assistant """Yes it is true, I am half horse, half shark.""" +` + + tmpl, err = template.New("").Parse(expectedModelfile) + assert.Nil(t, err) + + var parentBuf bytes.Buffer + err = tmpl.Execute(&parentBuf, opts) + assert.Nil(t, err) + assert.Equal(t, parentBuf.String(), mf) +} diff --git a/parser/parser.go b/parser/parser.go index 2fbd3cc5..947848b2 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "log/slog" + "slices" ) type Command struct { @@ -56,6 +57,16 @@ func Parse(reader io.Reader) ([]Command, error) { command.Args = string(bytes.TrimSpace(fields[1])) case "EMBED": return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead") + case "MESSAGE": + command.Name = string(bytes.ToLower(fields[0])) + fields = bytes.SplitN(fields[1], []byte(" "), 2) + if len(fields) < 2 { + return nil, fmt.Errorf("should be in the format ") + } + if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) { + return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"") + } + command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1])) default: if !bytes.HasPrefix(fields[0], []byte("#")) { // log a warning for unknown commands diff --git a/parser/parser_test.go b/parser/parser_test.go index 53555ad1..25e849b5 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -61,3 +61,38 @@ PARAMETER param1 assert.ErrorContains(t, err, "missing value for [param1]") } + +func Test_Parser_Messages(t *testing.T) { + + input := ` +FROM foo +MESSAGE system You are a Parser. Always Parse things. +MESSAGE user Hey there! +MESSAGE assistant Hello, I want to parse all the things! +` + + reader := strings.NewReader(input) + commands, err := Parse(reader) + assert.Nil(t, err) + + expectedCommands := []Command{ + {Name: "model", Args: "foo"}, + {Name: "message", Args: "system: You are a Parser. Always Parse things."}, + {Name: "message", Args: "user: Hey there!"}, + {Name: "message", Args: "assistant: Hello, I want to parse all the things!"}, + } + + assert.Equal(t, expectedCommands, commands) +} + +func Test_Parser_Messages_BadRole(t *testing.T) { + + input := ` +FROM foo +MESSAGE badguy I'm a bad guy! +` + + reader := strings.NewReader(input) + _, err := Parse(reader) + assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"") +} diff --git a/server/images.go b/server/images.go index a20f6bd7..ab3b4faa 100644 --- a/server/images.go +++ b/server/images.go @@ -41,7 +41,7 @@ type Model struct { Config ConfigV2 ShortName string ModelPath string - OriginalModel string + ParentModel string AdapterPaths []string ProjectorPaths []string Template string @@ -50,6 +50,12 @@ type Model struct { Digest string Size int64 Options map[string]interface{} + Messages []Message +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` } type PromptVars struct { @@ -333,7 +339,7 @@ func GetModel(name string) (*Model, error) { switch layer.MediaType { case "application/vnd.ollama.image.model": model.ModelPath = filename - model.OriginalModel = layer.From + model.ParentModel = layer.From case "application/vnd.ollama.image.embed": // Deprecated in versions > 0.1.2 // TODO: remove this warning in a future version @@ -374,6 +380,16 @@ func GetModel(name string) (*Model, error) { if err = json.NewDecoder(params).Decode(&model.Options); err != nil { return nil, err } + case "application/vnd.ollama.image.messages": + msgs, err := os.Open(filename) + if err != nil { + return nil, err + } + defer msgs.Close() + + if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil { + return nil, err + } case "application/vnd.ollama.image.license": bts, err := os.ReadFile(filename) if err != nil { @@ -428,12 +444,12 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars } var layers Layers + messages := []string{} params := make(map[string][]string) fromParams := make(map[string]any) for _, c := range commands { - slog.Info(fmt.Sprintf("[%s] - %s", c.Name, c.Args)) mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) switch c.Name { @@ -607,11 +623,37 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars } layers.Replace(layer) + case "message": + messages = append(messages, c.Args) default: params[c.Name] = append(params[c.Name], c.Args) } } + if len(messages) > 0 { + fn(api.ProgressResponse{Status: "creating parameters layer"}) + + msgs := make([]api.Message, 0) + + for _, m := range messages { + // todo: handle images + msg := strings.SplitN(m, ": ", 2) + msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]}) + } + + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(msgs); err != nil { + return err + } + + layer, err := NewLayer(&b, "application/vnd.ollama.image.messages") + if err != nil { + return err + } + + layers.Replace(layer) + } + if len(params) > 0 { fn(api.ProgressResponse{Status: "creating parameters layer"}) @@ -908,8 +950,8 @@ func ShowModelfile(model *Model) (string, error) { mt.Model = model mt.From = model.ModelPath - if model.OriginalModel != "" { - mt.From = model.OriginalModel + if model.ParentModel != "" { + mt.From = model.ParentModel } modelFile := `# Modelfile generated by "ollama show" diff --git a/server/routes.go b/server/routes.go index 0c145ae6..141f05d4 100644 --- a/server/routes.go +++ b/server/routes.go @@ -659,6 +659,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { } modelDetails := api.ModelDetails{ + ParentModel: model.ParentModel, Format: model.Config.ModelFormat, Family: model.Config.ModelFamily, Families: model.Config.ModelFamilies, @@ -674,11 +675,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { model.Template = req.Template } + msgs := make([]api.Message, 0) + for _, msg := range model.Messages { + msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content}) + } + resp := &api.ShowResponse{ License: strings.Join(model.License, "\n"), System: model.System, Template: model.Template, Details: modelDetails, + Messages: msgs, } var params []string @@ -1075,7 +1082,13 @@ func ChatHandler(c *gin.Context) { // an empty request loads the model if len(req.Messages) == 0 { - c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true, Message: api.Message{Role: "assistant"}}) + resp := api.ChatResponse{ + CreatedAt: time.Now().UTC(), + Model: req.Model, + Done: true, + Message: api.Message{Role: "assistant"}, + } + c.JSON(http.StatusOK, resp) return }