diff --git a/llm/ggml_llama.go b/llm/ggml_llama.go index 547f9bf6..0c293732 100644 --- a/llm/ggml_llama.go +++ b/llm/ggml_llama.go @@ -353,11 +353,6 @@ func (llm *llama) SetOptions(opts api.Options) { llm.Options = opts } -type Prediction struct { - Content string `json:"content"` - Stop bool `json:"stop"` -} - type GenerationSettings struct { FrequencyPenalty float64 `json:"frequency_penalty"` IgnoreEOS bool `json:"ignore_eos"` @@ -385,31 +380,19 @@ type GenerationSettings struct { } type Timings struct { - PredictedMS float64 `json:"predicted_ms"` - PredictedN int `json:"predicted_n"` - PredictedPerSecond float64 `json:"predicted_per_second"` - PredictedPerTokenMS float64 `json:"predicted_per_token_ms"` - PromptMS float64 `json:"prompt_ms"` - PromptN int `json:"prompt_n"` - PromptPerSecond float64 `json:"prompt_per_second"` - PromptPerTokenMS float64 `json:"prompt_per_token_ms"` + PredictedN int `json:"predicted_n"` + PredictedMS float64 `json:"predicted_ms"` + PromptN int `json:"prompt_n"` + PromptMS float64 `json:"prompt_ms"` } -type PredictComplete struct { - Content string `json:"content"` - GenerationSettings GenerationSettings `json:"generation_settings"` - Model string `json:"model"` - Prompt string `json:"prompt"` - Stop bool `json:"stop"` - StoppedEOS bool `json:"stopped_eos"` - StoppedLimit bool `json:"stopped_limit"` - StoppedWord bool `json:"stopped_word"` - StoppingWord string `json:"stopping_word"` - Timings Timings `json:"timings"` - TokensCached int `json:"tokens_cached"` - TokensEvaluated int `json:"tokens_evaluated"` - TokensPredicted int `json:"tokens_predicted"` - Truncated bool `json:"truncated"` +type Prediction struct { + Content string `json:"content"` + Model string `json:"model"` + Prompt string `json:"prompt"` + Stop bool `json:"stop"` + + Timings `json:"timings"` } type PredictRequest struct { @@ -509,13 +492,15 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, // Read data from the server-side event stream if strings.HasPrefix(line, "data: ") { evt := line[6:] - var complete PredictComplete - if err := json.Unmarshal([]byte(evt), &complete); err != nil { - return fmt.Errorf("error unmarshaling llm complete response: %v", err) + var p Prediction + if err := json.Unmarshal([]byte(evt), &p); err != nil { + return fmt.Errorf("error unmarshaling llm prediction response: %v", err) } - if complete.Timings.PredictedMS > 0 { - nextContext.WriteString(complete.Content) + fn(api.GenerateResponse{Response: p.Content}) + nextContext.WriteString(p.Content) + + if p.Stop { embd, err := llm.Encode(ctx, nextContext.String()) if err != nil { return fmt.Errorf("encoding context: %v", err) @@ -524,21 +509,14 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn(api.GenerateResponse{ Done: true, Context: embd, - PromptEvalCount: int(complete.Timings.PromptN), - PromptEvalDuration: parseDurationMs(float64(complete.Timings.PromptMS)), - EvalCount: int(complete.Timings.PredictedN), - EvalDuration: parseDurationMs(float64(complete.Timings.PredictedMS)), + PromptEvalCount: p.PromptN, + PromptEvalDuration: parseDurationMs(p.PromptMS), + EvalCount: p.PredictedN, + EvalDuration: parseDurationMs(p.PredictedMS), }) + return nil } - - var p Prediction - if err := json.Unmarshal([]byte(evt), &p); err != nil { - return fmt.Errorf("error unmarshaling llm prediction response: %v", err) - } - - fn(api.GenerateResponse{Response: p.Content}) - nextContext.WriteString(p.Content) } } }