From 1b44d873e74f62de4f53f154da386919c1426f8b Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Tue, 30 Jul 2024 13:12:21 -0700 Subject: [PATCH] Add Metrics to `api\embed` response (#5709) * add prompt tokens to embed response * rm slog * metrics * types * prompt n * clean up * reset submodule * update tests * test name * list metrics --- api/types.go | 4 ++++ integration/embed_test.go | 8 ++++++++ llm/ext_server/server.cpp | 7 ++++++- llm/server.go | 13 +++++++------ server/routes.go | 18 ++++++++++++------ server/sched_test.go | 4 ++-- 6 files changed, 39 insertions(+), 15 deletions(-) diff --git a/api/types.go b/api/types.go index ea5161ff..c2529652 100644 --- a/api/types.go +++ b/api/types.go @@ -267,6 +267,10 @@ type EmbedRequest struct { type EmbedResponse struct { Model string `json:"model"` Embeddings [][]float32 `json:"embeddings"` + + TotalDuration time.Duration `json:"total_duration,omitempty"` + LoadDuration time.Duration `json:"load_duration,omitempty"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` } // EmbeddingRequest is the request passed to [Client.Embeddings]. diff --git a/integration/embed_test.go b/integration/embed_test.go index 61b36fa2..10333d5d 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -69,6 +69,10 @@ func TestAllMiniLMEmbed(t *testing.T) { if !floatsEqual32(res.Embeddings[0][0], 0.010071031) { t.Fatalf("expected 0.010071031, got %.8f", res.Embeddings[0][0]) } + + if res.PromptEvalCount != 8 { + t.Fatalf("expected 8 prompt tokens, got %d", res.PromptEvalCount) + } } func TestAllMiniLMBatchEmbed(t *testing.T) { @@ -97,6 +101,10 @@ func TestAllMiniLMBatchEmbed(t *testing.T) { if !floatsEqual32(res.Embeddings[0][0], 0.010071031) || !floatsEqual32(res.Embeddings[1][0], -0.009802706) { t.Fatalf("expected 0.010071031 and -0.009802706, got %.8f and %.8f", res.Embeddings[0][0], res.Embeddings[1][0]) } + + if res.PromptEvalCount != 16 { + t.Fatalf("expected 16 prompt tokens, got %d", res.PromptEvalCount) + } } func TestAllMiniLMEmbedTruncate(t *testing.T) { diff --git a/llm/ext_server/server.cpp b/llm/ext_server/server.cpp index 0d51460c..d72bb1b1 100644 --- a/llm/ext_server/server.cpp +++ b/llm/ext_server/server.cpp @@ -1221,6 +1221,7 @@ struct llama_server_context res.result_json = json { {"embedding", std::vector(embd, embd + n_embd)}, + {"timings", slot.get_formated_timings()}, }; } } @@ -3203,11 +3204,15 @@ int main(int argc, char **argv) { responses = result.result_json.value("results", std::vector{result.result_json}); json embeddings = json::array(); + + int prompt_n = 0; for (auto & elem : responses) { embeddings.push_back(elem.at("embedding")); + prompt_n += elem.at("timings").at("prompt_n").get(); } + // send the result - json embedding_res = json{{"embedding", embeddings}}; + json embedding_res = json{{"embedding", embeddings}, {"prompt_n", prompt_n}}; return res.set_content(embedding_res.dump(), "application/json; charset=utf-8"); } }); diff --git a/llm/server.go b/llm/server.go index 8127960f..afde077e 100644 --- a/llm/server.go +++ b/llm/server.go @@ -33,7 +33,7 @@ type LlamaServer interface { Ping(ctx context.Context) error WaitUntilRunning(ctx context.Context) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error - Embed(ctx context.Context, input []string) ([][]float32, error) + Embed(ctx context.Context, input []string) (*EmbedResponse, error) Tokenize(ctx context.Context, content string) ([]int, error) Detokenize(ctx context.Context, tokens []int) (string, error) Close() error @@ -879,10 +879,11 @@ type EmbedRequest struct { } type EmbedResponse struct { - Embedding [][]float32 `json:"embedding"` + Embedding [][]float32 `json:"embedding"` + PromptEvalCount int `json:"prompt_n"` } -func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) { +func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, error) { if err := s.sem.Acquire(ctx, 1); err != nil { slog.Error("Failed to acquire semaphore", "error", err) return nil, err @@ -924,12 +925,12 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, err return nil, fmt.Errorf("%s", body) } - var embedding EmbedResponse - if err := json.Unmarshal(body, &embedding); err != nil { + var e EmbedResponse + if err := json.Unmarshal(body, &e); err != nil { return nil, fmt.Errorf("unmarshal tokenize response: %w", err) } - return embedding.Embedding, nil + return &e, nil } type TokenizeRequest struct { diff --git a/server/routes.go b/server/routes.go index e6ffe526..a560f369 100644 --- a/server/routes.go +++ b/server/routes.go @@ -284,6 +284,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { } func (s *Server) EmbedHandler(c *gin.Context) { + checkpointStart := time.Now() var req api.EmbedRequest err := c.ShouldBindJSON(&req) switch { @@ -332,6 +333,8 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } + checkpointLoaded := time.Now() + kvData, err := getKVData(m.ModelPath, false) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -370,13 +373,16 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - for i, e := range embeddings { - embeddings[i] = normalize(e) + for i, e := range embeddings.Embedding { + embeddings.Embedding[i] = normalize(e) } resp := api.EmbedResponse{ - Model: req.Model, - Embeddings: embeddings, + Model: req.Model, + Embeddings: embeddings.Embedding, + TotalDuration: time.Since(checkpointStart), + LoadDuration: checkpointLoaded.Sub(checkpointStart), + PromptEvalCount: embeddings.PromptEvalCount, } c.JSON(http.StatusOK, resp) } @@ -428,9 +434,9 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - embedding := make([]float64, len(embeddings[0])) + embedding := make([]float64, len(embeddings.Embedding[0])) - for i, v := range embeddings[0] { + for i, v := range embeddings.Embedding[0] { embedding[i] = float64(v) } diff --git a/server/sched_test.go b/server/sched_test.go index a186ce0e..4f8789fa 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -709,7 +709,7 @@ type mockLlm struct { pingResp error waitResp error completionResp error - embedResp [][]float32 + embedResp *llm.EmbedResponse embedRespErr error tokenizeResp []int tokenizeRespErr error @@ -727,7 +727,7 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { return s.completionResp } -func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) { +func (s *mockLlm) Embed(ctx context.Context, input []string) (*llm.EmbedResponse, error) { return s.embedResp, s.embedRespErr } func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {