diff --git a/server/prompt.go b/server/prompt.go index 5016fbe1..51d691a9 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -11,8 +11,13 @@ import ( "github.com/ollama/ollama/template" ) -func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) { - // extract system messages which should always be included +type tokenizeFunc func(context.Context, string) ([]int, error) + +// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn. +// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the +// latest message and 2) system messages +func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) { + // pull out any system messages which should always be included in the prompt var system []api.Message msgs = slices.DeleteFunc(msgs, func(m api.Message) bool { if m.Role == "system" { @@ -23,32 +28,35 @@ func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt s return false }) - if len(system) == 0 && r.model.System != "" { + if len(system) == 0 && m.System != "" { // add model system prompt since it wasn't provided - system = append(system, api.Message{Role: "system", Content: r.model.System}) + system = append(system, api.Message{Role: "system", Content: m.System}) } + // always include the last message n := len(msgs) - 1 + // in reverse, find all messages that fit into context window for i := n - 1; i >= 0; i-- { var b bytes.Buffer - if err := r.model.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil { + if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil { return "", nil, err } - s, err := r.llama.Tokenize(ctx, b.String()) + s, err := tokenize(ctx, b.String()) if err != nil { return "", nil, err } c := len(s) - if r.model.ProjectorPaths != nil { + if m.ProjectorPaths != nil { for _, m := range msgs[i:] { - // TODO: get image embedding length from project metadata + // images are represented as 768 sized embeddings + // TODO: get embedding length from project metadata c += 768 * len(m.Images) } } - if c > r.NumCtx { + if c > opts.NumCtx { slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:])) break } else { @@ -56,8 +64,9 @@ func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt s } } + // truncate any messages that do not fit into the context window var b bytes.Buffer - if err := r.model.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil { + if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil { return "", nil, err } diff --git a/server/prompt_test.go b/server/prompt_test.go index 59288b46..d4cee98c 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -7,15 +7,10 @@ import ( "testing" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/llm" "github.com/ollama/ollama/template" ) -type mock struct { - llm.LlamaServer -} - -func (m mock) Tokenize(_ context.Context, s string) (tokens []int, err error) { +func tokenize(_ context.Context, s string) (tokens []int, err error) { for range strings.Fields(s) { tokens = append(tokens, len(tokens)) } @@ -48,7 +43,7 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "truncate messages", + name: "truncate messages", limit: 1, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, @@ -60,7 +55,7 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "truncate messages with image", + name: "truncate messages with image", limit: 64, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, @@ -75,7 +70,7 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "truncate messages with images", + name: "truncate messages with images", limit: 64, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}}, @@ -90,7 +85,7 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "messages with images", + name: "messages with images", limit: 2048, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}}, @@ -106,7 +101,7 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "message with image tag", + name: "message with image tag", limit: 2048, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}}, @@ -122,7 +117,7 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "messages with interleaved images", + name: "messages with interleaved images", limit: 2048, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, @@ -140,7 +135,7 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "truncate message with interleaved images", + name: "truncate message with interleaved images", limit: 1024, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, @@ -157,7 +152,7 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "message with system prompt", + name: "message with system prompt", limit: 2048, msgs: []api.Message{ {Role: "system", Content: "You are the Test Who Lived."}, @@ -181,14 +176,9 @@ func TestChatPrompt(t *testing.T) { for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { - r := runnerRef{ - llama: mock{}, - model: &Model{Template: tmpl, ProjectorPaths: []string{"vision"}}, - Options: &api.Options{}, - } - - r.NumCtx = tt.limit - prompt, images, err := chatPrompt(context.TODO(), &r, tt.msgs) + model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}} + opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} + prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs) if err != nil { t.Fatal(err) } diff --git a/server/routes.go b/server/routes.go index 35e64511..1a93e977 100644 --- a/server/routes.go +++ b/server/routes.go @@ -54,6 +54,8 @@ func init() { gin.SetMode(mode) } +var errRequired = errors.New("is required") + func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) { opts := api.DefaultOptions() if err := opts.FromMap(model.Options); err != nil { @@ -69,7 +71,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (*runnerRef, error) { if name == "" { - return nil, errors.New("model is required") + return nil, fmt.Errorf("model %w", errRequired) } model, err := GetModel(name) @@ -121,7 +123,17 @@ func (s *Server) GenerateHandler(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)}) return } else if err != nil { - handleScheduleError(c, err) + handleScheduleError(c, req.Model, err) + return + } + + if req.Prompt == "" { + c.JSON(http.StatusOK, api.GenerateResponse{ + Model: req.Model, + CreatedAt: time.Now().UTC(), + Done: true, + DoneReason: "load", + }) return } @@ -139,23 +151,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { msgs = append(msgs, api.Message{Role: "system", Content: r.model.System}) } - if req.Prompt != "" { - for _, i := range images { - msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)}) - } - - msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt}) + for _, i := range images { + msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)}) } - if len(msgs) == 0 { - c.JSON(http.StatusOK, api.GenerateResponse{ - Model: req.Model, - CreatedAt: time.Now().UTC(), - Done: true, - DoneReason: "load", - }) - return - } + msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt}) tmpl := r.model.Template if req.Template != "" { @@ -256,7 +256,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { r, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) if err != nil { - handleScheduleError(c, err) + handleScheduleError(c, req.Model, err) return } @@ -1135,7 +1135,7 @@ func (s *Server) ChatHandler(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) return } else if err != nil { - handleScheduleError(c, err) + handleScheduleError(c, req.Model, err) return } @@ -1150,7 +1150,7 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - prompt, images, err := chatPrompt(c.Request.Context(), r, req.Messages) + prompt, images, err := chatPrompt(c.Request.Context(), r.model, r.llama.Tokenize, r.Options, req.Messages) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -1215,12 +1215,16 @@ func (s *Server) ChatHandler(c *gin.Context) { streamResponse(c, ch) } -func handleScheduleError(c *gin.Context, err error) { +func handleScheduleError(c *gin.Context, name string, err error) { switch { + case errors.Is(err, errRequired): + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) case errors.Is(err, context.Canceled): c.JSON(499, gin.H{"error": "request canceled"}) case errors.Is(err, ErrMaxQueue): c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()}) + case errors.Is(err, os.ErrNotExist): + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)}) default: c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } diff --git a/template/template.go b/template/template.go index cfba5a23..c8f8f6d0 100644 --- a/template/template.go +++ b/template/template.go @@ -83,6 +83,7 @@ type Template struct { raw string } +// response is a template node that can be added to templates that don't already have one var response = parse.ActionNode{ NodeType: parse.NodeAction, Pipe: &parse.PipeNode{ @@ -101,28 +102,25 @@ var response = parse.ActionNode{ }, } +var funcs = template.FuncMap{ + "toJson": func(v any) string { + b, err := json.Marshal(v) + if err != nil { + return "" + } + + return string(b) + }, + "add": func(a, b int) int { + return a + b + }, + "sub": func(a, b int) int { + return a - b + }, +} + func Parse(s string) (*Template, error) { - tmpl := template.New("").Option("missingkey=zero").Funcs(template.FuncMap{ - "toJson": func(v any) string { - b, err := json.Marshal(v) - if err != nil { - return "" - } - - return string(b) - }, - "isLastMessage": func(s []*api.Message, m *api.Message) bool { - for i := len(s) - 1; i >= 0; i-- { - if m.Role != s[i].Role { - continue - } - - return m == s[i] - } - - return false - }, - }) + tmpl := template.New("").Option("missingkey=zero").Funcs(funcs) tmpl, err := tmpl.Parse(s) if err != nil { @@ -218,7 +216,13 @@ func (t *Template) Execute(w io.Writer, v Values) error { return err } -func collate(msgs []api.Message) (system string, collated []*api.Message) { +type messages []*api.Message + +// collate messages based on role. consecutive messages of the same role are merged +// into a single message. collate also pulls out and merges messages with Role == "system" +// which are templated separately. As a side effect, it mangles message content adding image +// tags ([img-%d]) as needed +func collate(msgs []api.Message) (system string, collated messages) { var n int for i := range msgs { msg := msgs[i] diff --git a/template/template_test.go b/template/template_test.go index 5d5dad4b..ac92bf48 100644 --- a/template/template_test.go +++ b/template/template_test.go @@ -8,6 +8,7 @@ import ( "os" "path/filepath" "slices" + "strconv" "testing" "text/template" @@ -15,6 +16,98 @@ import ( "github.com/ollama/ollama/llm" ) +func TestFuncs(t *testing.T) { + t.Run("toJson", func(t *testing.T) { + cases := []struct { + input any + expected string + }{ + {nil, "null"}, + {true, "true"}, + {false, "false"}, + {0, "0"}, + {1, "1"}, + {1.0, "1"}, + {1.1, "1.1"}, + {"", `""`}, + {"hello", `"hello"`}, + {[]int{1, 2, 3}, "[1,2,3]"}, + {[]string{"a", "b", "c"}, `["a","b","c"]`}, + {map[string]int{"a": 1, "b": 2}, `{"a":1,"b":2}`}, + {map[string]string{"a": "b", "c": "d"}, `{"a":"b","c":"d"}`}, + } + + for _, tt := range cases { + t.Run(tt.expected, func(t *testing.T) { + toJson, ok := funcs["toJson"].(func(any) string) + if !ok { + t.Fatal("toJson is not a function") + } + + if s := toJson(tt.input); s != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, s) + } + }) + } + }) + + t.Run("add", func(t *testing.T) { + cases := []struct { + a, b int + expected int + }{ + {0, 0, 0}, + {0, 1, 1}, + {1, 0, 1}, + {1, 1, 2}, + {1, -1, 0}, + {-1, 1, 0}, + {-1, -1, -2}, + } + + for _, tt := range cases { + t.Run(strconv.Itoa(tt.expected), func(t *testing.T) { + add, ok := funcs["add"].(func(int, int) int) + if !ok { + t.Fatal("add is not a function") + } + + if n := add(tt.a, tt.b); n != tt.expected { + t.Errorf("expected %d, got %d", tt.expected, n) + } + }) + } + }) + + t.Run("sub", func(t *testing.T) { + cases := []struct { + a, b int + expected int + }{ + {0, 0, 0}, + {0, 1, -1}, + {1, 0, 1}, + {1, 1, 0}, + {1, -1, 2}, + {-1, 1, -2}, + {-1, -1, 0}, + } + + for _, tt := range cases { + t.Run(strconv.Itoa(tt.expected), func(t *testing.T) { + sub, ok := funcs["sub"].(func(int, int) int) + if !ok { + t.Fatal("sub is not a function") + } + + if n := sub(tt.a, tt.b); n != tt.expected { + t.Errorf("expected %d, got %d", tt.expected, n) + } + }) + } + }) +} + func TestNamed(t *testing.T) { f, err := os.Open(filepath.Join("testdata", "templates.jsonl")) if err != nil { @@ -89,77 +182,86 @@ func TestParse(t *testing.T) { } func TestExecuteWithMessages(t *testing.T) { + type template struct { + name string + template string + } cases := []struct { - templates []string + name string + templates []template values Values expected string }{ { - []string{ - `[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `, - `[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`, - `{{- range .Messages }} -{{- if eq .Role "user" }}[INST] {{ if and (isLastMessage $.Messages .) $.System }}{{ $.System }}{{ print "\n\n" }} + "mistral", + []template{ + {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `}, + {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, + {"messages", `{{- range .Messages }} +{{- if eq .Role "user" }}[INST] {{ if and (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}{{ $.System }}{{ "\n\n" }} {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }} {{- end }} -{{- end }}`, +{{- end }}`}, }, Values{ Messages: []api.Message{ {Role: "user", Content: "Hello friend!"}, {Role: "assistant", Content: "Hello human!"}, - {Role: "user", Content: "Yay!"}, + {Role: "user", Content: "What is your name?"}, }, }, - `[INST] Hello friend![/INST] Hello human![INST] Yay![/INST] `, + `[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `, }, { - []string{ - `[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `, - `[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`, - ` + "mistral system", + []template{ + {"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `}, + {"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`}, + {"messages", ` {{- range .Messages }} -{{- if eq .Role "user" }}[INST] {{ if and (isLastMessage $.Messages .) $.System }}{{ $.System }}{{ print "\n\n" }} +{{- if eq .Role "user" }}[INST] {{ if and (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}{{ $.System }}{{ "\n\n" }} {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }} {{- end }} -{{- end }}`, +{{- end }}`}, }, Values{ Messages: []api.Message{ {Role: "system", Content: "You are a helpful assistant!"}, {Role: "user", Content: "Hello friend!"}, {Role: "assistant", Content: "Hello human!"}, - {Role: "user", Content: "Yay!"}, + {Role: "user", Content: "What is your name?"}, }, }, `[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant! -Yay![/INST] `, +What is your name?[/INST] `, }, { - []string{ - `{{ if .System }}<|im_start|>system + "chatml", + []template{ + // this does not have a "no response" test because it's impossible to render the same output + {"response", `{{ if .System }}<|im_start|>system {{ .System }}<|im_end|> {{ end }}{{ if .Prompt }}<|im_start|>user {{ .Prompt }}<|im_end|> {{ end }}<|im_start|>assistant {{ .Response }}<|im_end|> -`, - ` +`}, + {"messages", ` {{- range .Messages }} -{{- if and (eq .Role "user") (isLastMessage $.Messages .) $.System }}<|im_start|>system -{{ $.System }}<|im_end|>{{ print "\n" }} +{{- if and (eq .Role "user") (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}<|im_start|>system +{{ $.System }}<|im_end|>{{ "\n" }} {{- end }}<|im_start|>{{ .Role }} -{{ .Content }}<|im_end|>{{ print "\n" }} +{{ .Content }}<|im_end|>{{ "\n" }} {{- end }}<|im_start|>assistant -`, +`}, }, Values{ Messages: []api.Message{ {Role: "system", Content: "You are a helpful assistant!"}, {Role: "user", Content: "Hello friend!"}, {Role: "assistant", Content: "Hello human!"}, - {Role: "user", Content: "Yay!"}, + {Role: "user", Content: "What is your name?"}, }, }, `<|im_start|>user @@ -169,23 +271,25 @@ Hello human!<|im_end|> <|im_start|>system You are a helpful assistant!<|im_end|> <|im_start|>user -Yay!<|im_end|> +What is your name?<|im_end|> <|im_start|>assistant `, }, { - []string{ - `{{ if .Prompt }}Question: {{ .Prompt }} + "moondream", + []template{ + // this does not have a "no response" test because it's impossible to render the same output + {"response", `{{ if .Prompt }}Question: {{ .Prompt }} {{ end }}Answer: {{ .Response }} -`, - ` +`}, + {"messages", ` {{- range .Messages }} -{{- if eq .Role "user" }}Question: {{ .Content }}{{ print "\n\n" }} -{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ print "\n\n" }} +{{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }} +{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }} {{- end }} -{{- end }}Answer: `, +{{- end }}Answer: `}, }, Values{ Messages: []api.Message{ @@ -211,10 +315,10 @@ Answer: `, } for _, tt := range cases { - t.Run("", func(t *testing.T) { - for _, tmpl := range tt.templates { - t.Run("", func(t *testing.T) { - tmpl, err := Parse(tmpl) + t.Run(tt.name, func(t *testing.T) { + for _, ttt := range tt.templates { + t.Run(ttt.name, func(t *testing.T) { + tmpl, err := Parse(ttt.template) if err != nil { t.Fatal(err) }