ollama/server/prompt.go

75 lines
1.7 KiB
Go
Raw Normal View History

package server
import (
2024-06-17 13:38:55 -04:00
"bytes"
"context"
"log/slog"
2024-06-17 13:38:55 -04:00
"slices"
"github.com/ollama/ollama/api"
2024-06-17 13:38:55 -04:00
"github.com/ollama/ollama/llm"
2024-06-10 17:54:42 -04:00
"github.com/ollama/ollama/template"
)
2024-06-17 13:38:55 -04:00
func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
// extract system messages which should always be included
var system []api.Message
msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
if m.Role == "system" {
system = append(system, m)
return true
}
2024-06-17 13:38:55 -04:00
return false
})
2024-06-17 13:38:55 -04:00
if len(system) == 0 && r.model.System != "" {
// add model system prompt since it wasn't provided
system = append(system, api.Message{Role: "system", Content: r.model.System})
}
2024-06-17 13:38:55 -04:00
n := len(msgs) - 1
for i := n - 1; i >= 0; i-- {
var b bytes.Buffer
if err := r.model.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
return "", nil, err
}
2024-06-17 13:38:55 -04:00
s, err := r.llama.Tokenize(ctx, b.String())
if err != nil {
2024-06-17 13:38:55 -04:00
return "", nil, err
}
2024-06-17 13:38:55 -04:00
c := len(s)
if r.model.ProjectorPaths != nil {
for _, m := range msgs[i:] {
// TODO: get image embedding length from project metadata
c += 768 * len(m.Images)
}
}
2024-06-17 13:38:55 -04:00
if c > r.NumCtx {
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
break
2024-06-17 13:38:55 -04:00
} else {
n = i
}
2024-06-17 13:38:55 -04:00
}
2024-06-17 13:38:55 -04:00
var b bytes.Buffer
if err := r.model.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
return "", nil, err
}
2024-06-17 13:38:55 -04:00
for _, m := range msgs[n:] {
for _, i := range m.Images {
images = append(images, llm.ImageData{
ID: len(images),
Data: i,
})
}
}
2024-06-17 13:38:55 -04:00
return b.String(), images, nil
}