diff --git a/api/types.go b/api/types.go index 3b607cec..97af4aed 100644 --- a/api/types.go +++ b/api/types.go @@ -97,6 +97,9 @@ type ChatRequest struct { // followin the request. KeepAlive *Duration `json:"keep_alive,omitempty"` + // Tools is an optional list of tools the model has access to. + Tools []Tool `json:"tools,omitempty"` + // Options lists model-specific options. Options map[string]interface{} `json:"options"` } @@ -105,9 +108,36 @@ type ChatRequest struct { // role ("system", "user", or "assistant"), the content and an optional list // of images. type Message struct { - Role string `json:"role"` - Content string `json:"content"` - Images []ImageData `json:"images,omitempty"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + Images []ImageData `json:"images,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} + +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + } `json:"function"` +} + +type Tool struct { + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters struct { + Type string `json:"type"` + Required []string `json:"required"` + Properties map[string]struct { + Type string `json:"type"` + Description string `json:"description"` + Enum []string `json:"enum,omitempty"` + } `json:"properties"` + } `json:"parameters"` + } `json:"function"` } func (m *Message) UnmarshalJSON(b []byte) error { @@ -374,6 +404,9 @@ type GenerateResponse struct { // Response is the textual response itself. Response string `json:"response"` + // ToolCalls is the list of tools the model wants to call + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + // Done specifies if the response is complete. Done bool `json:"done"` diff --git a/server/images.go b/server/images.go index 688d5dca..1b87888e 100644 --- a/server/images.go +++ b/server/images.go @@ -38,7 +38,10 @@ var errCapabilityCompletion = errors.New("completion") type Capability string -const CapabilityCompletion = Capability("completion") +const ( + CapabilityCompletion = Capability("completion") + CapabilityTools = Capability("tools") +) type registryOptions struct { Insecure bool @@ -88,6 +91,10 @@ func (m *Model) CheckCapabilities(caps ...Capability) error { if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok { errs = append(errs, errCapabilityCompletion) } + case CapabilityTools: + if !slices.Contains(m.Template.Vars(), "tools") { + errs = append(errs, errors.New("tools")) + } default: slog.Error("unknown capability", "capability", cap) return fmt.Errorf("unknown capability: %s", cap) @@ -95,7 +102,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error { } if err := errors.Join(errs...); err != nil { - return fmt.Errorf("missing capabilities: %w", errors.Join(errs...)) + return fmt.Errorf("does not support %w", errors.Join(errs...)) } return nil diff --git a/server/model.go b/server/model.go index a79f549a..be318db9 100644 --- a/server/model.go +++ b/server/model.go @@ -4,6 +4,7 @@ import ( "archive/zip" "bytes" "context" + "encoding/json" "errors" "fmt" "io" @@ -11,7 +12,11 @@ import ( "net/http" "os" "path/filepath" + "slices" + "strings" + "text/template/parse" + "github.com/google/uuid" "github.com/ollama/ollama/api" "github.com/ollama/ollama/convert" "github.com/ollama/ollama/llm" @@ -289,3 +294,103 @@ func detectContentType(r io.Reader) (string, error) { return "unknown", nil } + +// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls. +// mxyng: this only really works if the input contains tool calls in some JSON format +func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { + // create a subtree from the node that ranges over .ToolCalls + tmpl := m.Template.Subtree(func(n parse.Node) bool { + if t, ok := n.(*parse.RangeNode); ok { + return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") + } + + return false + }) + + if tmpl == nil { + return nil, false + } + + var b bytes.Buffer + if err := tmpl.Execute(&b, map[string][]map[string]any{ + "ToolCalls": { + { + "Function": map[string]any{ + "Name": "@@name@@", + "Arguments": "@@arguments@@", + }, + }, + }, + }); err != nil { + return nil, false + } + + var kv map[string]string + // execute the subtree with placeholders to identify the keys + if err := json.Unmarshal(b.Bytes(), &kv); err != nil { + return nil, false + } + + // find the keys that correspond to the name and arguments fields + var name, arguments string + for k, v := range kv { + switch v { + case "@@name@@": + name = k + case "@@arguments@@": + arguments = k + } + } + + var sm []map[string]any + decoder := json.NewDecoder(strings.NewReader(s)) + for { + // incrementally decode the JSON into a list of JSON objects + // skipping over any invalid tokens + if err := decoder.Decode(&sm); err != nil { + if errors.Is(err, io.EOF) { + break + } + + if errors.As(err, new(*json.SyntaxError)) { + r := decoder.Buffered() + if _, err := r.Read(make([]byte, decoder.InputOffset()+1)); err != nil { + break + } + + decoder = json.NewDecoder(r) + continue + } + + return nil, false + } + + // break as soon as a valid object is decoded + break + } + + var toolCalls []api.ToolCall + for _, kv := range sm { + call := api.ToolCall{ + ID: uuid.New().String(), + Type: "function", + } + + for k, v := range kv { + switch k { + case name: + call.Function.Name = v.(string) + case arguments: + call.Function.Arguments = v.(map[string]any) + } + } + + toolCalls = append(toolCalls, call) + } + + if len(toolCalls) > 0 { + return toolCalls, true + } + + return nil, false +} diff --git a/server/prompt.go b/server/prompt.go index abc5e61e..be0d4969 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -15,7 +15,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error) // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn. // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the // latest message and 2) system messages -func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) { +func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) { var system []api.Message // always include the last message n := len(msgs) - 1 @@ -29,7 +29,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. } var b bytes.Buffer - if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil { + if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil { return "", nil, err } @@ -57,7 +57,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. // truncate any messages that do not fit into the context window var b bytes.Buffer - if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil { + if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil { return "", nil, err } diff --git a/server/prompt_test.go b/server/prompt_test.go index d8caf3ed..9c4da068 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -192,7 +192,7 @@ func TestChatPrompt(t *testing.T) { t.Run(tt.name, func(t *testing.T) { model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}} opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} - prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs) + prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs, nil) if err != nil { t.Fatal(err) } diff --git a/server/routes.go b/server/routes.go index c5c3a19c..9712d895 100644 --- a/server/routes.go +++ b/server/routes.go @@ -265,6 +265,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { } r.Response = sb.String() + if toolCalls, ok := m.parseToolCalls(sb.String()); ok { + r.ToolCalls = toolCalls + r.Response = "" + } + c.JSON(http.StatusOK, r) return } @@ -1279,6 +1284,10 @@ func (s *Server) ChatHandler(c *gin.Context) { } caps := []Capability{CapabilityCompletion} + if req.Tools != nil { + caps = append(caps, CapabilityTools) + } + r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) @@ -1305,7 +1314,7 @@ func (s *Server) ChatHandler(c *gin.Context) { req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...) } - prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages) + prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -1348,13 +1357,13 @@ func (s *Server) ChatHandler(c *gin.Context) { }() if req.Stream != nil && !*req.Stream { - var r api.ChatResponse + var resp api.ChatResponse var sb strings.Builder for rr := range ch { switch t := rr.(type) { case api.ChatResponse: sb.WriteString(t.Message.Content) - r = t + resp = t case gin.H: msg, ok := t["error"].(string) if !ok { @@ -1369,8 +1378,13 @@ func (s *Server) ChatHandler(c *gin.Context) { } } - r.Message.Content = sb.String() - c.JSON(http.StatusOK, r) + resp.Message.Content = sb.String() + if toolCalls, ok := m.parseToolCalls(sb.String()); ok { + resp.Message.ToolCalls = toolCalls + resp.Message.Content = "" + } + + c.JSON(http.StatusOK, resp) return } diff --git a/template/template.go b/template/template.go index 90014ec1..0e23cf1c 100644 --- a/template/template.go +++ b/template/template.go @@ -13,6 +13,7 @@ import ( "sync" "text/template" "text/template/parse" + "time" "github.com/agnivade/levenshtein" "github.com/ollama/ollama/api" @@ -102,8 +103,18 @@ var response = parse.ActionNode{ }, } +var funcs = template.FuncMap{ + "json": func(v any) string { + b, _ := json.Marshal(v) + return string(b) + }, + "now": func() string { + return time.Now().Format("2006-01-02 15:04:05") + }, +} + func Parse(s string) (*Template, error) { - tmpl := template.New("").Option("missingkey=zero") + tmpl := template.New("").Option("missingkey=zero").Funcs(funcs) tmpl, err := tmpl.Parse(s) if err != nil { @@ -127,7 +138,7 @@ func (t *Template) Vars() []string { var vars []string for _, tt := range t.Templates() { for _, n := range tt.Root.Nodes { - vars = append(vars, parseNode(n)...) + vars = append(vars, Identifiers(n)...) } } @@ -143,17 +154,65 @@ func (t *Template) Vars() []string { type Values struct { Messages []api.Message + Tools []api.Tool // forceLegacy is a flag used to test compatibility with legacy templates forceLegacy bool } +func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template { + var walk func(parse.Node) parse.Node + walk = func(n parse.Node) parse.Node { + if fn(n) { + return n + } + + switch t := n.(type) { + case *parse.ListNode: + for _, c := range t.Nodes { + if n := walk(c); n != nil { + return n + } + } + case *parse.BranchNode: + for _, n := range []*parse.ListNode{t.List, t.ElseList} { + if n != nil { + if n := walk(n); n != nil { + return n + } + } + } + case *parse.IfNode: + return walk(&t.BranchNode) + case *parse.WithNode: + return walk(&t.BranchNode) + case *parse.RangeNode: + return walk(&t.BranchNode) + } + + return nil + } + + if n := walk(t.Tree.Root); n != nil { + return (&template.Template{ + Tree: &parse.Tree{ + Root: &parse.ListNode{ + Nodes: []parse.Node{n}, + }, + }, + }).Funcs(funcs) + } + + return nil +} + func (t *Template) Execute(w io.Writer, v Values) error { system, messages := collate(v.Messages) if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { return t.Template.Execute(w, map[string]any{ "System": system, "Messages": messages, + "Tools": v.Tools, }) } @@ -161,7 +220,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { var b bytes.Buffer var prompt, response string for _, m := range messages { - execute := func () error { + execute := func() error { if err := t.Template.Execute(&b, map[string]any{ "System": system, "Prompt": prompt, @@ -198,12 +257,8 @@ func (t *Template) Execute(w io.Writer, v Values) error { var cut bool nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool { - switch t := n.(type) { - case *parse.ActionNode: - case *parse.FieldNode: - if slices.Contains(t.Ident, "Response") { - cut = true - } + if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") { + cut = true } return cut @@ -255,50 +310,46 @@ func collate(msgs []api.Message) (string, []*api.Message) { return strings.Join(system, "\n\n"), collated } -func parseNode(n parse.Node) []string { +// Identifiers walks the node tree returning any identifiers it finds along the way +func Identifiers(n parse.Node) []string { switch n := n.(type) { + case *parse.ListNode: + var names []string + for _, n := range n.Nodes { + names = append(names, Identifiers(n)...) + } + + return names + case *parse.TemplateNode: + return Identifiers(n.Pipe) case *parse.ActionNode: - return parseNode(n.Pipe) + return Identifiers(n.Pipe) + case *parse.BranchNode: + names := Identifiers(n.Pipe) + for _, n := range []*parse.ListNode{n.List, n.ElseList} { + if n != nil { + names = append(names, Identifiers(n)...) + } + } + return names case *parse.IfNode: - names := parseNode(n.Pipe) - names = append(names, parseNode(n.List)...) - if n.ElseList != nil { - names = append(names, parseNode(n.ElseList)...) - } - return names + return Identifiers(&n.BranchNode) case *parse.RangeNode: - names := parseNode(n.Pipe) - names = append(names, parseNode(n.List)...) - if n.ElseList != nil { - names = append(names, parseNode(n.ElseList)...) - } - return names + return Identifiers(&n.BranchNode) case *parse.WithNode: - names := parseNode(n.Pipe) - names = append(names, parseNode(n.List)...) - if n.ElseList != nil { - names = append(names, parseNode(n.ElseList)...) - } - return names + return Identifiers(&n.BranchNode) case *parse.PipeNode: var names []string for _, c := range n.Cmds { for _, a := range c.Args { - names = append(names, parseNode(a)...) + names = append(names, Identifiers(a)...) } } - return names - case *parse.ListNode: - var names []string - for _, n := range n.Nodes { - names = append(names, parseNode(n)...) - } - return names case *parse.FieldNode: return n.Ident - case *parse.TemplateNode: - return parseNode(n.Pipe) + case *parse.VariableNode: + return n.Ident } return nil