treat stop as stop sequences, not exact tokens (#442)
The `stop` option to the generate API is a list of sequences that should cause generation to stop. Although these are commonly called "stop tokens", they do not necessarily correspond to LLM tokens (per the LLM's tokenizer). For example, if the caller sends a generate request with `"stop":["\n"]`, then generation should stop on any token containing `\n` (and trim `\n` from the output), not just if the token exactly matches `\n`. If `stop` were interpreted strictly as LLM tokens, then it would require callers of the generate API to know the LLM's tokenizer and enumerate many tokens in the `stop` list. Fixes https://github.com/jmorganca/ollama/issues/295.
This commit is contained in:
parent
982c535428
commit
f4432e1dba
|
@ -123,7 +123,7 @@ PARAMETER <parameter> <parametervalue>
|
|||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
||||
| stop | Sets the stop tokens to use. | string | stop "AI assistant:" |
|
||||
| stop | Sets the stop sequences to use. | string | stop "AI assistant:" |
|
||||
| tfs_z | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | float | tfs_z 1 |
|
||||
| top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |
|
||||
| top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 |
|
||||
|
|
43
llm/llama.go
43
llm/llama.go
|
@ -334,20 +334,18 @@ func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse
|
|||
|
||||
b.WriteString(llm.Decode(int(token)))
|
||||
|
||||
if err := llm.checkStopConditions(b); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
} else if errors.Is(err, errNeedMoreData) {
|
||||
continue
|
||||
}
|
||||
|
||||
return err
|
||||
stop, endsWithStopPrefix := handleStopSequences(&b, llm.Stop)
|
||||
if endsWithStopPrefix {
|
||||
continue
|
||||
}
|
||||
|
||||
if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax {
|
||||
fn(api.GenerateResponse{Response: b.String()})
|
||||
b.Reset()
|
||||
}
|
||||
if stop {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
embd := make([]int, len(llm.embd))
|
||||
|
@ -370,16 +368,31 @@ func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse
|
|||
return nil
|
||||
}
|
||||
|
||||
func (llm *llama) checkStopConditions(b bytes.Buffer) error {
|
||||
for _, stopCondition := range llm.Stop {
|
||||
if stopCondition == strings.TrimSpace(b.String()) {
|
||||
return io.EOF
|
||||
} else if strings.HasPrefix(stopCondition, strings.TrimSpace(b.String())) {
|
||||
return errNeedMoreData
|
||||
// handleStopSequences checks whether b contains any of the stop sequences, or ends with a prefix of
|
||||
// any stop sequence (and therefore might contain data that should not ultimately be returned to the
|
||||
// client).
|
||||
//
|
||||
// If b contains a stop sequence, it modifies b to remove the stop sequence and all subsequent data.
|
||||
func handleStopSequences(b *bytes.Buffer, stopSequences []string) (stop bool, endsWithStopPrefix bool) {
|
||||
s := b.String()
|
||||
for _, seq := range stopSequences {
|
||||
// Check for an exact or substring match.
|
||||
if i := strings.Index(s, seq); i != -1 {
|
||||
b.Truncate(i)
|
||||
return true, false
|
||||
}
|
||||
|
||||
// Check if b ends with a prefix of the stop sequence.
|
||||
if len(seq) > 1 {
|
||||
for i := 1; i < len(seq); i++ {
|
||||
if strings.HasSuffix(s, seq[:i]) {
|
||||
return false, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return false, false
|
||||
}
|
||||
|
||||
func (llm *llama) marshalPrompt(ctx []int, prompt string) []C.llama_token {
|
||||
|
|
79
llm/llama_test.go
Normal file
79
llm/llama_test.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCheckStopConditions(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
b string
|
||||
stop []string
|
||||
wantB string
|
||||
wantStop bool
|
||||
wantEndsWithStopPrefix bool
|
||||
}{
|
||||
"not present": {
|
||||
b: "abc",
|
||||
stop: []string{"x"},
|
||||
wantStop: false,
|
||||
wantEndsWithStopPrefix: false,
|
||||
},
|
||||
"exact": {
|
||||
b: "abc",
|
||||
stop: []string{"abc"},
|
||||
wantStop: true,
|
||||
wantEndsWithStopPrefix: false,
|
||||
},
|
||||
"substring": {
|
||||
b: "abc",
|
||||
stop: []string{"b"},
|
||||
wantB: "a",
|
||||
wantStop: true,
|
||||
wantEndsWithStopPrefix: false,
|
||||
},
|
||||
"prefix 1": {
|
||||
b: "abc",
|
||||
stop: []string{"abcd"},
|
||||
wantStop: false,
|
||||
wantEndsWithStopPrefix: true,
|
||||
},
|
||||
"prefix 2": {
|
||||
b: "abc",
|
||||
stop: []string{"bcd"},
|
||||
wantStop: false,
|
||||
wantEndsWithStopPrefix: true,
|
||||
},
|
||||
"prefix 3": {
|
||||
b: "abc",
|
||||
stop: []string{"cd"},
|
||||
wantStop: false,
|
||||
wantEndsWithStopPrefix: true,
|
||||
},
|
||||
"no prefix": {
|
||||
b: "abc",
|
||||
stop: []string{"bx"},
|
||||
wantStop: false,
|
||||
wantEndsWithStopPrefix: false,
|
||||
},
|
||||
}
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
b.WriteString(test.b)
|
||||
stop, endsWithStopPrefix := handleStopSequences(&b, test.stop)
|
||||
if test.wantB != "" {
|
||||
gotB := b.String()
|
||||
if gotB != test.wantB {
|
||||
t.Errorf("got b %q, want %q", gotB, test.wantB)
|
||||
}
|
||||
}
|
||||
if stop != test.wantStop {
|
||||
t.Errorf("got stop %v, want %v", stop, test.wantStop)
|
||||
}
|
||||
if endsWithStopPrefix != test.wantEndsWithStopPrefix {
|
||||
t.Errorf("got endsWithStopPrefix %v, want %v", endsWithStopPrefix, test.wantEndsWithStopPrefix)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -430,7 +430,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
|||
layer.MediaType = mediaType
|
||||
layers = append(layers, layer)
|
||||
default:
|
||||
// runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop tokens)
|
||||
// runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop sequences)
|
||||
params[c.Name] = append(params[c.Name], c.Args)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue