From fe91d7fff13cc48879b320911a9662f08a686264 Mon Sep 17 00:00:00 2001 From: frob Date: Fri, 6 Sep 2024 10:16:28 +0200 Subject: [PATCH] openai: fix "presence_penalty" typo and add test (#6665) --- openai/openai.go | 2 +- openai/openai_test.go | 46 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/openai/openai.go b/openai/openai.go index bda42b4d..a4499682 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -79,7 +79,7 @@ type ChatCompletionRequest struct { Stop any `json:"stop"` Temperature *float64 `json:"temperature"` FrequencyPenalty *float64 `json:"frequency_penalty"` - PresencePenalty *float64 `json:"presence_penalty_penalty"` + PresencePenalty *float64 `json:"presence_penalty"` TopP *float64 `json:"top_p"` ResponseFormat *ResponseFormat `json:"response_format"` Tools []api.Tool `json:"tools"` diff --git a/openai/openai_test.go b/openai/openai_test.go index c7e9f384..b34f73c5 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -22,7 +22,10 @@ const ( image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` ) -var False = false +var ( + False = false + True = true +) func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc { return func(c *gin.Context) { @@ -70,6 +73,44 @@ func TestChatMiddleware(t *testing.T) { Stream: &False, }, }, + { + name: "chat handler with options", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "stream": true, + "max_tokens": 999, + "seed": 123, + "stop": ["\n", "stop"], + "temperature": 3.0, + "frequency_penalty": 4.0, + "presence_penalty": 5.0, + "top_p": 6.0, + "response_format": {"type": "json_object"} + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "Hello", + }, + }, + Options: map[string]any{ + "num_predict": 999.0, // float because JSON doesn't distinguish between float and int + "seed": 123.0, + "stop": []any{"\n", "stop"}, + "temperature": 6.0, + "frequency_penalty": 8.0, + "presence_penalty": 10.0, + "top_p": 6.0, + }, + Format: "json", + Stream: &True, + }, + }, { name: "chat handler with image content", body: `{ @@ -186,6 +227,8 @@ func TestChatMiddleware(t *testing.T) { req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body)) req.Header.Set("Content-Type", "application/json") + defer func() { capturedRequest = nil }() + resp := httptest.NewRecorder() router.ServeHTTP(resp, req) @@ -202,7 +245,6 @@ func TestChatMiddleware(t *testing.T) { if !reflect.DeepEqual(tc.err, errResp) { t.Fatal("errors did not match") } - capturedRequest = nil }) } }