From 5d3f314b0bdcc7a599f65e947bee65e3cc4c73bd Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Sun, 3 Sep 2023 14:10:03 -0400 Subject: [PATCH 1/3] remove marshalPrompt which is no longer needed --- llm/ggml_llama.go | 61 +++++++++++++++-------------------------------- 1 file changed, 19 insertions(+), 42 deletions(-) diff --git a/llm/ggml_llama.go b/llm/ggml_llama.go index 8858e22f..547f9bf6 100644 --- a/llm/ggml_llama.go +++ b/llm/ggml_llama.go @@ -286,8 +286,8 @@ func newLlama(model string, adapters []string, runner ModelRunner, opts api.Opti runner.Path, append(params, "--port", strconv.Itoa(port))..., ) - var stderr bytes.Buffer - cmd.Stderr = &stderr + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel}} @@ -437,15 +437,19 @@ type PredictRequest struct { Stop []string `json:"stop,omitempty"` } -func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string, fn func(api.GenerateResponse)) error { - // we need to find the trimmed prompt context before predicting so that we can return it to the client - trimmedPrompt, err := llm.marshalPrompt(ctx, predictCtx, prompt) +func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error { + prevConvo, err := llm.Decode(ctx, prevContext) if err != nil { - return fmt.Errorf("marshaling prompt: %v", err) + return err } + + var nextContext strings.Builder + nextContext.WriteString(prevConvo) + nextContext.WriteString(prompt) + endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port) predReq := PredictRequest{ - Prompt: trimmedPrompt, + Prompt: nextContext.String(), Stream: true, NPredict: llm.NumPredict, NKeep: llm.NumKeep, @@ -491,7 +495,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string, } scanner := bufio.NewScanner(resp.Body) - genCtx := trimmedPrompt // start with the trimmed prompt for scanner.Scan() { select { case <-ctx.Done(): @@ -512,11 +515,12 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string, } if complete.Timings.PredictedMS > 0 { - genCtx += complete.Content - embd, err := llm.Encode(ctx, genCtx) + nextContext.WriteString(complete.Content) + embd, err := llm.Encode(ctx, nextContext.String()) if err != nil { return fmt.Errorf("encoding context: %v", err) } + fn(api.GenerateResponse{ Done: true, Context: embd, @@ -528,12 +532,13 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string, return nil } - var pred Prediction - if err := json.Unmarshal([]byte(evt), &pred); err != nil { + var p Prediction + if err := json.Unmarshal([]byte(evt), &p); err != nil { return fmt.Errorf("error unmarshaling llm prediction response: %v", err) } - genCtx += pred.Content - fn(api.GenerateResponse{Response: pred.Content}) + + fn(api.GenerateResponse{Response: p.Content}) + nextContext.WriteString(p.Content) } } } @@ -545,34 +550,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string, return nil } -func (llm *llama) marshalPrompt(ctx context.Context, pCtx []int, prompt string) (string, error) { - pEncode, err := llm.Encode(ctx, prompt) - if err != nil { - return "", fmt.Errorf("encoding prompt context: %w", err) - } - tokens := append(pCtx, pEncode...) - if llm.NumKeep < 0 { - llm.NumKeep = len(tokens) - } - - // min(llm.NumCtx - 4, llm.NumKeep) - if llm.NumCtx-4 < llm.NumKeep { - llm.NumKeep = llm.NumCtx - 4 - } - - if len(tokens) >= llm.NumCtx { - // truncate input - numLeft := (llm.NumCtx - llm.NumKeep) / 2 - truncated := tokens[:llm.NumKeep] - erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft - truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...) - tokens = truncated - log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated)) - } - - return llm.Decode(ctx, tokens) -} - type TokenizeRequest struct { Content string `json:"content"` } From 59a705525c91bfa407c2f5fa58eac33c109bdd4a Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Sun, 3 Sep 2023 17:46:35 -0400 Subject: [PATCH 2/3] fix not forwarding last token --- llm/ggml_llama.go | 68 ++++++++++++++++------------------------------- 1 file changed, 23 insertions(+), 45 deletions(-) diff --git a/llm/ggml_llama.go b/llm/ggml_llama.go index 547f9bf6..0c293732 100644 --- a/llm/ggml_llama.go +++ b/llm/ggml_llama.go @@ -353,11 +353,6 @@ func (llm *llama) SetOptions(opts api.Options) { llm.Options = opts } -type Prediction struct { - Content string `json:"content"` - Stop bool `json:"stop"` -} - type GenerationSettings struct { FrequencyPenalty float64 `json:"frequency_penalty"` IgnoreEOS bool `json:"ignore_eos"` @@ -385,31 +380,19 @@ type GenerationSettings struct { } type Timings struct { - PredictedMS float64 `json:"predicted_ms"` - PredictedN int `json:"predicted_n"` - PredictedPerSecond float64 `json:"predicted_per_second"` - PredictedPerTokenMS float64 `json:"predicted_per_token_ms"` - PromptMS float64 `json:"prompt_ms"` - PromptN int `json:"prompt_n"` - PromptPerSecond float64 `json:"prompt_per_second"` - PromptPerTokenMS float64 `json:"prompt_per_token_ms"` + PredictedN int `json:"predicted_n"` + PredictedMS float64 `json:"predicted_ms"` + PromptN int `json:"prompt_n"` + PromptMS float64 `json:"prompt_ms"` } -type PredictComplete struct { - Content string `json:"content"` - GenerationSettings GenerationSettings `json:"generation_settings"` - Model string `json:"model"` - Prompt string `json:"prompt"` - Stop bool `json:"stop"` - StoppedEOS bool `json:"stopped_eos"` - StoppedLimit bool `json:"stopped_limit"` - StoppedWord bool `json:"stopped_word"` - StoppingWord string `json:"stopping_word"` - Timings Timings `json:"timings"` - TokensCached int `json:"tokens_cached"` - TokensEvaluated int `json:"tokens_evaluated"` - TokensPredicted int `json:"tokens_predicted"` - Truncated bool `json:"truncated"` +type Prediction struct { + Content string `json:"content"` + Model string `json:"model"` + Prompt string `json:"prompt"` + Stop bool `json:"stop"` + + Timings `json:"timings"` } type PredictRequest struct { @@ -509,13 +492,15 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, // Read data from the server-side event stream if strings.HasPrefix(line, "data: ") { evt := line[6:] - var complete PredictComplete - if err := json.Unmarshal([]byte(evt), &complete); err != nil { - return fmt.Errorf("error unmarshaling llm complete response: %v", err) + var p Prediction + if err := json.Unmarshal([]byte(evt), &p); err != nil { + return fmt.Errorf("error unmarshaling llm prediction response: %v", err) } - if complete.Timings.PredictedMS > 0 { - nextContext.WriteString(complete.Content) + fn(api.GenerateResponse{Response: p.Content}) + nextContext.WriteString(p.Content) + + if p.Stop { embd, err := llm.Encode(ctx, nextContext.String()) if err != nil { return fmt.Errorf("encoding context: %v", err) @@ -524,21 +509,14 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn(api.GenerateResponse{ Done: true, Context: embd, - PromptEvalCount: int(complete.Timings.PromptN), - PromptEvalDuration: parseDurationMs(float64(complete.Timings.PromptMS)), - EvalCount: int(complete.Timings.PredictedN), - EvalDuration: parseDurationMs(float64(complete.Timings.PredictedMS)), + PromptEvalCount: p.PromptN, + PromptEvalDuration: parseDurationMs(p.PromptMS), + EvalCount: p.PredictedN, + EvalDuration: parseDurationMs(p.PredictedMS), }) + return nil } - - var p Prediction - if err := json.Unmarshal([]byte(evt), &p); err != nil { - return fmt.Errorf("error unmarshaling llm prediction response: %v", err) - } - - fn(api.GenerateResponse{Response: p.Content}) - nextContext.WriteString(p.Content) } } } From 681f3c4c42a645fb3e9cbc4311ae6662d838a684 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Sun, 3 Sep 2023 17:36:14 -0400 Subject: [PATCH 3/3] fix num_keep --- server/routes.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/routes.go b/server/routes.go index f0762416..1a049cbd 100644 --- a/server/routes.go +++ b/server/routes.go @@ -117,12 +117,13 @@ func load(ctx context.Context, model *Model, reqOpts map[string]interface{}, ses if err != nil { return err } + tokensNoSystem, err := llmModel.Encode(ctx, promptNoSystem) if err != nil { return err } - opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1 + opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) llmModel.SetOptions(opts) }