diff --git a/server/routes.go b/server/routes.go index 4059c7c5..5b6d0978 100644 --- a/server/routes.go +++ b/server/routes.go @@ -102,6 +102,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil } func (s *Server) GenerateHandler(c *gin.Context) { + checkpointStart := time.Now() var req api.GenerateRequest if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) @@ -129,6 +130,8 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + checkpointLoaded := time.Now() + if req.Prompt == "" { c.JSON(http.StatusOK, api.GenerateResponse{ Model: req.Model, @@ -191,26 +194,48 @@ func (s *Server) GenerateHandler(c *gin.Context) { ch := make(chan any) go func() { + // TODO (jmorganca): avoid building the response twice both here and below + var sb strings.Builder defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, Format: req.Format, Options: opts, - }, func(r llm.CompletionResponse) { - ch <- api.GenerateResponse{ + }, func(cr llm.CompletionResponse) { + res := api.GenerateResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), - Response: r.Content, - Done: r.Done, - DoneReason: r.DoneReason, + Response: cr.Content, + Done: cr.Done, + DoneReason: cr.DoneReason, Metrics: api.Metrics{ - PromptEvalCount: r.PromptEvalCount, - PromptEvalDuration: r.PromptEvalDuration, - EvalCount: r.EvalCount, - EvalDuration: r.EvalDuration, + PromptEvalCount: cr.PromptEvalCount, + PromptEvalDuration: cr.PromptEvalDuration, + EvalCount: cr.EvalCount, + EvalDuration: cr.EvalDuration, }, } + + if _, err := sb.WriteString(cr.Content); err != nil { + ch <- gin.H{"error": err.Error()} + } + + if cr.Done { + res.TotalDuration = time.Since(checkpointStart) + res.LoadDuration = checkpointLoaded.Sub(checkpointStart) + + if !req.Raw { + tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String()) + if err != nil { + ch <- gin.H{"error": err.Error()} + return + } + res.Context = append(req.Context, tokens...) + } + } + + ch <- res }); err != nil { ch <- gin.H{"error": err.Error()} } @@ -1122,6 +1147,8 @@ func (s *Server) ProcessHandler(c *gin.Context) { } func (s *Server) ChatHandler(c *gin.Context) { + checkpointStart := time.Now() + var req api.ChatRequest if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) @@ -1141,6 +1168,8 @@ func (s *Server) ChatHandler(c *gin.Context) { return } + checkpointLoaded := time.Now() + if len(req.Messages) == 0 { c.JSON(http.StatusOK, api.ChatResponse{ Model: req.Model, @@ -1169,7 +1198,7 @@ func (s *Server) ChatHandler(c *gin.Context) { Format: req.Format, Options: opts, }, func(r llm.CompletionResponse) { - ch <- api.ChatResponse{ + res := api.ChatResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), Message: api.Message{Role: "assistant", Content: r.Content}, @@ -1182,6 +1211,13 @@ func (s *Server) ChatHandler(c *gin.Context) { EvalDuration: r.EvalDuration, }, } + + if r.Done { + res.TotalDuration = time.Since(checkpointStart) + res.LoadDuration = checkpointLoaded.Sub(checkpointStart) + } + + ch <- res }); err != nil { ch <- gin.H{"error": err.Error()} }