diff --git a/api/client.go b/api/client.go index 44af222c..250711dd 100644 --- a/api/client.go +++ b/api/client.go @@ -221,6 +221,19 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate }) } +type ChatResponseFunc func(ChatResponse) error + +func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error { + return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error { + var resp ChatResponse + if err := json.Unmarshal(bts, &resp); err != nil { + return err + } + + return fn(resp) + }) +} + type PullProgressFunc func(ProgressResponse) error func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error { diff --git a/api/types.go b/api/types.go index 692c4445..14a7059e 100644 --- a/api/types.go +++ b/api/types.go @@ -36,7 +36,7 @@ type GenerateRequest struct { Prompt string `json:"prompt"` System string `json:"system"` Template string `json:"template"` - Context []int `json:"context,omitempty"` + Context []int `json:"context,omitempty"` // DEPRECATED: context is deprecated, use the /chat endpoint instead for chat history Stream *bool `json:"stream,omitempty"` Raw bool `json:"raw,omitempty"` Format string `json:"format"` @@ -44,6 +44,41 @@ type GenerateRequest struct { Options map[string]interface{} `json:"options"` } +type ChatRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Template string `json:"template"` + Stream *bool `json:"stream,omitempty"` + Format string `json:"format"` + + Options map[string]interface{} `json:"options"` +} + +type Message struct { + Role string `json:"role"` // one of ["system", "user", "assistant"] + Content string `json:"content"` +} + +type ChatResponse struct { + Model string `json:"model"` + CreatedAt time.Time `json:"created_at"` + Message *Message `json:"message,omitempty"` + + Done bool `json:"done"` + Context []int `json:"context,omitempty"` + + EvalMetrics +} + +type EvalMetrics struct { + TotalDuration time.Duration `json:"total_duration,omitempty"` + LoadDuration time.Duration `json:"load_duration,omitempty"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + EvalDuration time.Duration `json:"eval_duration,omitempty"` +} + // Options specfied in GenerateRequest, if you add a new option here add it to the API docs also type Options struct { Runner @@ -173,39 +208,34 @@ type GenerateResponse struct { Done bool `json:"done"` Context []int `json:"context,omitempty"` - TotalDuration time.Duration `json:"total_duration,omitempty"` - LoadDuration time.Duration `json:"load_duration,omitempty"` - PromptEvalCount int `json:"prompt_eval_count,omitempty"` - PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` - EvalCount int `json:"eval_count,omitempty"` - EvalDuration time.Duration `json:"eval_duration,omitempty"` + EvalMetrics } -func (r *GenerateResponse) Summary() { - if r.TotalDuration > 0 { - fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration) +func (m *EvalMetrics) Summary() { + if m.TotalDuration > 0 { + fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration) } - if r.LoadDuration > 0 { - fmt.Fprintf(os.Stderr, "load duration: %v\n", r.LoadDuration) + if m.LoadDuration > 0 { + fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration) } - if r.PromptEvalCount > 0 { - fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount) + if m.PromptEvalCount > 0 { + fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", m.PromptEvalCount) } - if r.PromptEvalDuration > 0 { - fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", r.PromptEvalDuration) - fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(r.PromptEvalCount)/r.PromptEvalDuration.Seconds()) + if m.PromptEvalDuration > 0 { + fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", m.PromptEvalDuration) + fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(m.PromptEvalCount)/m.PromptEvalDuration.Seconds()) } - if r.EvalCount > 0 { - fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", r.EvalCount) + if m.EvalCount > 0 { + fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", m.EvalCount) } - if r.EvalDuration > 0 { - fmt.Fprintf(os.Stderr, "eval duration: %s\n", r.EvalDuration) - fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(r.EvalCount)/r.EvalDuration.Seconds()) + if m.EvalDuration > 0 { + fmt.Fprintf(os.Stderr, "eval duration: %s\n", m.EvalDuration) + fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(m.EvalCount)/m.EvalDuration.Seconds()) } } diff --git a/cmd/cmd.go b/cmd/cmd.go index df0d90c8..d3c3b777 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -159,7 +159,54 @@ func RunHandler(cmd *cobra.Command, args []string) error { return err } - return RunGenerate(cmd, args) + interactive := true + + opts := runOptions{ + Model: name, + WordWrap: os.Getenv("TERM") == "xterm-256color", + Options: map[string]interface{}{}, + } + + format, err := cmd.Flags().GetString("format") + if err != nil { + return err + } + opts.Format = format + + prompts := args[1:] + + // prepend stdin to the prompt if provided + if !term.IsTerminal(int(os.Stdin.Fd())) { + in, err := io.ReadAll(os.Stdin) + if err != nil { + return err + } + + prompts = append([]string{string(in)}, prompts...) + opts.WordWrap = false + interactive = false + } + msg := api.Message{ + Role: "user", + Content: strings.Join(prompts, " "), + } + opts.Messages = append(opts.Messages, msg) + if len(prompts) > 0 { + interactive = false + } + + nowrap, err := cmd.Flags().GetBool("nowordwrap") + if err != nil { + return err + } + opts.WordWrap = !nowrap + + if !interactive { + _, err := chat(cmd, opts) + return err + } + + return chatInteractive(cmd, opts) } func PushHandler(cmd *cobra.Command, args []string) error { @@ -411,83 +458,26 @@ func PullHandler(cmd *cobra.Command, args []string) error { return nil } -func RunGenerate(cmd *cobra.Command, args []string) error { - interactive := true - - opts := generateOptions{ - Model: args[0], - WordWrap: os.Getenv("TERM") == "xterm-256color", - Options: map[string]interface{}{}, - } - - format, err := cmd.Flags().GetString("format") - if err != nil { - return err - } - opts.Format = format - - prompts := args[1:] - - // prepend stdin to the prompt if provided - if !term.IsTerminal(int(os.Stdin.Fd())) { - in, err := io.ReadAll(os.Stdin) - if err != nil { - return err - } - - prompts = append([]string{string(in)}, prompts...) - opts.WordWrap = false - interactive = false - } - opts.Prompt = strings.Join(prompts, " ") - if len(prompts) > 0 { - interactive = false - } - - nowrap, err := cmd.Flags().GetBool("nowordwrap") - if err != nil { - return err - } - opts.WordWrap = !nowrap - - if !interactive { - return generate(cmd, opts) - } - - return generateInteractive(cmd, opts) -} - -type generateContextKey string - -type generateOptions struct { +type runOptions struct { Model string - Prompt string + Messages []api.Message WordWrap bool Format string - System string Template string Options map[string]interface{} } -func generate(cmd *cobra.Command, opts generateOptions) error { +func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { client, err := api.ClientFromEnvironment() if err != nil { - return err + return nil, err } p := progress.NewProgress(os.Stderr) defer p.StopAndClear() - spinner := progress.NewSpinner("") p.Add("", spinner) - var latest api.GenerateResponse - - generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int) - if !ok { - generateContext = []int{} - } - termWidth, _, err := term.GetSize(int(os.Stdout.Fd())) if err != nil { opts.WordWrap = false @@ -506,24 +496,24 @@ func generate(cmd *cobra.Command, opts generateOptions) error { var currentLineLength int var wordBuffer string + var latest api.ChatResponse + var fullResponse strings.Builder + var role string - request := api.GenerateRequest{ - Model: opts.Model, - Prompt: opts.Prompt, - Context: generateContext, - Format: opts.Format, - System: opts.System, - Template: opts.Template, - Options: opts.Options, - } - fn := func(response api.GenerateResponse) error { + fn := func(response api.ChatResponse) error { p.StopAndClear() - latest = response + if response.Message == nil { + // warm-up response or done + return nil + } + role = response.Message.Role + content := response.Message.Content + fullResponse.WriteString(content) termWidth, _, _ = term.GetSize(int(os.Stdout.Fd())) if opts.WordWrap && termWidth >= 10 { - for _, ch := range response.Response { + for _, ch := range content { if currentLineLength+1 > termWidth-5 { if len(wordBuffer) > termWidth-10 { fmt.Printf("%s%c", wordBuffer, ch) @@ -551,7 +541,7 @@ func generate(cmd *cobra.Command, opts generateOptions) error { } } } else { - fmt.Printf("%s%s", wordBuffer, response.Response) + fmt.Printf("%s%s", wordBuffer, content) if len(wordBuffer) > 0 { wordBuffer = "" } @@ -560,35 +550,35 @@ func generate(cmd *cobra.Command, opts generateOptions) error { return nil } - if err := client.Generate(cancelCtx, &request, fn); err != nil { - if errors.Is(err, context.Canceled) { - return nil - } - return err + req := &api.ChatRequest{ + Model: opts.Model, + Messages: opts.Messages, + Format: opts.Format, + Template: opts.Template, + Options: opts.Options, } - if opts.Prompt != "" { - fmt.Println() - fmt.Println() + if err := client.Chat(cancelCtx, req, fn); err != nil { + if errors.Is(err, context.Canceled) { + return nil, nil + } + return nil, err } - if !latest.Done { - return nil + if len(opts.Messages) > 0 { + fmt.Println() + fmt.Println() } verbose, err := cmd.Flags().GetBool("verbose") if err != nil { - return err + return nil, err } if verbose { latest.Summary() } - ctx := cmd.Context() - ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context) - cmd.SetContext(ctx) - - return nil + return &api.Message{Role: role, Content: fullResponse.String()}, nil } type MultilineState int @@ -600,13 +590,10 @@ const ( MultilineTemplate ) -func generateInteractive(cmd *cobra.Command, opts generateOptions) error { +func chatInteractive(cmd *cobra.Command, opts runOptions) error { // load the model - loadOpts := generateOptions{ - Model: opts.Model, - Prompt: "", - } - if err := generate(cmd, loadOpts); err != nil { + loadOpts := runOptions{Model: opts.Model} + if _, err := chat(cmd, loadOpts); err != nil { return err } @@ -677,7 +664,9 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error { defer fmt.Printf(readline.EndBracketedPaste) var multiline MultilineState - var prompt string + var content string + var systemContent string + opts.Messages = make([]api.Message, 0) for { line, err := scanner.Readline() @@ -691,7 +680,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error { } scanner.Prompt.UseAlt = false - prompt = "" + content = "" continue case err != nil: @@ -699,37 +688,37 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error { } switch { - case strings.HasPrefix(prompt, `"""`): + case strings.HasPrefix(content, `"""`): // if the prompt so far starts with """ then we're in multiline mode // and we need to keep reading until we find a line that ends with """ cut, found := strings.CutSuffix(line, `"""`) - prompt += cut + "\n" + content += cut + "\n" if !found { continue } - prompt = strings.TrimPrefix(prompt, `"""`) + content = strings.TrimPrefix(content, `"""`) scanner.Prompt.UseAlt = false switch multiline { case MultilineSystem: - opts.System = prompt - prompt = "" + systemContent = content + content = "" fmt.Println("Set system template.\n") case MultilineTemplate: - opts.Template = prompt - prompt = "" + opts.Template = content + content = "" fmt.Println("Set model template.\n") } multiline = MultilineNone - case strings.HasPrefix(line, `"""`) && len(prompt) == 0: + case strings.HasPrefix(line, `"""`) && len(content) == 0: scanner.Prompt.UseAlt = true multiline = MultilinePrompt - prompt += line + "\n" + content += line + "\n" continue case scanner.Pasting: - prompt += line + "\n" + content += line + "\n" continue case strings.HasPrefix(line, "/list"): args := strings.Fields(line) @@ -791,17 +780,17 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error { line = strings.TrimPrefix(line, `"""`) if strings.HasPrefix(args[2], `"""`) { cut, found := strings.CutSuffix(line, `"""`) - prompt += cut + "\n" + content += cut + "\n" if found { - opts.System = prompt + systemContent = content if args[1] == "system" { fmt.Println("Set system template.\n") } else { fmt.Println("Set prompt template.\n") } - prompt = "" + content = "" } else { - prompt = `"""` + prompt + content = `"""` + content if args[1] == "system" { multiline = MultilineSystem } else { @@ -810,7 +799,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error { scanner.Prompt.UseAlt = true } } else { - opts.System = line + systemContent = line fmt.Println("Set system template.\n") } default: @@ -858,8 +847,8 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error { } case "system": switch { - case opts.System != "": - fmt.Println(opts.System + "\n") + case systemContent != "": + fmt.Println(systemContent + "\n") case resp.System != "": fmt.Println(resp.System + "\n") default: @@ -899,16 +888,23 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error { fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0]) continue default: - prompt += line + content += line } - if len(prompt) > 0 && multiline == MultilineNone { - opts.Prompt = prompt - if err := generate(cmd, opts); err != nil { + if len(content) > 0 && multiline == MultilineNone { + if systemContent != "" { + opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: systemContent}) + } + opts.Messages = append(opts.Messages, api.Message{Role: "user", Content: content}) + assistant, err := chat(cmd, opts) + if err != nil { return err } + if assistant != nil { + opts.Messages = append(opts.Messages, *assistant) + } - prompt = "" + content = "" } } } diff --git a/docs/api.md b/docs/api.md index 0595fadd..9e39cb9b 100644 --- a/docs/api.md +++ b/docs/api.md @@ -24,7 +24,7 @@ All durations are returned in nanoseconds. ### Streaming responses -Certain endpoints stream responses as JSON objects delineated with the newline (`\n`) character. +Certain endpoints stream responses as JSON objects. ## Generate a completion @@ -32,10 +32,12 @@ Certain endpoints stream responses as JSON objects delineated with the newline ( POST /api/generate ``` -Generate a response for a given prompt with a provided model. This is a streaming endpoint, so will be a series of responses. The final response object will include statistics and additional data from the request. +Generate a response for a given prompt with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request. ### Parameters +`model` is required. + - `model`: (required) the [model name](#model-names) - `prompt`: the prompt to generate a response for @@ -43,11 +45,10 @@ Advanced parameters (optional): - `format`: the format to return a response in. Currently the only accepted value is `json` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` -- `system`: system prompt to (overrides what is defined in the `Modelfile`) - `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`) -- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory +- `system`: system prompt to (overrides what is defined in the `Modelfile`) - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects -- `raw`: if `true` no formatting will be applied to the prompt and no context will be returned. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API, and are managing history yourself. +- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API. ### JSON mode @@ -57,7 +58,7 @@ Enable JSON mode by setting the `format` parameter to `json`. This will structur ### Examples -#### Request +#### Request (Prompt) ```shell curl http://localhost:11434/api/generate -d '{ @@ -89,7 +90,7 @@ The final response in the stream also includes additional data about the generat - `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt - `eval_count`: number of tokens the response - `eval_duration`: time in nanoseconds spent generating the response -- `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory +- `context`: deprecated, an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory - `response`: empty if the response was streamed, if not streamed, this will contain the full response To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration`. @@ -114,6 +115,8 @@ To calculate how fast the response is generated in tokens per second (token/s), #### Request (No streaming) +A response can be recieved in one reply when streaming is off. + ```shell curl http://localhost:11434/api/generate -d '{ "model": "llama2", @@ -144,9 +147,9 @@ If `stream` is set to `false`, the response will be a single JSON object: } ``` -#### Request (Raw mode) +#### Request (Raw Mode) -In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting and context. +In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting. ```shell curl http://localhost:11434/api/generate -d '{ @@ -164,6 +167,7 @@ curl http://localhost:11434/api/generate -d '{ "model": "mistral", "created_at": "2023-11-03T15:36:02.583064Z", "response": " The sky appears blue because of a phenomenon called Rayleigh scattering.", + "context": [1, 2, 3], "done": true, "total_duration": 14648695333, "load_duration": 3302671417, @@ -275,7 +279,6 @@ curl http://localhost:11434/api/generate -d '{ "model": "llama2", "created_at": "2023-08-04T19:22:45.499127Z", "response": "The sky is blue because it is the color of the sky.", - "context": [1, 2, 3], "done": true, "total_duration": 5589157167, "load_duration": 3013701500, @@ -288,6 +291,135 @@ curl http://localhost:11434/api/generate -d '{ } ``` +## Send Chat Messages +```shell +POST /api/chat +``` + +Generate the next message in a chat with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request. + +### Parameters + +`model` is required. + +- `model`: (required) the [model name](#model-names) +- `messages`: the messages of the chat, this can be used to keep a chat memory + +Advanced parameters (optional): + +- `format`: the format to return a response in. Currently the only accepted value is `json` +- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` +- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`) +- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects + +### Examples + +#### Request +Send a chat message with a streaming response. + +```shell +curl http://localhost:11434/api/generate -d '{ + "model": "llama2", + "messages": [ + { + "role": "user", + "content": "why is the sky blue?" + } + ] +}' +``` + +#### Response + +A stream of JSON objects is returned: + +```json +{ + "model": "llama2", + "created_at": "2023-08-04T08:52:19.385406455-07:00", + "message": { + "role": "assisant", + "content": "The" + }, + "done": false +} +``` + +Final response: + +```json +{ + "model": "llama2", + "created_at": "2023-08-04T19:22:45.499127Z", + "done": true, + "total_duration": 5589157167, + "load_duration": 3013701500, + "sample_count": 114, + "sample_duration": 81442000, + "prompt_eval_count": 46, + "prompt_eval_duration": 1160282000, + "eval_count": 113, + "eval_duration": 1325948000 +} +``` + +#### Request (With History) +Send a chat message with a conversation history. + +```shell +curl http://localhost:11434/api/generate -d '{ + "model": "llama2", + "messages": [ + { + "role": "user", + "content": "why is the sky blue?" + }, + { + "role": "assistant", + "content": "due to rayleigh scattering." + }, + { + "role": "user", + "content": "how is that different than mie scattering?" + } + ] +}' +``` + +#### Response + +A stream of JSON objects is returned: + +```json +{ + "model": "llama2", + "created_at": "2023-08-04T08:52:19.385406455-07:00", + "message": { + "role": "assisant", + "content": "The" + }, + "done": false +} +``` + +Final response: + +```json +{ + "model": "llama2", + "created_at": "2023-08-04T19:22:45.499127Z", + "done": true, + "total_duration": 5589157167, + "load_duration": 3013701500, + "sample_count": 114, + "sample_duration": 81442000, + "prompt_eval_count": 46, + "prompt_eval_duration": 1160282000, + "eval_count": 113, + "eval_duration": 1325948000 +} +``` + ## Create a Model ```shell diff --git a/llm/llama.go b/llm/llama.go index fc033258..3cce7fef 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -531,21 +531,31 @@ type prediction struct { const maxBufferSize = 512 * format.KiloByte -func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error { - prevConvo, err := llm.Decode(ctx, prevContext) - if err != nil { - return err - } +type PredictRequest struct { + Model string + Prompt string + Format string + CheckpointStart time.Time + CheckpointLoaded time.Time +} - // Remove leading spaces from prevConvo if present - prevConvo = strings.TrimPrefix(prevConvo, " ") - - var nextContext strings.Builder - nextContext.WriteString(prevConvo) - nextContext.WriteString(prompt) +type PredictResponse struct { + Model string + CreatedAt time.Time + TotalDuration time.Duration + LoadDuration time.Duration + Content string + Done bool + PromptEvalCount int + PromptEvalDuration time.Duration + EvalCount int + EvalDuration time.Duration + Context []int +} +func (llm *llama) Predict(ctx context.Context, predict PredictRequest, fn func(PredictResponse)) error { request := map[string]any{ - "prompt": nextContext.String(), + "prompt": predict.Prompt, "stream": true, "n_predict": llm.NumPredict, "n_keep": llm.NumKeep, @@ -567,7 +577,7 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, "stop": llm.Stop, } - if format == "json" { + if predict.Format == "json" { request["grammar"] = jsonGrammar } @@ -624,25 +634,25 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, } if p.Content != "" { - fn(api.GenerateResponse{Response: p.Content}) - nextContext.WriteString(p.Content) + fn(PredictResponse{ + Model: predict.Model, + CreatedAt: time.Now().UTC(), + Content: p.Content, + }) } if p.Stop { - embd, err := llm.Encode(ctx, nextContext.String()) - if err != nil { - return fmt.Errorf("encoding context: %v", err) - } + fn(PredictResponse{ + Model: predict.Model, + CreatedAt: time.Now().UTC(), + TotalDuration: time.Since(predict.CheckpointStart), - fn(api.GenerateResponse{ Done: true, - Context: embd, PromptEvalCount: p.Timings.PromptN, PromptEvalDuration: parseDurationMs(p.Timings.PromptMS), EvalCount: p.Timings.PredictedN, EvalDuration: parseDurationMs(p.Timings.PredictedMS), }) - return nil } } diff --git a/llm/llm.go b/llm/llm.go index 4901d9fe..703ea012 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -14,7 +14,7 @@ import ( ) type LLM interface { - Predict(context.Context, []int, string, string, func(api.GenerateResponse)) error + Predict(context.Context, PredictRequest, func(PredictResponse)) error Embedding(context.Context, string) ([]float64, error) Encode(context.Context, string) ([]int, error) Decode(context.Context, []int) (string, error) diff --git a/server/images.go b/server/images.go index 294fdf2b..efc5e8bc 100644 --- a/server/images.go +++ b/server/images.go @@ -47,37 +47,82 @@ type Model struct { Options map[string]interface{} } -func (m *Model) Prompt(request api.GenerateRequest) (string, error) { - t := m.Template - if request.Template != "" { - t = request.Template - } +type PromptVars struct { + System string + Prompt string + Response string +} - tmpl, err := template.New("").Parse(t) +func (m *Model) Prompt(p PromptVars) (string, error) { + var prompt strings.Builder + tmpl, err := template.New("").Parse(m.Template) if err != nil { return "", err } - var vars struct { - First bool - System string - Prompt string - } - - vars.First = len(request.Context) == 0 - vars.System = m.System - vars.Prompt = request.Prompt - - if request.System != "" { - vars.System = request.System + if p.System == "" { + // use the default system prompt for this model if one is not specified + p.System = m.System } var sb strings.Builder - if err := tmpl.Execute(&sb, vars); err != nil { + if err := tmpl.Execute(&sb, p); err != nil { return "", err } + prompt.WriteString(sb.String()) + prompt.WriteString(p.Response) + return prompt.String(), nil +} - return sb.String(), nil +func (m *Model) ChatPrompt(msgs []api.Message) (string, error) { + // build the prompt from the list of messages + var prompt strings.Builder + currentVars := PromptVars{} + + writePrompt := func() error { + p, err := m.Prompt(currentVars) + if err != nil { + return err + } + prompt.WriteString(p) + currentVars = PromptVars{} + return nil + } + + for _, msg := range msgs { + switch msg.Role { + case "system": + if currentVars.Prompt != "" || currentVars.System != "" { + if err := writePrompt(); err != nil { + return "", err + } + } + currentVars.System = msg.Content + case "user": + if currentVars.Prompt != "" || currentVars.System != "" { + if err := writePrompt(); err != nil { + return "", err + } + } + currentVars.Prompt = msg.Content + case "assistant": + currentVars.Response = msg.Content + if err := writePrompt(); err != nil { + return "", err + } + default: + return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role) + } + } + + // Append the last set of vars if they are non-empty + if currentVars.Prompt != "" || currentVars.System != "" { + if err := writePrompt(); err != nil { + return "", err + } + } + + return prompt.String(), nil } type ManifestV2 struct { diff --git a/server/images_test.go b/server/images_test.go index 5e6a197b..85e8d4bd 100644 --- a/server/images_test.go +++ b/server/images_test.go @@ -2,17 +2,15 @@ package server import ( "testing" - - "github.com/jmorganca/ollama/api" ) func TestModelPrompt(t *testing.T) { - var m Model - req := api.GenerateRequest{ + m := Model{ Template: "a{{ .Prompt }}b", - Prompt: "

", } - s, err := m.Prompt(req) + s, err := m.Prompt(PromptVars{ + Prompt: "

", + }) if err != nil { t.Fatal(err) } diff --git a/server/routes.go b/server/routes.go index bc8ea804..385af66a 100644 --- a/server/routes.go +++ b/server/routes.go @@ -60,17 +60,26 @@ var loaded struct { var defaultSessionDuration = 5 * time.Minute // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function -func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error { +func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sessionDuration time.Duration) (*Model, error) { + model, err := GetModel(modelName) + if err != nil { + return nil, err + } + + workDir := c.GetString("workDir") + opts := api.DefaultOptions() if err := opts.FromMap(model.Options); err != nil { log.Printf("could not load model options: %v", err) - return err + return nil, err } if err := opts.FromMap(reqOpts); err != nil { - return err + return nil, err } + ctx := c.Request.Context() + // check if the loaded model is still running in a subprocess, in case something unexpected happened if loaded.runner != nil { if err := loaded.runner.Ping(ctx); err != nil { @@ -106,7 +115,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string] err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName) } - return err + return nil, err } loaded.Model = model @@ -140,7 +149,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string] } loaded.expireTimer.Reset(sessionDuration) - return nil + return model, nil } func GenerateHandler(c *gin.Context) { @@ -173,88 +182,262 @@ func GenerateHandler(c *gin.Context) { return } - model, err := GetModel(req.Model) + sessionDuration := defaultSessionDuration + model, err := load(c, req.Model, req.Options, sessionDuration) if err != nil { var pErr *fs.PathError - if errors.As(err, &pErr) { + switch { + case errors.As(err, &pErr): c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) - return + case errors.Is(err, api.ErrInvalidOpts): + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + default: + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - workDir := c.GetString("workDir") - - // TODO: set this duration from the request if specified - sessionDuration := defaultSessionDuration - if err := load(c.Request.Context(), workDir, model, req.Options, sessionDuration); err != nil { - if errors.Is(err, api.ErrInvalidOpts) { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + // an empty request loads the model + if req.Prompt == "" && req.Template == "" && req.System == "" { + c.JSON(http.StatusOK, api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}) return } checkpointLoaded := time.Now() - prompt := req.Prompt - if !req.Raw { - prompt, err = model.Prompt(req) + var prompt string + sendContext := false + switch { + case req.Raw: + prompt = req.Prompt + case req.Prompt != "": + if req.Template != "" { + // override the default model template + model.Template = req.Template + } + + var rebuild strings.Builder + if req.Context != nil { + // TODO: context is deprecated, at some point the context logic within this conditional should be removed + prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Remove leading spaces from prevCtx if present + prevCtx = strings.TrimPrefix(prevCtx, " ") + rebuild.WriteString(prevCtx) + } + p, err := model.Prompt(PromptVars{ + System: req.System, + Prompt: req.Prompt, + }) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + rebuild.WriteString(p) + prompt = rebuild.String() + sendContext = true } ch := make(chan any) + var generated strings.Builder go func() { defer close(ch) - // an empty request loads the model - if req.Prompt == "" && req.Template == "" && req.System == "" { - ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true} - return - } - fn := func(r api.GenerateResponse) { + fn := func(r llm.PredictResponse) { + // Update model expiration loaded.expireAt = time.Now().Add(sessionDuration) loaded.expireTimer.Reset(sessionDuration) - r.Model = req.Model - r.CreatedAt = time.Now().UTC() - if r.Done { - r.TotalDuration = time.Since(checkpointStart) - r.LoadDuration = checkpointLoaded.Sub(checkpointStart) + // Build up the full response + if _, err := generated.WriteString(r.Content); err != nil { + ch <- gin.H{"error": err.Error()} + return } - if req.Raw { - // in raw mode the client must manage history on their own - r.Context = nil + resp := api.GenerateResponse{ + Model: r.Model, + CreatedAt: r.CreatedAt, + Done: r.Done, + Response: r.Content, + EvalMetrics: api.EvalMetrics{ + TotalDuration: r.TotalDuration, + LoadDuration: r.LoadDuration, + PromptEvalCount: r.PromptEvalCount, + PromptEvalDuration: r.PromptEvalDuration, + EvalCount: r.EvalCount, + EvalDuration: r.EvalDuration, + }, } - ch <- r + if r.Done && sendContext { + embd, err := loaded.runner.Encode(c.Request.Context(), req.Prompt+generated.String()) + if err != nil { + ch <- gin.H{"error": err.Error()} + return + } + r.Context = embd + } + + ch <- resp } - if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, req.Format, fn); err != nil { + // Start prediction + predictReq := llm.PredictRequest{ + Model: model.Name, + Prompt: prompt, + Format: req.Format, + CheckpointStart: checkpointStart, + CheckpointLoaded: checkpointLoaded, + } + if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() if req.Stream != nil && !*req.Stream { - var response api.GenerateResponse - generated := "" + // Wait for the channel to close + var r api.GenerateResponse + var sb strings.Builder for resp := range ch { - if r, ok := resp.(api.GenerateResponse); ok { - generated += r.Response - response = r - } else { + var ok bool + if r, ok = resp.(api.GenerateResponse); !ok { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + sb.WriteString(r.Response) } - response.Response = generated - c.JSON(http.StatusOK, response) + r.Response = sb.String() + c.JSON(http.StatusOK, r) + return + } + + streamResponse(c, ch) +} + +func ChatHandler(c *gin.Context) { + loaded.mu.Lock() + defer loaded.mu.Unlock() + + checkpointStart := time.Now() + + var req api.ChatRequest + err := c.ShouldBindJSON(&req) + switch { + case errors.Is(err, io.EOF): + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) + return + case err != nil: + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // validate the request + switch { + case req.Model == "": + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + return + case len(req.Format) > 0 && req.Format != "json": + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"}) + return + } + + sessionDuration := defaultSessionDuration + model, err := load(c, req.Model, req.Options, sessionDuration) + if err != nil { + var pErr *fs.PathError + switch { + case errors.As(err, &pErr): + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) + case errors.Is(err, api.ErrInvalidOpts): + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + default: + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + } + return + } + + // an empty request loads the model + if len(req.Messages) == 0 { + c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}) + return + } + + checkpointLoaded := time.Now() + + if req.Template != "" { + // override the default model template + model.Template = req.Template + } + prompt, err := model.ChatPrompt(req.Messages) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + ch := make(chan any) + + go func() { + defer close(ch) + + fn := func(r llm.PredictResponse) { + // Update model expiration + loaded.expireAt = time.Now().Add(sessionDuration) + loaded.expireTimer.Reset(sessionDuration) + + resp := api.ChatResponse{ + Model: r.Model, + CreatedAt: r.CreatedAt, + Done: r.Done, + EvalMetrics: api.EvalMetrics{ + TotalDuration: r.TotalDuration, + LoadDuration: r.LoadDuration, + PromptEvalCount: r.PromptEvalCount, + PromptEvalDuration: r.PromptEvalDuration, + EvalCount: r.EvalCount, + EvalDuration: r.EvalDuration, + }, + } + + if !r.Done { + resp.Message = &api.Message{Role: "assistant", Content: r.Content} + } + + ch <- resp + } + + // Start prediction + predictReq := llm.PredictRequest{ + Model: model.Name, + Prompt: prompt, + Format: req.Format, + CheckpointStart: checkpointStart, + CheckpointLoaded: checkpointLoaded, + } + if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { + ch <- gin.H{"error": err.Error()} + } + }() + + if req.Stream != nil && !*req.Stream { + // Wait for the channel to close + var r api.ChatResponse + var sb strings.Builder + for resp := range ch { + var ok bool + if r, ok = resp.(api.ChatResponse); !ok { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if r.Message != nil { + sb.WriteString(r.Message.Content) + } + } + r.Message = &api.Message{Role: "assistant", Content: sb.String()} + c.JSON(http.StatusOK, r) return } @@ -281,15 +464,18 @@ func EmbeddingHandler(c *gin.Context) { return } - model, err := GetModel(req.Model) + sessionDuration := defaultSessionDuration + _, err = load(c, req.Model, req.Options, sessionDuration) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - workDir := c.GetString("workDir") - if err := load(c.Request.Context(), workDir, model, req.Options, 5*time.Minute); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + var pErr *fs.PathError + switch { + case errors.As(err, &pErr): + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) + case errors.Is(err, api.ErrInvalidOpts): + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + default: + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + } return } @@ -767,6 +953,7 @@ func Serve(ln net.Listener, allowOrigins []string) error { r.POST("/api/pull", PullModelHandler) r.POST("/api/generate", GenerateHandler) + r.POST("/api/chat", ChatHandler) r.POST("/api/embeddings", EmbeddingHandler) r.POST("/api/create", CreateModelHandler) r.POST("/api/push", PushModelHandler)