diff --git a/llm/ggml_llama.go b/llm/ggml_llama.go index 8858e22f..547f9bf6 100644 --- a/llm/ggml_llama.go +++ b/llm/ggml_llama.go @@ -286,8 +286,8 @@ func newLlama(model string, adapters []string, runner ModelRunner, opts api.Opti runner.Path, append(params, "--port", strconv.Itoa(port))..., ) - var stderr bytes.Buffer - cmd.Stderr = &stderr + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel}} @@ -437,15 +437,19 @@ type PredictRequest struct { Stop []string `json:"stop,omitempty"` } -func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string, fn func(api.GenerateResponse)) error { - // we need to find the trimmed prompt context before predicting so that we can return it to the client - trimmedPrompt, err := llm.marshalPrompt(ctx, predictCtx, prompt) +func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error { + prevConvo, err := llm.Decode(ctx, prevContext) if err != nil { - return fmt.Errorf("marshaling prompt: %v", err) + return err } + + var nextContext strings.Builder + nextContext.WriteString(prevConvo) + nextContext.WriteString(prompt) + endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port) predReq := PredictRequest{ - Prompt: trimmedPrompt, + Prompt: nextContext.String(), Stream: true, NPredict: llm.NumPredict, NKeep: llm.NumKeep, @@ -491,7 +495,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string, } scanner := bufio.NewScanner(resp.Body) - genCtx := trimmedPrompt // start with the trimmed prompt for scanner.Scan() { select { case <-ctx.Done(): @@ -512,11 +515,12 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string, } if complete.Timings.PredictedMS > 0 { - genCtx += complete.Content - embd, err := llm.Encode(ctx, genCtx) + nextContext.WriteString(complete.Content) + embd, err := llm.Encode(ctx, nextContext.String()) if err != nil { return fmt.Errorf("encoding context: %v", err) } + fn(api.GenerateResponse{ Done: true, Context: embd, @@ -528,12 +532,13 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string, return nil } - var pred Prediction - if err := json.Unmarshal([]byte(evt), &pred); err != nil { + var p Prediction + if err := json.Unmarshal([]byte(evt), &p); err != nil { return fmt.Errorf("error unmarshaling llm prediction response: %v", err) } - genCtx += pred.Content - fn(api.GenerateResponse{Response: pred.Content}) + + fn(api.GenerateResponse{Response: p.Content}) + nextContext.WriteString(p.Content) } } } @@ -545,34 +550,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string, return nil } -func (llm *llama) marshalPrompt(ctx context.Context, pCtx []int, prompt string) (string, error) { - pEncode, err := llm.Encode(ctx, prompt) - if err != nil { - return "", fmt.Errorf("encoding prompt context: %w", err) - } - tokens := append(pCtx, pEncode...) - if llm.NumKeep < 0 { - llm.NumKeep = len(tokens) - } - - // min(llm.NumCtx - 4, llm.NumKeep) - if llm.NumCtx-4 < llm.NumKeep { - llm.NumKeep = llm.NumCtx - 4 - } - - if len(tokens) >= llm.NumCtx { - // truncate input - numLeft := (llm.NumCtx - llm.NumKeep) / 2 - truncated := tokens[:llm.NumKeep] - erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft - truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...) - tokens = truncated - log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated)) - } - - return llm.Decode(ctx, tokens) -} - type TokenizeRequest struct { Content string `json:"content"` }