diff --git a/api/types.go b/api/types.go index dccfbf7a..5f8c3891 100644 --- a/api/types.go +++ b/api/types.go @@ -134,6 +134,7 @@ type Options struct { // Model options NumCtx int `json:"num_ctx,omitempty"` + NumKeep int `json:"num_keep,omitempty"` NumBatch int `json:"num_batch,omitempty"` NumGPU int `json:"num_gpu,omitempty"` MainGPU int `json:"main_gpu,omitempty"` @@ -158,6 +159,7 @@ type Options struct { Mirostat int `json:"mirostat,omitempty"` MirostatTau float32 `json:"mirostat_tau,omitempty"` MirostatEta float32 `json:"mirostat_eta,omitempty"` + PenalizeNewline bool `json:"penalize_newline,omitempty"` NumThread int `json:"num_thread,omitempty"` } @@ -176,7 +178,7 @@ func DefaultOptions() Options { UseMMap: true, UseMLock: false, - RepeatLastN: 512, + RepeatLastN: 64, RepeatPenalty: 1.1, FrequencyPenalty: 0.0, PresencePenalty: 0.0, @@ -188,6 +190,7 @@ func DefaultOptions() Options { Mirostat: 0, MirostatTau: 5.0, MirostatEta: 0.1, + PenalizeNewline: true, NumThread: runtime.NumCPU(), } diff --git a/llama/llama.go b/llama/llama.go index a48c5965..9f5066f3 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -1,8 +1,8 @@ package llama /* -#cgo CPPFLAGS: -O3 -DNDEBUG=1 -DGGML_USE_K_QUANTS -#cgo CXXFLAGS: -std=c++11 +#cgo CPPFLAGS: -O3 -Wall -Wextra -Werror -Wno-unused-function -Wno-unused-variable -DNDEBUG -DGGML_USE_K_QUANTS +#cgo CXXFLAGS: -std=gnu++11 #cgo darwin CPPFLAGS: -DGGML_USE_ACCELERATE -DGGML_USE_METAL -DGGML_METAL_NDEBUG #cgo darwin LDFLAGS: -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders #include @@ -21,6 +21,7 @@ struct llama_sample_options int mirostat; float mirostat_tau; float mirostat_eta; + bool penalize_newline; }; llama_token llama_sample( @@ -37,6 +38,8 @@ llama_token llama_sample( false, }; + struct llama_token_data newline = candidates_p.data[llama_token_nl()]; + llama_sample_repetition_penalty( ctx, &candidates_p, last_tokens, n_last_tokens, @@ -47,6 +50,10 @@ llama_token llama_sample( last_tokens, n_last_tokens, opts->frequency_penalty, opts->presence_penalty); + if (!opts->penalize_newline) { + candidates_p.data[llama_token_nl()] = newline; + } + if (opts->temperature <= 0) { return llama_sample_token_greedy(ctx, &candidates_p); } @@ -82,9 +89,9 @@ import ( "errors" "fmt" "io" + "log" "os" "strings" - "time" "unicode/utf8" "unsafe" @@ -96,6 +103,10 @@ type LLM struct { model *C.struct_llama_model ctx *C.struct_llama_context + last []C.llama_token + embd []C.llama_token + cursor int + api.Options } @@ -152,16 +163,98 @@ func (llm *LLM) Close() { } func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error { - if input := llm.tokenize(prompt); input != nil { - embd := make([]C.llama_token, len(ctx)) - for i := range ctx { - embd[i] = C.llama_token(ctx[i]) - } + C.llama_reset_timings(llm.ctx) - return llm.generate(append(embd, input...), fn) + tokens := make([]C.llama_token, len(ctx)) + for i := range tokens { + tokens[i] = C.llama_token(ctx[i]) } - return errors.New("llama: tokenize") + if len(tokens) == 0 { + tokens = llm.tokenize(" ") + } + + llm.marshalPrompt(tokens, prompt) + + C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed)) + + var b bytes.Buffer + for { + token, err := llm.next() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return err + } + + b.WriteString(llm.detokenize(token)) + if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax { + fn(api.GenerateResponse{Response: b.String()}) + b.Reset() + } + } + + last := make([]int, 0, len(llm.last)) + for _, i := range llm.last { + if i != 0 { + last = append(last, int(i)) + } + } + + timings := C.llama_get_timings(llm.ctx) + fn(api.GenerateResponse{ + Done: true, + Context: last, + PromptEvalCount: int(timings.n_p_eval), + PromptEvalDuration: parseDurationMs(float64(timings.t_p_eval_ms)), + EvalCount: int(timings.n_eval), + EvalDuration: parseDurationMs(float64(timings.t_eval_ms)), + }) + + return nil +} + +func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token { + tokens := append(ctx, llm.tokenize(prompt)...) + if llm.NumKeep < 0 { + llm.NumKeep = len(tokens) + } + + // min(llm.NumCtx - 4, llm.NumKeep) + if llm.NumCtx-4 < llm.NumKeep { + llm.NumKeep = llm.NumCtx - 4 + } + + if len(tokens) >= llm.NumCtx { + // truncate input + numLeft := (llm.NumCtx - llm.NumKeep) / 2 + truncated := tokens[:llm.NumKeep] + erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft + truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...) + copy(llm.last, tokens[len(tokens)-llm.NumCtx:]) + + tokens = truncated + log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated)) + } else { + llm.last = make([]C.llama_token, llm.NumCtx-len(tokens)) + llm.last = append(llm.last, tokens...) + } + + var i int + for i = 0; i < len(llm.embd) && i < len(tokens) && llm.embd[i] == tokens[i]; i++ { + // noop + } + + llm.embd = tokens + if i == len(tokens) { + // evaluate at least one token to generate logits + i-- + } + + llm.cursor = i + + log.Printf("prompt: num_past=%d cached=%v eval=%v", i, len(llm.embd[:i]), len(llm.embd[i:])) + return tokens } func (llm *LLM) tokenize(prompt string) []C.llama_token { @@ -185,98 +278,86 @@ func (llm *LLM) detokenize(tokens ...C.llama_token) string { return sb.String() } -func (llm *LLM) generate(input []C.llama_token, fn func(api.GenerateResponse)) error { - var opts C.struct_llama_sample_options - opts.repeat_penalty = C.float(llm.RepeatPenalty) - opts.frequency_penalty = C.float(llm.FrequencyPenalty) - opts.presence_penalty = C.float(llm.PresencePenalty) - opts.temperature = C.float(llm.Temperature) - opts.top_k = C.int(llm.TopK) - opts.top_p = C.float(llm.TopP) - opts.tfs_z = C.float(llm.TFSZ) - opts.typical_p = C.float(llm.TypicalP) - opts.mirostat = C.int(llm.Mirostat) - opts.mirostat_tau = C.float(llm.MirostatTau) - opts.mirostat_eta = C.float(llm.MirostatEta) +func (llm *LLM) next() (C.llama_token, error) { + if len(llm.embd) >= llm.NumCtx { + numLeft := (llm.NumCtx - llm.NumKeep) / 2 + truncated := llm.embd[:llm.NumKeep] + truncated = append(truncated, llm.embd[len(llm.embd)-numLeft:]...) - output := deque[C.llama_token]{capacity: llm.NumCtx} - - context := deque[int]{capacity: llm.NumCtx / 2} - for _, in := range input { - context.PushLeft(int(in)) + llm.embd = truncated + llm.cursor = llm.NumKeep + log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d cursor=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated), llm.cursor) } - var b bytes.Buffer - for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) { - if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 { - return errors.New("llama: eval") - } - - token, err := llm.sample(output, &opts) - if errors.Is(err, io.EOF) { + for { + if llm.cursor >= len(llm.embd) { break - } else if err != nil { - return err } - b.WriteString(llm.detokenize(token)) - if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax { - // call the callback - fn(api.GenerateResponse{ - Response: b.String(), - }) - - output.PushLeft(token) - context.PushLeft(int(token)) - b.Reset() + numEval := len(llm.embd) - llm.cursor + if numEval > llm.NumBatch { + numEval = llm.NumBatch } - input = []C.llama_token{token} + if retval := C.llama_eval(llm.ctx, unsafe.SliceData(llm.embd[llm.cursor:]), C.int(numEval), C.int(llm.cursor), C.int(llm.NumThread)); retval != 0 { + return 0, fmt.Errorf("llama_eval: %d", retval) + } + + llm.cursor += numEval } - dur := func(ms float64) time.Duration { - d, err := time.ParseDuration(fmt.Sprintf("%fms", ms)) - if err != nil { - panic(err) - } + var sampleOpts C.struct_llama_sample_options + sampleOpts.repeat_penalty = C.float(llm.RepeatPenalty) + sampleOpts.frequency_penalty = C.float(llm.FrequencyPenalty) + sampleOpts.presence_penalty = C.float(llm.PresencePenalty) + sampleOpts.temperature = C.float(llm.Temperature) + sampleOpts.top_k = C.int(llm.TopK) + sampleOpts.top_p = C.float(llm.TopP) + sampleOpts.tfs_z = C.float(llm.TFSZ) + sampleOpts.typical_p = C.float(llm.TypicalP) + sampleOpts.mirostat = C.int(llm.Mirostat) + sampleOpts.mirostat_tau = C.float(llm.MirostatTau) + sampleOpts.mirostat_eta = C.float(llm.MirostatEta) + sampleOpts.penalize_newline = C.bool(llm.PenalizeNewline) - return d - } - - timings := C.llama_get_timings(llm.ctx) - fn(api.GenerateResponse{ - Done: true, - Context: context.Data(), - PromptEvalCount: int(timings.n_p_eval), - PromptEvalDuration: dur(float64(timings.t_p_eval_ms)), - EvalCount: int(timings.n_eval), - EvalDuration: dur(float64(timings.t_eval_ms)), - }) - - return nil -} - -func (llm *LLM) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) { - numVocab := int(C.llama_n_vocab(llm.ctx)) + numVocab := C.llama_n_vocab(llm.ctx) logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab) - candidates := deque[C.struct_llama_token_data]{capacity: numVocab} - for i := 0; i < candidates.Cap(); i++ { - candidates.PushLeft(C.struct_llama_token_data{ + // TODO: logit bias + + candidates := make([]C.llama_token_data, numVocab) + for i := range logits { + candidates[i] = C.llama_token_data{ id: C.int(i), logit: logits[i], p: 0, - }) + } } + repeatLastN := llm.RepeatLastN + if len(llm.last) < repeatLastN { + repeatLastN = len(llm.last) + } + + if llm.NumCtx < repeatLastN { + repeatLastN = llm.NumCtx + } + + lastN := llm.last[len(llm.last)-repeatLastN:] + token := C.llama_sample( llm.ctx, - unsafe.SliceData(candidates.Data()), C.size_t(candidates.Len()), - unsafe.SliceData(output.Data()), C.size_t(output.Len()), - opts) - if token != C.llama_token_eos() { - return token, nil + unsafe.SliceData(candidates), C.size_t(len(candidates)), + unsafe.SliceData(lastN), C.size_t(len(lastN)), + &sampleOpts, + ) + + llm.last = append(llm.last, token) + llm.embd = append(llm.embd, token) + + if token == C.llama_token_eos() { + return 0, io.EOF } - return 0, io.EOF + return token, nil } diff --git a/llama/utils.go b/llama/utils.go index b0db27d4..8b52ad5c 100644 --- a/llama/utils.go +++ b/llama/utils.go @@ -1,104 +1,15 @@ package llama -type node[T any] struct { - t T - next *node[T] - prev *node[T] -} +import ( + "fmt" + "time" +) -type deque[T any] struct { - head *node[T] - tail *node[T] - size int - capacity int -} - -func (d *deque[T]) Empty() bool { - return d.size == 0 -} - -func (d *deque[T]) Len() int { - return d.size -} - -func (d *deque[T]) Cap() int { - return d.capacity -} - -func (d *deque[T]) Push(t T) { - if d.capacity > 0 && d.size >= d.capacity { - d.PopLeft() +func parseDurationMs(ms float64) time.Duration { + dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms)) + if err != nil { + panic(err) } - n := node[T]{t: t} - if d.head != nil { - n.next = d.head - d.head.prev = &n - d.head = &n - } else { - d.head = &n - d.tail = &n - } - - d.size++ -} - -func (d *deque[T]) PushLeft(t T) { - if d.capacity > 0 && d.size >= d.capacity { - d.Pop() - } - - n := node[T]{t: t} - if d.tail != nil { - n.prev = d.tail - d.tail.next = &n - d.tail = &n - } else { - d.head = &n - d.tail = &n - } - - d.size++ -} - -func (d *deque[T]) Pop() *T { - if d.Empty() { - return nil - } - - head := d.head - d.head = head.next - if d.head != nil { - d.head.prev = nil - } else { - d.tail = nil - } - - d.size-- - return &head.t -} - -func (d *deque[T]) PopLeft() *T { - if d.Empty() { - return nil - } - - tail := d.tail - d.tail = tail.prev - if d.tail != nil { - d.tail.next = nil - } else { - d.head = nil - } - - d.size-- - return &tail.t -} - -func (d *deque[T]) Data() (data []T) { - for n := d.head; n != nil; n = n.next { - data = append(data, n.t) - } - - return data + return dur }