From cca61181cb08995ffc2ac93439425ac3fa997a5b Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 25 Jul 2023 15:51:32 -0700 Subject: [PATCH] sample metrics --- api/types.go | 11 +++++++++++ llama/llama.go | 2 ++ 2 files changed, 13 insertions(+) diff --git a/api/types.go b/api/types.go index fc00adb1..6208cb6e 100644 --- a/api/types.go +++ b/api/types.go @@ -98,6 +98,8 @@ type GenerateResponse struct { TotalDuration time.Duration `json:"total_duration,omitempty"` LoadDuration time.Duration `json:"load_duration,omitempty"` + SampleCount int `json:"sample_count,omitempty"` + SampleDuration time.Duration `json:"sample_duration,omitempty"` PromptEvalCount int `json:"prompt_eval_count,omitempty"` PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` EvalCount int `json:"eval_count,omitempty"` @@ -113,6 +115,15 @@ func (r *GenerateResponse) Summary() { fmt.Fprintf(os.Stderr, "load duration: %v\n", r.LoadDuration) } + if r.SampleCount > 0 { + fmt.Fprintf(os.Stderr, "sample count: %d token(s)\n", r.SampleCount) + } + + if r.SampleDuration > 0 { + fmt.Fprintf(os.Stderr, "sample duration: %s\n", r.SampleDuration) + fmt.Fprintf(os.Stderr, "sample rate: %.2f tokens/s\n", float64(r.SampleCount)/r.SampleDuration.Seconds()) + } + if r.PromptEvalCount > 0 { fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount) } diff --git a/llama/llama.go b/llama/llama.go index 07dd8a13..e2c30f1f 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -216,6 +216,8 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) fn(api.GenerateResponse{ Done: true, Context: last, + SampleCount: int(timings.n_sample), + SampleDuration: parseDurationMs(float64(timings.t_sample_ms)), PromptEvalCount: int(timings.n_p_eval), PromptEvalDuration: parseDurationMs(float64(timings.t_p_eval_ms)), EvalCount: int(timings.n_eval),