From 1775647f763f9785a0f06eed7cfaa310b6dc9519 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 13 Jul 2023 11:02:53 -0700 Subject: [PATCH] continue conversation feed responses back into the llm --- api/types.go | 10 ++++++---- cmd/cmd.go | 11 ++++++++++- llama/llama.go | 18 +++++++++++++++--- main.go | 4 +++- server/routes.go | 2 +- server/templates/alpaca.prompt | 2 ++ server/templates/falcon.prompt | 2 ++ server/templates/mpt.prompt | 2 ++ server/templates/orca.prompt | 2 ++ server/templates/vicuna.prompt | 2 ++ server/templates/wizardcoder.prompt | 2 ++ 11 files changed, 47 insertions(+), 10 deletions(-) diff --git a/api/types.go b/api/types.go index 34ba91b3..d84611e5 100644 --- a/api/types.go +++ b/api/types.go @@ -18,8 +18,9 @@ type PullProgress struct { } type GenerateRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` + Model string `json:"model"` + Prompt string `json:"prompt"` + Context []int `json:"context,omitempty"` Options `json:"options"` } @@ -29,7 +30,8 @@ type GenerateResponse struct { CreatedAt time.Time `json:"created_at"` Response string `json:"response,omitempty"` - Done bool `json:"done"` + Done bool `json:"done"` + Context []int `json:"context,omitempty"` TotalDuration time.Duration `json:"total_duration,omitempty"` PromptEvalCount int `json:"prompt_eval_count,omitempty"` @@ -104,7 +106,7 @@ func DefaultOptions() Options { UseNUMA: false, - NumCtx: 512, + NumCtx: 2048, NumBatch: 512, NumGPU: 1, LowVRAM: false, diff --git a/cmd/cmd.go b/cmd/cmd.go index dede4600..16e8cae0 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -85,6 +85,8 @@ func RunGenerate(cmd *cobra.Command, args []string) error { return generateBatch(cmd, args[0]) } +var generateContextKey struct{} + func generate(cmd *cobra.Command, model, prompt string) error { if len(strings.TrimSpace(prompt)) > 0 { client := api.NewClient() @@ -110,7 +112,12 @@ func generate(cmd *cobra.Command, model, prompt string) error { var latest api.GenerateResponse - request := api.GenerateRequest{Model: model, Prompt: prompt} + generateContext, ok := cmd.Context().Value(generateContextKey).([]int) + if !ok { + generateContext = []int{} + } + + request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext} fn := func(resp api.GenerateResponse) error { if !spinner.IsFinished() { spinner.Finish() @@ -119,6 +126,8 @@ func generate(cmd *cobra.Command, model, prompt string) error { latest = resp fmt.Print(resp.Response) + + cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey, resp.Context)) return nil } diff --git a/llama/llama.go b/llama/llama.go index 80a1b420..c1220972 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -149,9 +149,14 @@ func (llm *llama) Close() { C.llama_print_timings(llm.ctx) } -func (llm *llama) Predict(prompt string, fn func(api.GenerateResponse)) error { - if tokens := llm.tokenize(prompt); tokens != nil { - return llm.generate(tokens, fn) +func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error { + if input := llm.tokenize(prompt); input != nil { + embd := make([]C.llama_token, len(ctx)) + for i := range ctx { + embd[i] = C.llama_token(ctx[i]) + } + + return llm.generate(append(embd, input...), fn) } return errors.New("llama: tokenize") @@ -194,6 +199,11 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) output := deque[C.llama_token]{capacity: llm.NumCtx} + context := deque[int]{capacity: llm.NumCtx / 2} + for _, in := range input { + context.PushLeft(int(in)) + } + for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) { if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 { return errors.New("llama: eval") @@ -212,6 +222,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) }) output.PushLeft(token) + context.PushLeft(int(token)) input = []C.llama_token{token} } @@ -228,6 +239,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) timings := C.llama_get_timings(llm.ctx) fn(api.GenerateResponse{ Done: true, + Context: context.Data(), PromptEvalCount: int(timings.n_p_eval), PromptEvalDuration: dur(float64(timings.t_p_eval_ms)), EvalCount: int(timings.n_eval), diff --git a/main.go b/main.go index b445e7ce..e56c9b4f 100644 --- a/main.go +++ b/main.go @@ -1,9 +1,11 @@ package main import ( + "context" + "github.com/jmorganca/ollama/cmd" ) func main() { - cmd.NewCLI().Execute() + cmd.NewCLI().ExecuteContext(context.Background()) } diff --git a/server/routes.go b/server/routes.go index ace82213..ab651f42 100644 --- a/server/routes.go +++ b/server/routes.go @@ -94,7 +94,7 @@ func generate(c *gin.Context) { ch <- r } - if err := llm.Predict(req.Prompt, fn); err != nil { + if err := llm.Predict(req.Context, req.Prompt, fn); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } diff --git a/server/templates/alpaca.prompt b/server/templates/alpaca.prompt index e0574d25..fa480ce1 100644 --- a/server/templates/alpaca.prompt +++ b/server/templates/alpaca.prompt @@ -1,4 +1,6 @@ +{{- if not .Context }} Below is an instruction that describes a task. Write a response that appropriately completes the request. +{{- end }} ### Instruction: {{ .Prompt }} diff --git a/server/templates/falcon.prompt b/server/templates/falcon.prompt index b0aaf3d7..f1267c5d 100644 --- a/server/templates/falcon.prompt +++ b/server/templates/falcon.prompt @@ -1,3 +1,5 @@ +{{- if not .Context }} A helpful assistant who helps the user with any questions asked. +{{- end }} User: {{ .Prompt }} Assistant: diff --git a/server/templates/mpt.prompt b/server/templates/mpt.prompt index 4955ee3a..1ee5b8e9 100644 --- a/server/templates/mpt.prompt +++ b/server/templates/mpt.prompt @@ -1,4 +1,6 @@ +{{- if not .Context }} Below is an instruction that describes a task. Write a response that appropriately completes the request. Be concise. Once the request is completed, include no other text. +{{- end }} ### Instruction: {{ .Prompt }} ### Response: diff --git a/server/templates/orca.prompt b/server/templates/orca.prompt index 3908fcde..94bef0c2 100644 --- a/server/templates/orca.prompt +++ b/server/templates/orca.prompt @@ -1,5 +1,7 @@ +{{- if not .Context }} ### System: You are an AI assistant that follows instruction extremely well. Help as much as you can. +{{- end }} ### User: {{ .Prompt }} diff --git a/server/templates/vicuna.prompt b/server/templates/vicuna.prompt index 835d5023..8fa04629 100644 --- a/server/templates/vicuna.prompt +++ b/server/templates/vicuna.prompt @@ -1,4 +1,6 @@ +{{ if not .Context }} A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. +{{- end }} USER: {{ .Prompt }} ASSISTANT: diff --git a/server/templates/wizardcoder.prompt b/server/templates/wizardcoder.prompt index 263c4440..500a2208 100644 --- a/server/templates/wizardcoder.prompt +++ b/server/templates/wizardcoder.prompt @@ -1,4 +1,6 @@ +{{- if not .Context }} Below is an instruction that describes a task. Write a response that appropriately completes the request +{{- end }} ### Instruction: {{ .Prompt }}