Merge pull request #93 from jmorganca/split-prompt

separate prompt into template and system
This commit is contained in:
Michael Yang 2023-07-19 23:25:33 -07:00 committed by GitHub
commit 6984171cfd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 161 additions and 127 deletions

View file

@ -12,11 +12,12 @@ INSTRUCTION arguments
``` ```
| Instruction | Description | | Instruction | Description |
|------------------------- |--------------------------------------------------------- | | ------------------------- | ----------------------------------------------------- |
| FROM<br>(required) | Defines the base model to be used when creating a model | | `FROM`<br>(required) | Defines the base model to use |
| PARAMETER<br>(optional) | Sets the parameters for how the model will be run | | `PARAMETER`<br>(optional) | Sets the parameters for how Ollama will run the model |
| PROMPT <br>(optional) | Sets the prompt to use when the model will be run | | `SYSTEM`<br>(optional) | Specifies the system prompt that will set the context |
| LICENSE<br>(optional) | Specify the license of the model. It is additive, and | | `TEMPLATE`<br>(optional) | The full prompt template to be sent to the model |
| `LICENSE`<br>(optional) | Specifies the legal license |
## Examples ## Examples
@ -24,12 +25,13 @@ An example of a model file creating a mario blueprint:
``` ```
FROM llama2 FROM llama2
# sets the temperature to 1 [higher is more creative, lower is more coherent]
# sets the context size to 4096
PARAMETER temperature 1 PARAMETER temperature 1
PROMPT """ PARAMETER num_ctx 4096
System: You are Mario from super mario bros, acting as an assistant.
User: {{ .Prompt }} # Overriding the system prompt
Assistant: SYSTEM You are Mario from super mario bros, acting as an assistant.
"""
``` ```
To use this: To use this:
@ -41,7 +43,7 @@ To use this:
## FROM (Required) ## FROM (Required)
The FROM instruction defines the base model to be used when creating a model. The FROM instruction defines the base model to use when creating a model.
``` ```
FROM <model name>:<tag> FROM <model name>:<tag>
@ -64,7 +66,7 @@ FROM ./ollama-model.bin
## PARAMETER (Optional) ## PARAMETER (Optional)
The PARAMETER instruction defines a parameter that can be set when the model is run. The `PARAMETER` instruction defines a parameter that can be set when the model is run.
``` ```
PARAMETER <parameter> <parametervalue> PARAMETER <parameter> <parametervalue>
@ -73,37 +75,27 @@ PARAMETER <parameter> <parametervalue>
### Valid Parameters and Values ### Valid Parameters and Values
| Parameter | Description | Value Type | Example Usage | | Parameter | Description | Value Type | Example Usage |
|---------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------|-------------------| | -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | ------------------ |
| NumCtx | Sets the size of the prompt context size length model. (Default: 2048) | int | Numctx 4096 | | num_ctx | Sets the size of the prompt context size length model. (Default: 2048) | int | num_ctx 4096 |
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | Temperature 0.7 | | temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
| TopK | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | TopK 40 | | top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |
| TopP | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | TopP 0.9 | | top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 |
| NumGPU | The number of GPUs to use. On macOS it defaults to 1 to enable metal support, 0 to disable. | int | numGPU 1 | | num_gpu | The number of GPUs to use. On macOS it defaults to 1 to enable metal support, 0 to disable. | int | num_gpu 1 |
| RepeatLastN | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = ctx-size) | int | RepeatLastN 64 | | repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = ctx-size) | int | repeat_last_n 64 |
| RepeatPenalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | RepeatPenalty 1.1 | | repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
| TFSZ | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | float | TFSZ 1 | | tfs_z | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | float | tfs_z 1 |
| Mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | Mirostat 0 | | mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
| MirostatTau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | MirostatTau 5.0 | | mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
| MirostatEta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | MirostatEta 0.1 | | mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
| NumThread | Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance. It is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). | int | NumThread 8 | | num_thread | Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance. It is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). | int | num_thread 8 |
## PROMPT ## Prompt
Prompt is a set of instructions to an LLM to cause the model to return desired response(s). Typically there are 3-4 components to a prompt: System, context, user, and response. When building on top of the base models supplied by Ollama, it comes with the prompt template predefined. To override the supplied system prompt, simply add `SYSTEM insert system prompt` to change the system prompt.
```modelfile ### Prompt Template
PROMPT """
{{- if not .Context }}
### System:
You are a content marketer who needs to come up with a short but succinct tweet. Make sure to include the appropriate hashtags and links. Sometimes when appropriate, describe a meme that can be includes as well. All answers should be in the form of a tweet which has a max size of 280 characters. Every instruction will be the topic to create a tweet about.
{{- end }}
### Instruction:
{{ .Prompt }}
### Response: `TEMPLATE` the full prompt template to be passed into the model. It may include (optionally) a system prompt, user prompt, and assistant prompt. This is used to create a full custom prompt, and syntax may be model specific.
"""
```
## Notes ## Notes

View file

@ -2,76 +2,91 @@ package parser
import ( import (
"bufio" "bufio"
"bytes"
"errors"
"fmt" "fmt"
"io" "io"
"strings"
) )
type Command struct { type Command struct {
Name string Name string
Arg string Args string
}
func (c *Command) Reset() {
c.Name = ""
c.Args = ""
} }
func Parse(reader io.Reader) ([]Command, error) { func Parse(reader io.Reader) ([]Command, error) {
var commands []Command var commands []Command
var foundModel bool
var command, modelCommand Command
scanner := bufio.NewScanner(reader) scanner := bufio.NewScanner(reader)
multiline := false scanner.Split(scanModelfile)
var multilineCommand *Command
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Bytes()
if multiline {
// If we're in a multiline string and the line is """, end the multiline string. fields := bytes.SplitN(line, []byte(" "), 2)
if strings.TrimSpace(line) == `"""` {
multiline = false
commands = append(commands, *multilineCommand)
} else {
// Otherwise, append the line to the multiline string.
multilineCommand.Arg += "\n" + line
}
continue
}
fields := strings.Fields(line)
if len(fields) == 0 { if len(fields) == 0 {
continue continue
} }
command := Command{} switch string(bytes.ToUpper(fields[0])) {
switch strings.ToUpper(fields[0]) {
case "FROM": case "FROM":
command.Name = "model" command.Name = "model"
command.Arg = fields[1] command.Args = string(fields[1])
if command.Arg == "" { // copy command for validation
return nil, fmt.Errorf("no model specified in FROM line") modelCommand = command
} case "LICENSE", "TEMPLATE", "SYSTEM":
foundModel = true command.Name = string(bytes.ToLower(fields[0]))
case "PROMPT", "LICENSE": command.Args = string(fields[1])
command.Name = strings.ToLower(fields[0])
if fields[1] == `"""` {
multiline = true
multilineCommand = &command
multilineCommand.Arg = ""
} else {
command.Arg = strings.Join(fields[1:], " ")
}
case "PARAMETER": case "PARAMETER":
command.Name = fields[1] fields = bytes.SplitN(fields[1], []byte(" "), 2)
command.Arg = strings.Join(fields[2:], " ") command.Name = string(fields[0])
command.Args = string(fields[1])
default: default:
continue continue
} }
if !multiline {
commands = append(commands, command) commands = append(commands, command)
} command.Reset()
} }
if !foundModel { if modelCommand.Args == "" {
return nil, fmt.Errorf("no FROM line for the model was specified") return nil, fmt.Errorf("no FROM line for the model was specified")
} }
if multiline {
return nil, fmt.Errorf("unclosed multiline string")
}
return commands, scanner.Err() return commands, scanner.Err()
} }
func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF || len(data) == 0 {
return 0, nil, nil
}
newline := bytes.IndexByte(data, '\n')
if start := bytes.Index(data, []byte(`"""`)); start >= 0 && start < newline {
end := bytes.Index(data[start+3:], []byte(`"""`))
if end < 0 {
return 0, nil, errors.New(`unterminated multiline string: """`)
}
n := start + 3 + end + 3
return n, bytes.Replace(data[:n], []byte(`"""`), []byte(""), 2), nil
}
if start := bytes.Index(data, []byte(`'''`)); start >= 0 && start < newline {
end := bytes.Index(data[start+3:], []byte(`'''`))
if end < 0 {
return 0, nil, errors.New("unterminated multiline string: '''")
}
n := start + 3 + end + 3
return n, bytes.Replace(data[:n], []byte("'''"), []byte(""), 2), nil
}
return bufio.ScanLines(data, atEOF)
}

View file

@ -16,6 +16,7 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"text/template"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/parser"
@ -24,10 +25,39 @@ import (
type Model struct { type Model struct {
Name string `json:"name"` Name string `json:"name"`
ModelPath string ModelPath string
Prompt string Template string
System string
Options api.Options Options api.Options
} }
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
tmpl, err := template.New("").Parse(m.Template)
if err != nil {
return "", err
}
var vars struct {
First bool
System string
Prompt string
// deprecated: versions <= 0.0.7 used this to omit the system prompt
Context []int
}
vars.First = len(vars.Context) == 0
vars.System = m.System
vars.Prompt = request.Prompt
vars.Context = request.Context
var sb strings.Builder
if err := tmpl.Execute(&sb, vars); err != nil {
return "", err
}
return sb.String(), nil
}
type ManifestV2 struct { type ManifestV2 struct {
SchemaVersion int `json:"schemaVersion"` SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"` MediaType string `json:"mediaType"`
@ -71,20 +101,19 @@ func GetManifest(mp ModelPath) (*ManifestV2, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if _, err = os.Stat(fp); err != nil && !errors.Is(err, os.ErrNotExist) { if _, err = os.Stat(fp); err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("couldn't find model '%s'", mp.GetShortTagname()) return nil, fmt.Errorf("couldn't find model '%s'", mp.GetShortTagname())
} }
var manifest *ManifestV2 var manifest *ManifestV2
f, err := os.Open(fp) bts, err := os.ReadFile(fp)
if err != nil { if err != nil {
return nil, fmt.Errorf("couldn't open file '%s'", fp) return nil, fmt.Errorf("couldn't open file '%s'", fp)
} }
decoder := json.NewDecoder(f) if err := json.Unmarshal(bts, &manifest); err != nil {
err = decoder.Decode(&manifest)
if err != nil {
return nil, err return nil, err
} }
@ -112,12 +141,28 @@ func GetModel(name string) (*Model, error) {
switch layer.MediaType { switch layer.MediaType {
case "application/vnd.ollama.image.model": case "application/vnd.ollama.image.model":
model.ModelPath = filename model.ModelPath = filename
case "application/vnd.ollama.image.prompt": case "application/vnd.ollama.image.template":
data, err := os.ReadFile(filename) bts, err := os.ReadFile(filename)
if err != nil { if err != nil {
return nil, err return nil, err
} }
model.Prompt = string(data)
model.Template = string(bts)
case "application/vnd.ollama.image.system":
bts, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
model.System = string(bts)
case "application/vnd.ollama.image.prompt":
log.Printf("PROMPT is deprecated. Please use TEMPLATE and SYSTEM instead.")
bts, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
model.Template = string(bts)
case "application/vnd.ollama.image.params": case "application/vnd.ollama.image.params":
params, err := os.Open(filename) params, err := os.Open(filename)
if err != nil { if err != nil {
@ -156,13 +201,13 @@ func CreateModel(name string, path string, fn func(status string)) error {
params := make(map[string]string) params := make(map[string]string)
for _, c := range commands { for _, c := range commands {
log.Printf("[%s] - %s\n", c.Name, c.Arg) log.Printf("[%s] - %s\n", c.Name, c.Args)
switch c.Name { switch c.Name {
case "model": case "model":
fn("looking for model") fn("looking for model")
mf, err := GetManifest(ParseModelPath(c.Arg)) mf, err := GetManifest(ParseModelPath(c.Args))
if err != nil { if err != nil {
fp := c.Arg fp := c.Args
// If filePath starts with ~/, replace it with the user's home directory. // If filePath starts with ~/, replace it with the user's home directory.
if strings.HasPrefix(fp, "~/") { if strings.HasPrefix(fp, "~/") {
@ -183,7 +228,7 @@ func CreateModel(name string, path string, fn func(status string)) error {
fn("creating model layer") fn("creating model layer")
file, err := os.Open(fp) file, err := os.Open(fp)
if err != nil { if err != nil {
fn(fmt.Sprintf("couldn't find model '%s'", c.Arg)) fn(fmt.Sprintf("couldn't find model '%s'", c.Args))
return fmt.Errorf("failed to open file: %v", err) return fmt.Errorf("failed to open file: %v", err)
} }
defer file.Close() defer file.Close()
@ -206,31 +251,21 @@ func CreateModel(name string, path string, fn func(status string)) error {
layers = append(layers, newLayer) layers = append(layers, newLayer)
} }
} }
case "prompt": case "license", "template", "system":
fn("creating prompt layer") fn(fmt.Sprintf("creating %s layer", c.Name))
// remove the prompt layer if one exists // remove the prompt layer if one exists
layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.prompt") mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
layers = removeLayerFromLayers(layers, mediaType)
prompt := strings.NewReader(c.Arg) layer, err := CreateLayer(strings.NewReader(c.Args))
l, err := CreateLayer(prompt)
if err != nil { if err != nil {
fn(fmt.Sprintf("couldn't create prompt layer: %v", err)) return err
return fmt.Errorf("failed to create layer: %v", err)
} }
l.MediaType = "application/vnd.ollama.image.prompt"
layers = append(layers, l) layer.MediaType = mediaType
case "license": layers = append(layers, layer)
fn("creating license layer")
license := strings.NewReader(c.Arg)
l, err := CreateLayer(license)
if err != nil {
fn(fmt.Sprintf("couldn't create license layer: %v", err))
return fmt.Errorf("failed to create layer: %v", err)
}
l.MediaType = "application/vnd.ollama.image.license"
layers = append(layers, l)
default: default:
params[c.Name] = c.Arg params[c.Name] = c.Args
} }
} }

View file

@ -9,7 +9,6 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"text/template"
"time" "time"
"dario.cat/mergo" "dario.cat/mergo"
@ -54,19 +53,12 @@ func generate(c *gin.Context) {
return return
} }
templ, err := template.New("").Parse(model.Prompt) prompt, err := model.Prompt(req)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
var sb strings.Builder
if err = templ.Execute(&sb, req); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
req.Prompt = sb.String()
llm, err := llama.New(model.ModelPath, opts) llm, err := llama.New(model.ModelPath, opts)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@ -77,7 +69,7 @@ func generate(c *gin.Context) {
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
llm.Predict(req.Context, req.Prompt, func(r api.GenerateResponse) { llm.Predict(req.Context, prompt, func(r api.GenerateResponse) {
r.Model = req.Model r.Model = req.Model
r.CreatedAt = time.Now().UTC() r.CreatedAt = time.Now().UTC()
if r.Done { if r.Done {