From 40c9dc0a3191dbb85f6dbfbdd1a1952b5e8073d0 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 14 Jul 2023 18:30:32 -0700 Subject: [PATCH] fix multibyte responses --- llama/llama.go | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/llama/llama.go b/llama/llama.go index 95f6d311..a3458a24 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -78,12 +78,14 @@ llama_token llama_sample( */ import "C" import ( + "bytes" "errors" "fmt" "io" "os" "strings" "time" + "unicode/utf8" "unsafe" "github.com/jmorganca/ollama/api" @@ -204,6 +206,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) context.PushLeft(int(in)) } + var b bytes.Buffer for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) { if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 { return errors.New("llama: eval") @@ -216,13 +219,17 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) return err } - // call the callback - fn(api.GenerateResponse{ - Response: llm.detokenize(token), - }) + b.WriteString(llm.detokenize(token)) + if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax { + // call the callback + fn(api.GenerateResponse{ + Response: b.String(), + }) - output.PushLeft(token) - context.PushLeft(int(token)) + output.PushLeft(token) + context.PushLeft(int(token)) + b.Reset() + } input = []C.llama_token{token} }