diff --git a/openai/openai.go b/openai/openai.go index 93b63296..de6f4bd5 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -877,6 +877,7 @@ func ChatMiddleware() gin.HandlerFunc { chatReq, err := fromChatRequest(req) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) + return } if err := json.NewEncoder(&b).Encode(chatReq); err != nil { diff --git a/openai/openai_test.go b/openai/openai_test.go index ad056e6d..f978d46c 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -20,113 +20,59 @@ const prefix = `data:image/jpeg;base64,` const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` const imageURL = prefix + image -func TestMiddlewareRequests(t *testing.T) { +func prepareRequest(req *http.Request, body any) { + bodyBytes, _ := json.Marshal(body) + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") +} + +func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc { + return func(c *gin.Context) { + bodyBytes, _ := io.ReadAll(c.Request.Body) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + err := json.Unmarshal(bodyBytes, capturedRequest) + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request") + } + c.Next() + } +} + +func TestChatMiddleware(t *testing.T) { type testCase struct { Name string - Method string - Path string - Handler func() gin.HandlerFunc Setup func(t *testing.T, req *http.Request) - Expected func(t *testing.T, req *http.Request) + Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) } - var capturedRequest *http.Request - - captureRequestMiddleware := func() gin.HandlerFunc { - return func(c *gin.Context) { - bodyBytes, _ := io.ReadAll(c.Request.Body) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - capturedRequest = c.Request - c.Next() - } - } + var capturedRequest *api.ChatRequest testCases := []testCase{ { - Name: "chat handler", - Method: http.MethodPost, - Path: "/api/chat", - Handler: ChatMiddleware, + Name: "chat handler", Setup: func(t *testing.T, req *http.Request) { body := ChatCompletionRequest{ Model: "test-model", Messages: []Message{{Role: "user", Content: "Hello"}}, } - - bodyBytes, _ := json.Marshal(body) - - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - req.Header.Set("Content-Type", "application/json") + prepareRequest(req, body) }, - Expected: func(t *testing.T, req *http.Request) { - var chatReq api.ChatRequest - if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil { - t.Fatal(err) + Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.Code) } - if chatReq.Messages[0].Role != "user" { - t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role) + if req.Messages[0].Role != "user" { + t.Fatalf("expected 'user', got %s", req.Messages[0].Role) } - if chatReq.Messages[0].Content != "Hello" { - t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content) + if req.Messages[0].Content != "Hello" { + t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content) } }, }, { - Name: "completions handler", - Method: http.MethodPost, - Path: "/api/generate", - Handler: CompletionsMiddleware, - Setup: func(t *testing.T, req *http.Request) { - temp := float32(0.8) - body := CompletionRequest{ - Model: "test-model", - Prompt: "Hello", - Temperature: &temp, - Stop: []string{"\n", "stop"}, - Suffix: "suffix", - } - - bodyBytes, _ := json.Marshal(body) - - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - req.Header.Set("Content-Type", "application/json") - }, - Expected: func(t *testing.T, req *http.Request) { - var genReq api.GenerateRequest - if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil { - t.Fatal(err) - } - - if genReq.Prompt != "Hello" { - t.Fatalf("expected 'Hello', got %s", genReq.Prompt) - } - - if genReq.Options["temperature"] != 1.6 { - t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"]) - } - - stopTokens, ok := genReq.Options["stop"].([]any) - - if !ok { - t.Fatalf("expected stop tokens to be a list") - } - - if stopTokens[0] != "\n" || stopTokens[1] != "stop" { - t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens) - } - - if genReq.Suffix != "suffix" { - t.Fatalf("expected 'suffix', got %s", genReq.Suffix) - } - }, - }, - { - Name: "chat handler with image content", - Method: http.MethodPost, - Path: "/api/chat", - Handler: ChatMiddleware, + Name: "chat handler with image content", Setup: func(t *testing.T, req *http.Request) { body := ChatCompletionRequest{ Model: "test-model", @@ -139,91 +85,254 @@ func TestMiddlewareRequests(t *testing.T) { }, }, } - - bodyBytes, _ := json.Marshal(body) - - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - req.Header.Set("Content-Type", "application/json") + prepareRequest(req, body) }, - Expected: func(t *testing.T, req *http.Request) { - var chatReq api.ChatRequest - if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil { - t.Fatal(err) + Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { + if resp.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.Code) } - if chatReq.Messages[0].Role != "user" { - t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role) + if req.Messages[0].Role != "user" { + t.Fatalf("expected 'user', got %s", req.Messages[0].Role) } - if chatReq.Messages[0].Content != "Hello" { - t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content) + if req.Messages[0].Content != "Hello" { + t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content) } img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):]) - if chatReq.Messages[1].Role != "user" { - t.Fatalf("expected 'user', got %s", chatReq.Messages[1].Role) + if req.Messages[1].Role != "user" { + t.Fatalf("expected 'user', got %s", req.Messages[1].Role) } - if !bytes.Equal(chatReq.Messages[1].Images[0], img) { - t.Fatalf("expected image encoding, got %s", chatReq.Messages[1].Images[0]) + if !bytes.Equal(req.Messages[1].Images[0], img) { + t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0]) } }, }, { - Name: "embed handler single input", - Method: http.MethodPost, - Path: "/api/embed", - Handler: EmbeddingsMiddleware, + Name: "chat handler with tools", + Setup: func(t *testing.T, req *http.Request) { + body := ChatCompletionRequest{ + Model: "test-model", + Messages: []Message{ + {Role: "user", Content: "What's the weather like in Paris Today?"}, + {Role: "assistant", ToolCalls: []ToolCall{{ + ID: "id", + Type: "function", + Function: struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + }{ + Name: "get_current_weather", + Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}", + }, + }}}, + }, + } + prepareRequest(req, body) + }, + Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { + if resp.Code != 200 { + t.Fatalf("expected 200, got %d", resp.Code) + } + + if req.Messages[0].Content != "What's the weather like in Paris Today?" { + t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content) + } + + if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" { + t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"]) + } + + if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" { + t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"]) + } + }, + }, + { + Name: "chat handler error forwarding", + Setup: func(t *testing.T, req *http.Request) { + body := ChatCompletionRequest{ + Model: "test-model", + Messages: []Message{{Role: "user", Content: 2}}, + } + prepareRequest(req, body) + }, + Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { + if resp.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.Code) + } + + if !strings.Contains(resp.Body.String(), "invalid message content type") { + t.Fatalf("error was not forwarded") + } + }, + }, + } + + endpoint := func(c *gin.Context) { + c.Status(http.StatusOK) + } + + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest)) + router.Handle(http.MethodPost, "/api/chat", endpoint) + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil) + + tc.Setup(t, req) + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + tc.Expected(t, capturedRequest, resp) + + capturedRequest = nil + }) + } +} + +func TestCompletionsMiddleware(t *testing.T) { + type testCase struct { + Name string + Setup func(t *testing.T, req *http.Request) + Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) + } + + var capturedRequest *api.GenerateRequest + + testCases := []testCase{ + { + Name: "completions handler", + Setup: func(t *testing.T, req *http.Request) { + temp := float32(0.8) + body := CompletionRequest{ + Model: "test-model", + Prompt: "Hello", + Temperature: &temp, + Stop: []string{"\n", "stop"}, + Suffix: "suffix", + } + prepareRequest(req, body) + }, + Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) { + if req.Prompt != "Hello" { + t.Fatalf("expected 'Hello', got %s", req.Prompt) + } + + if req.Options["temperature"] != 1.6 { + t.Fatalf("expected 1.6, got %f", req.Options["temperature"]) + } + + stopTokens, ok := req.Options["stop"].([]any) + + if !ok { + t.Fatalf("expected stop tokens to be a list") + } + + if stopTokens[0] != "\n" || stopTokens[1] != "stop" { + t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens) + } + + if req.Suffix != "suffix" { + t.Fatalf("expected 'suffix', got %s", req.Suffix) + } + }, + }, + { + Name: "completions handler error forwarding", + Setup: func(t *testing.T, req *http.Request) { + body := CompletionRequest{ + Model: "test-model", + Prompt: "Hello", + Temperature: nil, + Stop: []int{1, 2}, + Suffix: "suffix", + } + prepareRequest(req, body) + }, + Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) { + if resp.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.Code) + } + + if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") { + t.Fatalf("error was not forwarded") + } + }, + }, + } + + endpoint := func(c *gin.Context) { + c.Status(http.StatusOK) + } + + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest)) + router.Handle(http.MethodPost, "/api/generate", endpoint) + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil) + + tc.Setup(t, req) + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + tc.Expected(t, capturedRequest, resp) + + capturedRequest = nil + }) + } +} + +func TestEmbeddingsMiddleware(t *testing.T) { + type testCase struct { + Name string + Setup func(t *testing.T, req *http.Request) + Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) + } + + var capturedRequest *api.EmbedRequest + + testCases := []testCase{ + { + Name: "embed handler single input", Setup: func(t *testing.T, req *http.Request) { body := EmbedRequest{ Input: "Hello", Model: "test-model", } - - bodyBytes, _ := json.Marshal(body) - - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - req.Header.Set("Content-Type", "application/json") + prepareRequest(req, body) }, - Expected: func(t *testing.T, req *http.Request) { - var embedReq api.EmbedRequest - if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil { - t.Fatal(err) + Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) { + if req.Input != "Hello" { + t.Fatalf("expected 'Hello', got %s", req.Input) } - if embedReq.Input != "Hello" { - t.Fatalf("expected 'Hello', got %s", embedReq.Input) - } - - if embedReq.Model != "test-model" { - t.Fatalf("expected 'test-model', got %s", embedReq.Model) + if req.Model != "test-model" { + t.Fatalf("expected 'test-model', got %s", req.Model) } }, }, { - Name: "embed handler batch input", - Method: http.MethodPost, - Path: "/api/embed", - Handler: EmbeddingsMiddleware, + Name: "embed handler batch input", Setup: func(t *testing.T, req *http.Request) { body := EmbedRequest{ Input: []string{"Hello", "World"}, Model: "test-model", } - - bodyBytes, _ := json.Marshal(body) - - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - req.Header.Set("Content-Type", "application/json") + prepareRequest(req, body) }, - Expected: func(t *testing.T, req *http.Request) { - var embedReq api.EmbedRequest - if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil { - t.Fatal(err) - } - - input, ok := embedReq.Input.([]any) + Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) { + input, ok := req.Input.([]any) if !ok { t.Fatalf("expected input to be a list") @@ -237,36 +346,52 @@ func TestMiddlewareRequests(t *testing.T) { t.Fatalf("expected 'World', got %s", input[1]) } - if embedReq.Model != "test-model" { - t.Fatalf("expected 'test-model', got %s", embedReq.Model) + if req.Model != "test-model" { + t.Fatalf("expected 'test-model', got %s", req.Model) + } + }, + }, + { + Name: "embed handler error forwarding", + Setup: func(t *testing.T, req *http.Request) { + body := EmbedRequest{ + Model: "test-model", + } + prepareRequest(req, body) + }, + Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) { + if resp.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.Code) + } + + if !strings.Contains(resp.Body.String(), "invalid input") { + t.Fatalf("error was not forwarded") } }, }, } - gin.SetMode(gin.TestMode) - router := gin.New() - endpoint := func(c *gin.Context) { c.Status(http.StatusOK) } + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest)) + router.Handle(http.MethodPost, "/api/embed", endpoint) + for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { - router = gin.New() - router.Use(captureRequestMiddleware()) - router.Use(tc.Handler()) - router.Handle(tc.Method, tc.Path, endpoint) - req, _ := http.NewRequest(tc.Method, tc.Path, nil) + req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil) - if tc.Setup != nil { - tc.Setup(t, req) - } + tc.Setup(t, req) resp := httptest.NewRecorder() router.ServeHTTP(resp, req) - tc.Expected(t, capturedRequest) + tc.Expected(t, capturedRequest, resp) + + capturedRequest = nil }) } } @@ -284,36 +409,6 @@ func TestMiddlewareResponses(t *testing.T) { } testCases := []testCase{ - { - Name: "completions handler error forwarding", - Method: http.MethodPost, - Path: "/api/generate", - TestPath: "/api/generate", - Handler: CompletionsMiddleware, - Endpoint: func(c *gin.Context) { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"}) - }, - Setup: func(t *testing.T, req *http.Request) { - body := CompletionRequest{ - Model: "test-model", - Prompt: "Hello", - } - - bodyBytes, _ := json.Marshal(body) - - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - req.Header.Set("Content-Type", "application/json") - }, - Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { - if resp.Code != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", resp.Code) - } - - if !strings.Contains(resp.Body.String(), `"invalid request"`) { - t.Fatalf("error was not forwarded") - } - }, - }, { Name: "list handler", Method: http.MethodGet, @@ -330,8 +425,6 @@ func TestMiddlewareResponses(t *testing.T) { }) }, Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { - assert.Equal(t, http.StatusOK, resp.Code) - var listResp ListCompletion if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil { t.Fatal(err) @@ -395,6 +488,8 @@ func TestMiddlewareResponses(t *testing.T) { resp := httptest.NewRecorder() router.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) + tc.Expected(t, resp) }) }