From 8b00a415ab5170a5a75b105402ca262d1fb7ac12 Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Tue, 13 Aug 2024 13:19:56 -0400 Subject: [PATCH] Load Embedding Model on Empty Input (#6325) * load on empty input * no load on invalid input --- server/routes.go | 16 +++++----- server/routes_test.go | 70 ------------------------------------------- 2 files changed, 9 insertions(+), 77 deletions(-) diff --git a/server/routes.go b/server/routes.go index e5a31002..6c470c17 100644 --- a/server/routes.go +++ b/server/routes.go @@ -324,13 +324,10 @@ func (s *Server) EmbedHandler(c *gin.Context) { input = append(input, v.(string)) } default: - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) - return - } - - if len(input) == 0 { - c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}}) - return + if req.Input != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) + return + } } r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) @@ -341,6 +338,11 @@ func (s *Server) EmbedHandler(c *gin.Context) { checkpointLoaded := time.Now() + if len(input) == 0 { + c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}}) + return + } + kvData, err := getKVData(m.ModelPath, false) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) diff --git a/server/routes_test.go b/server/routes_test.go index ef7248ef..242875d6 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -272,76 +272,6 @@ func Test_Routes(t *testing.T) { assert.Equal(t, "library", retrieveResp.OwnedBy) }, }, - { - Name: "Embed Handler Empty Input", - Method: http.MethodPost, - Path: "/api/embed", - Setup: func(t *testing.T, req *http.Request) { - embedReq := api.EmbedRequest{ - Model: "t-bone", - Input: "", - } - jsonData, err := json.Marshal(embedReq) - require.NoError(t, err) - req.Body = io.NopCloser(bytes.NewReader(jsonData)) - }, - Expected: func(t *testing.T, resp *http.Response) { - contentType := resp.Header.Get("Content-Type") - if contentType != "application/json; charset=utf-8" { - t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType) - } - body, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - var embedResp api.EmbedResponse - err = json.Unmarshal(body, &embedResp) - if err != nil { - t.Fatal(err) - } - - if embedResp.Model != "t-bone" { - t.Fatalf("expected model t-bone, got %s", embedResp.Model) - } - - if embedResp.Embeddings == nil { - t.Fatalf("expected embeddings to not be nil, got %v", embedResp.Embeddings) - } - - if len(embedResp.Embeddings) != 0 { - t.Fatalf("expected embeddings to be empty, got %v", embedResp.Embeddings) - } - }, - }, - { - Name: "Embed Handler Invalid Input", - Method: http.MethodPost, - Path: "/api/embed", - Setup: func(t *testing.T, req *http.Request) { - embedReq := api.EmbedRequest{ - Model: "t-bone", - Input: 2, - } - jsonData, err := json.Marshal(embedReq) - require.NoError(t, err) - req.Body = io.NopCloser(bytes.NewReader(jsonData)) - }, - Expected: func(t *testing.T, resp *http.Response) { - contentType := resp.Header.Get("Content-Type") - if contentType != "application/json; charset=utf-8" { - t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType) - } - _, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("expected status code 400, got %d", resp.StatusCode) - } - }, - }, } t.Setenv("OLLAMA_MODELS", t.TempDir())