diff --git a/server/prompt.go b/server/prompt.go index 51d691a9..abc5e61e 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "log/slog" - "slices" "github.com/ollama/ollama/api" "github.com/ollama/ollama/llm" @@ -17,26 +16,18 @@ type tokenizeFunc func(context.Context, string) ([]int, error) // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the // latest message and 2) system messages func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) { - // pull out any system messages which should always be included in the prompt var system []api.Message - msgs = slices.DeleteFunc(msgs, func(m api.Message) bool { - if m.Role == "system" { - system = append(system, m) - return true - } - - return false - }) - - if len(system) == 0 && m.System != "" { - // add model system prompt since it wasn't provided - system = append(system, api.Message{Role: "system", Content: m.System}) - } - // always include the last message n := len(msgs) - 1 // in reverse, find all messages that fit into context window for i := n - 1; i >= 0; i-- { + system = make([]api.Message, 0) + for j := range i { + if msgs[j].Role == "system" { + system = append(system, msgs[j]) + } + } + var b bytes.Buffer if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil { return "", nil, err diff --git a/server/prompt_test.go b/server/prompt_test.go index 1435b143..d8caf3ed 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -6,6 +6,7 @@ import ( "strings" "testing" + "github.com/google/go-cmp/cmp" "github.com/ollama/ollama/api" "github.com/ollama/ollama/template" ) @@ -164,6 +165,19 @@ func TestChatPrompt(t *testing.T) { prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ", }, }, + { + name: "out of order system", + limit: 2048, + msgs: []api.Message{ + {Role: "user", Content: "You're a test, Harry!"}, + {Role: "assistant", Content: "I-I'm a what?"}, + {Role: "system", Content: "You are the Test Who Lived."}, + {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."}, + }, + expect: expect{ + prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ", + }, + }, } tmpl, err := template.Parse(` @@ -187,6 +201,10 @@ func TestChatPrompt(t *testing.T) { t.Errorf("expected %q, got %q", tt.prompt, prompt) } + if diff := cmp.Diff(prompt, tt.prompt); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + if len(images) != len(tt.images) { t.Fatalf("expected %d images, got %d", len(tt.images), len(images)) } diff --git a/template/template.go b/template/template.go index 9b351666..90014ec1 100644 --- a/template/template.go +++ b/template/template.go @@ -149,27 +149,19 @@ type Values struct { } func (t *Template) Execute(w io.Writer, v Values) error { - system, collated := collate(v.Messages) + system, messages := collate(v.Messages) if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { return t.Template.Execute(w, map[string]any{ "System": system, - "Messages": collated, + "Messages": messages, }) } + system = "" var b bytes.Buffer var prompt, response string - for i, m := range collated { - switch m.Role { - case "system": - system = m.Content - case "user": - prompt = m.Content - case "assistant": - response = m.Content - } - - if i != len(collated)-1 && prompt != "" && response != "" { + for _, m := range messages { + execute := func () error { if err := t.Template.Execute(&b, map[string]any{ "System": system, "Prompt": prompt, @@ -181,6 +173,26 @@ func (t *Template) Execute(w io.Writer, v Values) error { system = "" prompt = "" response = "" + return nil + } + + switch m.Role { + case "system": + if prompt != "" || response != "" { + if err := execute(); err != nil { + return err + } + } + system = m.Content + case "user": + if response != "" { + if err := execute(); err != nil { + return err + } + } + prompt = m.Content + case "assistant": + response = m.Content } } @@ -199,7 +211,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { tree := parse.Tree{Root: nodes.(*parse.ListNode)} if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{ - "System": "", + "System": system, "Prompt": prompt, }); err != nil { return err