restore model load duration on generate response (#1524)

* restore model load duration on generate response

- set model load duration on generate and chat done response
- calculate createAt time when response created

* remove checkpoints predict opts

* Update routes.go
This commit is contained in:
Bruce MacDonald 2023-12-14 12:15:50 -05:00 committed by GitHub
parent 31f0551dab
commit 6ee8c80199
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 36 deletions

View file

@ -551,14 +551,9 @@ type PredictOpts struct {
Prompt string Prompt string
Format string Format string
Images []api.ImageData Images []api.ImageData
CheckpointStart time.Time
CheckpointLoaded time.Time
} }
type PredictResult struct { type PredictResult struct {
CreatedAt time.Time
TotalDuration time.Duration
LoadDuration time.Duration
Content string Content string
Done bool Done bool
PromptEvalCount int PromptEvalCount int
@ -681,16 +676,12 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
if p.Content != "" { if p.Content != "" {
fn(PredictResult{ fn(PredictResult{
CreatedAt: time.Now().UTC(),
Content: p.Content, Content: p.Content,
}) })
} }
if p.Stop { if p.Stop {
fn(PredictResult{ fn(PredictResult{
CreatedAt: time.Now().UTC(),
TotalDuration: time.Since(predict.CheckpointStart),
Done: true, Done: true,
PromptEvalCount: p.Timings.PromptN, PromptEvalCount: p.Timings.PromptN,
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS), PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),

View file

@ -261,12 +261,10 @@ func GenerateHandler(c *gin.Context) {
resp := api.GenerateResponse{ resp := api.GenerateResponse{
Model: req.Model, Model: req.Model,
CreatedAt: r.CreatedAt, CreatedAt: time.Now().UTC(),
Done: r.Done, Done: r.Done,
Response: r.Content, Response: r.Content,
Metrics: api.Metrics{ Metrics: api.Metrics{
TotalDuration: r.TotalDuration,
LoadDuration: r.LoadDuration,
PromptEvalCount: r.PromptEvalCount, PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration, PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount, EvalCount: r.EvalCount,
@ -274,7 +272,11 @@ func GenerateHandler(c *gin.Context) {
}, },
} }
if r.Done && !req.Raw { if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String()) embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String())
if err != nil { if err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
@ -282,6 +284,7 @@ func GenerateHandler(c *gin.Context) {
} }
resp.Context = embd resp.Context = embd
} }
}
ch <- resp ch <- resp
} }
@ -290,8 +293,6 @@ func GenerateHandler(c *gin.Context) {
predictReq := llm.PredictOpts{ predictReq := llm.PredictOpts{
Prompt: prompt, Prompt: prompt,
Format: req.Format, Format: req.Format,
CheckpointStart: checkpointStart,
CheckpointLoaded: checkpointLoaded,
Images: req.Images, Images: req.Images,
} }
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
@ -1012,11 +1013,9 @@ func ChatHandler(c *gin.Context) {
resp := api.ChatResponse{ resp := api.ChatResponse{
Model: req.Model, Model: req.Model,
CreatedAt: r.CreatedAt, CreatedAt: time.Now().UTC(),
Done: r.Done, Done: r.Done,
Metrics: api.Metrics{ Metrics: api.Metrics{
TotalDuration: r.TotalDuration,
LoadDuration: r.LoadDuration,
PromptEvalCount: r.PromptEvalCount, PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration, PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount, EvalCount: r.EvalCount,
@ -1024,7 +1023,10 @@ func ChatHandler(c *gin.Context) {
}, },
} }
if !r.Done { if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} else {
resp.Message = &api.Message{Role: "assistant", Content: r.Content} resp.Message = &api.Message{Role: "assistant", Content: r.Content}
} }
@ -1035,8 +1037,6 @@ func ChatHandler(c *gin.Context) {
predictReq := llm.PredictOpts{ predictReq := llm.PredictOpts{
Prompt: prompt, Prompt: prompt,
Format: req.Format, Format: req.Format,
CheckpointStart: checkpointStart,
CheckpointLoaded: checkpointLoaded,
Images: images, Images: images,
} }
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {