diff --git a/README.md b/README.md index 2f4bbdba..39d2582c 100644 --- a/README.md +++ b/README.md @@ -214,6 +214,17 @@ curl http://localhost:11434/api/generate -d '{ }' ``` +Or send a chat message: + +``` +curl http://localhost:11434/api/chat -d '{ + "model": "mistral", + "messages": [ + { "role": "user", "content": "why is the sky blue?" } + ] +}' +``` + See the [API documentation](./docs/api.md) for all endpoints. ## Community Integrations diff --git a/api/client.go b/api/client.go index 44af222c..250711dd 100644 --- a/api/client.go +++ b/api/client.go @@ -221,6 +221,19 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate }) } +type ChatResponseFunc func(ChatResponse) error + +func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error { + return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error { + var resp ChatResponse + if err := json.Unmarshal(bts, &resp); err != nil { + return err + } + + return fn(resp) + }) +} + type PullProgressFunc func(ProgressResponse) error func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error { diff --git a/api/types.go b/api/types.go index 692c4445..3fbc5829 100644 --- a/api/types.go +++ b/api/types.go @@ -44,6 +44,39 @@ type GenerateRequest struct { Options map[string]interface{} `json:"options"` } +type ChatRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Stream *bool `json:"stream,omitempty"` + Format string `json:"format"` + + Options map[string]interface{} `json:"options"` +} + +type Message struct { + Role string `json:"role"` // one of ["system", "user", "assistant"] + Content string `json:"content"` +} + +type ChatResponse struct { + Model string `json:"model"` + CreatedAt time.Time `json:"created_at"` + Message *Message `json:"message,omitempty"` + + Done bool `json:"done"` + + Metrics +} + +type Metrics struct { + TotalDuration time.Duration `json:"total_duration,omitempty"` + LoadDuration time.Duration `json:"load_duration,omitempty"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + EvalDuration time.Duration `json:"eval_duration,omitempty"` +} + // Options specfied in GenerateRequest, if you add a new option here add it to the API docs also type Options struct { Runner @@ -173,39 +206,34 @@ type GenerateResponse struct { Done bool `json:"done"` Context []int `json:"context,omitempty"` - TotalDuration time.Duration `json:"total_duration,omitempty"` - LoadDuration time.Duration `json:"load_duration,omitempty"` - PromptEvalCount int `json:"prompt_eval_count,omitempty"` - PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` - EvalCount int `json:"eval_count,omitempty"` - EvalDuration time.Duration `json:"eval_duration,omitempty"` + Metrics } -func (r *GenerateResponse) Summary() { - if r.TotalDuration > 0 { - fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration) +func (m *Metrics) Summary() { + if m.TotalDuration > 0 { + fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration) } - if r.LoadDuration > 0 { - fmt.Fprintf(os.Stderr, "load duration: %v\n", r.LoadDuration) + if m.LoadDuration > 0 { + fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration) } - if r.PromptEvalCount > 0 { - fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount) + if m.PromptEvalCount > 0 { + fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", m.PromptEvalCount) } - if r.PromptEvalDuration > 0 { - fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", r.PromptEvalDuration) - fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(r.PromptEvalCount)/r.PromptEvalDuration.Seconds()) + if m.PromptEvalDuration > 0 { + fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", m.PromptEvalDuration) + fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(m.PromptEvalCount)/m.PromptEvalDuration.Seconds()) } - if r.EvalCount > 0 { - fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", r.EvalCount) + if m.EvalCount > 0 { + fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", m.EvalCount) } - if r.EvalDuration > 0 { - fmt.Fprintf(os.Stderr, "eval duration: %s\n", r.EvalDuration) - fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(r.EvalCount)/r.EvalDuration.Seconds()) + if m.EvalDuration > 0 { + fmt.Fprintf(os.Stderr, "eval duration: %s\n", m.EvalDuration) + fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(m.EvalCount)/m.EvalDuration.Seconds()) } } diff --git a/docs/api.md b/docs/api.md index 0595fadd..cac8ddd0 100644 --- a/docs/api.md +++ b/docs/api.md @@ -24,7 +24,7 @@ All durations are returned in nanoseconds. ### Streaming responses -Certain endpoints stream responses as JSON objects delineated with the newline (`\n`) character. +Certain endpoints stream responses as JSON objects. ## Generate a completion @@ -32,7 +32,7 @@ Certain endpoints stream responses as JSON objects delineated with the newline ( POST /api/generate ``` -Generate a response for a given prompt with a provided model. This is a streaming endpoint, so will be a series of responses. The final response object will include statistics and additional data from the request. +Generate a response for a given prompt with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request. ### Parameters @@ -47,7 +47,7 @@ Advanced parameters (optional): - `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`) - `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects -- `raw`: if `true` no formatting will be applied to the prompt and no context will be returned. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API, and are managing history yourself. +- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API. ### JSON mode @@ -57,7 +57,7 @@ Enable JSON mode by setting the `format` parameter to `json`. This will structur ### Examples -#### Request +#### Request (Prompt) ```shell curl http://localhost:11434/api/generate -d '{ @@ -114,6 +114,8 @@ To calculate how fast the response is generated in tokens per second (token/s), #### Request (No streaming) +A response can be recieved in one reply when streaming is off. + ```shell curl http://localhost:11434/api/generate -d '{ "model": "llama2", @@ -144,9 +146,9 @@ If `stream` is set to `false`, the response will be a single JSON object: } ``` -#### Request (Raw mode) +#### Request (Raw Mode) -In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting and context. +In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting. ```shell curl http://localhost:11434/api/generate -d '{ @@ -164,6 +166,7 @@ curl http://localhost:11434/api/generate -d '{ "model": "mistral", "created_at": "2023-11-03T15:36:02.583064Z", "response": " The sky appears blue because of a phenomenon called Rayleigh scattering.", + "context": [1, 2, 3], "done": true, "total_duration": 14648695333, "load_duration": 3302671417, @@ -275,7 +278,6 @@ curl http://localhost:11434/api/generate -d '{ "model": "llama2", "created_at": "2023-08-04T19:22:45.499127Z", "response": "The sky is blue because it is the color of the sky.", - "context": [1, 2, 3], "done": true, "total_duration": 5589157167, "load_duration": 3013701500, @@ -288,6 +290,133 @@ curl http://localhost:11434/api/generate -d '{ } ``` +## Send Chat Messages +```shell +POST /api/chat +``` + +Generate the next message in a chat with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request. + +### Parameters + +- `model`: (required) the [model name](#model-names) +- `messages`: the messages of the chat, this can be used to keep a chat memory + +Advanced parameters (optional): + +- `format`: the format to return a response in. Currently the only accepted value is `json` +- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` +- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`) +- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects + +### Examples + +#### Request +Send a chat message with a streaming response. + +```shell +curl http://localhost:11434/api/generate -d '{ + "model": "llama2", + "messages": [ + { + "role": "user", + "content": "why is the sky blue?" + } + ] +}' +``` + +#### Response + +A stream of JSON objects is returned: + +```json +{ + "model": "llama2", + "created_at": "2023-08-04T08:52:19.385406455-07:00", + "message": { + "role": "assisant", + "content": "The" + }, + "done": false +} +``` + +Final response: + +```json +{ + "model": "llama2", + "created_at": "2023-08-04T19:22:45.499127Z", + "done": true, + "total_duration": 5589157167, + "load_duration": 3013701500, + "sample_count": 114, + "sample_duration": 81442000, + "prompt_eval_count": 46, + "prompt_eval_duration": 1160282000, + "eval_count": 113, + "eval_duration": 1325948000 +} +``` + +#### Request (With History) +Send a chat message with a conversation history. + +```shell +curl http://localhost:11434/api/generate -d '{ + "model": "llama2", + "messages": [ + { + "role": "user", + "content": "why is the sky blue?" + }, + { + "role": "assistant", + "content": "due to rayleigh scattering." + }, + { + "role": "user", + "content": "how is that different than mie scattering?" + } + ] +}' +``` + +#### Response + +A stream of JSON objects is returned: + +```json +{ + "model": "llama2", + "created_at": "2023-08-04T08:52:19.385406455-07:00", + "message": { + "role": "assisant", + "content": "The" + }, + "done": false +} +``` + +Final response: + +```json +{ + "model": "llama2", + "created_at": "2023-08-04T19:22:45.499127Z", + "done": true, + "total_duration": 5589157167, + "load_duration": 3013701500, + "sample_count": 114, + "sample_duration": 81442000, + "prompt_eval_count": 46, + "prompt_eval_duration": 1160282000, + "eval_count": 113, + "eval_duration": 1325948000 +} +``` + ## Create a Model ```shell diff --git a/llm/llama.go b/llm/llama.go index fc033258..bd2cb1f1 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -531,21 +531,30 @@ type prediction struct { const maxBufferSize = 512 * format.KiloByte -func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error { - prevConvo, err := llm.Decode(ctx, prevContext) - if err != nil { - return err - } +type PredictOpts struct { + Model string + Prompt string + Format string + CheckpointStart time.Time + CheckpointLoaded time.Time +} - // Remove leading spaces from prevConvo if present - prevConvo = strings.TrimPrefix(prevConvo, " ") - - var nextContext strings.Builder - nextContext.WriteString(prevConvo) - nextContext.WriteString(prompt) +type PredictResult struct { + Model string + CreatedAt time.Time + TotalDuration time.Duration + LoadDuration time.Duration + Content string + Done bool + PromptEvalCount int + PromptEvalDuration time.Duration + EvalCount int + EvalDuration time.Duration +} +func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error { request := map[string]any{ - "prompt": nextContext.String(), + "prompt": predict.Prompt, "stream": true, "n_predict": llm.NumPredict, "n_keep": llm.NumKeep, @@ -567,7 +576,7 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, "stop": llm.Stop, } - if format == "json" { + if predict.Format == "json" { request["grammar"] = jsonGrammar } @@ -624,25 +633,25 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, } if p.Content != "" { - fn(api.GenerateResponse{Response: p.Content}) - nextContext.WriteString(p.Content) + fn(PredictResult{ + Model: predict.Model, + CreatedAt: time.Now().UTC(), + Content: p.Content, + }) } if p.Stop { - embd, err := llm.Encode(ctx, nextContext.String()) - if err != nil { - return fmt.Errorf("encoding context: %v", err) - } + fn(PredictResult{ + Model: predict.Model, + CreatedAt: time.Now().UTC(), + TotalDuration: time.Since(predict.CheckpointStart), - fn(api.GenerateResponse{ Done: true, - Context: embd, PromptEvalCount: p.Timings.PromptN, PromptEvalDuration: parseDurationMs(p.Timings.PromptMS), EvalCount: p.Timings.PredictedN, EvalDuration: parseDurationMs(p.Timings.PredictedMS), }) - return nil } } diff --git a/llm/llm.go b/llm/llm.go index 4901d9fe..d7b2a7cd 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -14,7 +14,7 @@ import ( ) type LLM interface { - Predict(context.Context, []int, string, string, func(api.GenerateResponse)) error + Predict(context.Context, PredictOpts, func(PredictResult)) error Embedding(context.Context, string) ([]float64, error) Encode(context.Context, string) ([]int, error) Decode(context.Context, []int) (string, error) diff --git a/server/images.go b/server/images.go index 294fdf2b..3337058d 100644 --- a/server/images.go +++ b/server/images.go @@ -47,37 +47,85 @@ type Model struct { Options map[string]interface{} } -func (m *Model) Prompt(request api.GenerateRequest) (string, error) { - t := m.Template - if request.Template != "" { - t = request.Template - } +type PromptVars struct { + System string + Prompt string + Response string + First bool +} - tmpl, err := template.New("").Parse(t) +func (m *Model) Prompt(p PromptVars) (string, error) { + var prompt strings.Builder + tmpl, err := template.New("").Parse(m.Template) if err != nil { return "", err } - var vars struct { - First bool - System string - Prompt string - } - - vars.First = len(request.Context) == 0 - vars.System = m.System - vars.Prompt = request.Prompt - - if request.System != "" { - vars.System = request.System + if p.System == "" { + // use the default system prompt for this model if one is not specified + p.System = m.System } var sb strings.Builder - if err := tmpl.Execute(&sb, vars); err != nil { + if err := tmpl.Execute(&sb, p); err != nil { return "", err } + prompt.WriteString(sb.String()) + prompt.WriteString(p.Response) + return prompt.String(), nil +} - return sb.String(), nil +func (m *Model) ChatPrompt(msgs []api.Message) (string, error) { + // build the prompt from the list of messages + var prompt strings.Builder + currentVars := PromptVars{ + First: true, + } + + writePrompt := func() error { + p, err := m.Prompt(currentVars) + if err != nil { + return err + } + prompt.WriteString(p) + currentVars = PromptVars{} + return nil + } + + for _, msg := range msgs { + switch msg.Role { + case "system": + if currentVars.Prompt != "" || currentVars.System != "" { + if err := writePrompt(); err != nil { + return "", err + } + } + currentVars.System = msg.Content + case "user": + if currentVars.Prompt != "" || currentVars.System != "" { + if err := writePrompt(); err != nil { + return "", err + } + } + currentVars.Prompt = msg.Content + case "assistant": + currentVars.Response = msg.Content + if err := writePrompt(); err != nil { + return "", err + } + default: + return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role) + } + } + + // Append the last set of vars if they are non-empty + if currentVars.Prompt != "" || currentVars.System != "" { + if err := writePrompt(); err != nil { + return "", err + } + } + + return prompt.String(), nil } type ManifestV2 struct { @@ -383,7 +431,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars c.Args = blobPath } - + fn(api.ProgressResponse{Status: "creating adapter layer"}) bin, err := os.Open(realpath(modelFileDir, c.Args)) if err != nil { diff --git a/server/images_test.go b/server/images_test.go index 5e6a197b..85e8d4bd 100644 --- a/server/images_test.go +++ b/server/images_test.go @@ -2,17 +2,15 @@ package server import ( "testing" - - "github.com/jmorganca/ollama/api" ) func TestModelPrompt(t *testing.T) { - var m Model - req := api.GenerateRequest{ + m := Model{ Template: "a{{ .Prompt }}b", - Prompt: "