diff --git a/llm/ext_server/server.cpp b/llm/ext_server/server.cpp index c65901c7..5717c17a 100644 --- a/llm/ext_server/server.cpp +++ b/llm/ext_server/server.cpp @@ -1223,9 +1223,7 @@ struct llama_server_context res.result_json = json { - {"id", res.id}, {"embedding", std::vector(embd, embd + n_embd)}, - {"timings", slot.get_formated_timings()}, }; } } @@ -3194,41 +3192,17 @@ int main(int argc, char **argv) { prompt = ""; } - if (prompt.size() == 1) { - prompt = prompt[0]; - } - // create and queue the task - json responses; - { - const int id_task = llama.queue_tasks.get_new_id(); - llama.queue_results.add_waiting_task_id(id_task); - llama.request_completion(id_task, {{"prompt", prompt}}, true, -1); + const int task_id = llama.queue_tasks.get_new_id(); + llama.queue_results.add_waiting_task_id(task_id); + llama.request_completion(task_id, {{"prompt", prompt}}, true, -1); - // get the result - task_result result = llama.queue_results.recv(id_task); - llama.queue_results.remove_waiting_task_id(id_task); - if (result.error) { - return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); - } + // get the result + task_result result = llama.queue_results.recv(task_id); + llama.queue_results.remove_waiting_task_id(task_id); - responses = result.result_json.value("results", std::vector{result.result_json}); - std::sort(responses.begin(), responses.end(), [](const json& a, const json& b) { - return a["id"] < b["id"]; - }); - - json embeddings = json::array(); - - int prompt_n = 0; - for (auto & elem : responses) { - embeddings.push_back(elem.at("embedding")); - prompt_n += elem.at("timings").at("prompt_n").get(); - } - - // send the result - json embedding_res = json{{"embedding", embeddings}, {"prompt_n", prompt_n}}; - return res.set_content(embedding_res.dump(), "application/json; charset=utf-8"); - } + // send the result + return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); }); // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!? diff --git a/llm/server.go b/llm/server.go index 0bd94f35..d2b8db9b 100644 --- a/llm/server.go +++ b/llm/server.go @@ -33,7 +33,7 @@ type LlamaServer interface { Ping(ctx context.Context) error WaitUntilRunning(ctx context.Context) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error - Embed(ctx context.Context, input []string) (*EmbedResponse, error) + Embedding(ctx context.Context, input string) ([]float32, error) Tokenize(ctx context.Context, content string) ([]int, error) Detokenize(ctx context.Context, tokens []int) (string, error) Close() error @@ -883,24 +883,20 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return nil } -type EmbedRequest struct { - Content []string `json:"content"` +type EmbeddingRequest struct { + Content string `json:"content"` } -type EmbedResponse struct { - Embedding [][]float32 `json:"embedding"` - PromptEvalCount int `json:"prompt_n"` +type EmbeddingResponse struct { + Embedding []float32 `json:"embedding"` } -func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, error) { - // each input will use a slot, so we need to acquire the semaphore for - // the number of inputs up to numParallel - slots := int64(min(len(input), s.numParallel)) - if err := s.sem.Acquire(ctx, slots); err != nil { +func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) { + if err := s.sem.Acquire(ctx, 1); err != nil { slog.Error("Failed to acquire semaphore", "error", err) return nil, err } - defer s.sem.Release(slots) + defer s.sem.Release(1) // Make sure the server is ready status, err := s.getServerStatusRetry(ctx) @@ -910,18 +906,18 @@ func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) } - data, err := json.Marshal(EmbedRequest{Content: input}) + data, err := json.Marshal(EmbeddingRequest{Content: input}) if err != nil { return nil, fmt.Errorf("error marshaling embed data: %w", err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data)) + r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data)) if err != nil { return nil, fmt.Errorf("error creating embed request: %w", err) } - req.Header.Set("Content-Type", "application/json") + r.Header.Set("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) + resp, err := http.DefaultClient.Do(r) if err != nil { return nil, fmt.Errorf("do embedding request: %w", err) } @@ -937,12 +933,12 @@ func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, return nil, fmt.Errorf("%s", body) } - var e EmbedResponse + var e EmbeddingResponse if err := json.Unmarshal(body, &e); err != nil { return nil, fmt.Errorf("unmarshal tokenize response: %w", err) } - return &e, nil + return e.Embedding, nil } type TokenizeRequest struct { diff --git a/server/routes.go b/server/routes.go index e55eaa9d..e5a31002 100644 --- a/server/routes.go +++ b/server/routes.go @@ -23,6 +23,7 @@ import ( "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" + "golang.org/x/sync/errgroup" "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" @@ -346,6 +347,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } + var count int for i, s := range input { tokens, err := r.Tokenize(c.Request.Context(), s) if err != nil { @@ -368,25 +370,36 @@ func (s *Server) EmbedHandler(c *gin.Context) { } } + count += len(tokens) + input[i] = s } - embeddings, err := r.Embed(c.Request.Context(), input) - if err != nil { - slog.Error("embedding generation failed", "error", err) - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) - return + + var g errgroup.Group + embeddings := make([][]float32, len(input)) + for i, text := range input { + g.Go(func() error { + embedding, err := r.Embedding(c.Request.Context(), text) + if err != nil { + return err + } + embeddings[i] = normalize(embedding) + return nil + }) } - for i, e := range embeddings.Embedding { - embeddings.Embedding[i] = normalize(e) + if err := g.Wait(); err != nil { + slog.Error("embedding generation failed", "error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embeddings: %v", err)}) + return } resp := api.EmbedResponse{ Model: req.Model, - Embeddings: embeddings.Embedding, + Embeddings: embeddings, TotalDuration: time.Since(checkpointStart), LoadDuration: checkpointLoaded.Sub(checkpointStart), - PromptEvalCount: embeddings.PromptEvalCount, + PromptEvalCount: count, } c.JSON(http.StatusOK, resp) } @@ -430,21 +443,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt}) + embedding, err := r.Embedding(c.Request.Context(), req.Prompt) if err != nil { slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) return } - embedding := make([]float64, len(embeddings.Embedding[0])) - - for i, v := range embeddings.Embedding[0] { - embedding[i] = float64(v) + var e []float64 + for _, v := range embedding { + e = append(e, float64(v)) } resp := api.EmbeddingResponse{ - Embedding: embedding, + Embedding: e, } c.JSON(http.StatusOK, resp) } diff --git a/server/sched_test.go b/server/sched_test.go index c8717430..713b9259 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -708,8 +708,8 @@ type mockLlm struct { pingResp error waitResp error completionResp error - embedResp *llm.EmbedResponse - embedRespErr error + embeddingResp []float32 + embeddingRespErr error tokenizeResp []int tokenizeRespErr error detokenizeResp string @@ -727,8 +727,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn return s.completionResp } -func (s *mockLlm) Embed(ctx context.Context, input []string) (*llm.EmbedResponse, error) { - return s.embedResp, s.embedRespErr +func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) { + return s.embeddingResp, s.embeddingRespErr } func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {