diff --git a/openai/openai.go b/openai/openai.go index 01da4440..f1e75bf2 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -43,6 +43,12 @@ type ChunkChoice struct { FinishReason *string `json:"finish_reason"` } +type CompleteChunkChoice struct { + Text string `json:"text"` + Index int `json:"index"` + FinishReason *string `json:"finish_reason"` +} + type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` @@ -86,6 +92,39 @@ type ChatCompletionChunk struct { Choices []ChunkChoice `json:"choices"` } +// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int +type CompletionRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + FrequencyPenalty float32 `json:"frequency_penalty"` + MaxTokens *int `json:"max_tokens"` + PresencePenalty float32 `json:"presence_penalty"` + Seed *int `json:"seed"` + Stop any `json:"stop"` + Stream bool `json:"stream"` + Temperature *float32 `json:"temperature"` + TopP float32 `json:"top_p"` +} + +type Completion struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []CompleteChunkChoice `json:"choices"` + Usage Usage `json:"usage,omitempty"` +} + +type CompletionChunk struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Choices []CompleteChunkChoice `json:"choices"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` +} + type Model struct { Id string `json:"id"` Object string `json:"object"` @@ -158,6 +197,52 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { } } +func toCompletion(id string, r api.GenerateResponse) Completion { + return Completion{ + Id: id, + Object: "text_completion", + Created: r.CreatedAt.Unix(), + Model: r.Model, + SystemFingerprint: "fp_ollama", + Choices: []CompleteChunkChoice{{ + Text: r.Response, + Index: 0, + FinishReason: func(reason string) *string { + if len(reason) > 0 { + return &reason + } + return nil + }(r.DoneReason), + }}, + Usage: Usage{ + // TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count + PromptTokens: r.PromptEvalCount, + CompletionTokens: r.EvalCount, + TotalTokens: r.PromptEvalCount + r.EvalCount, + }, + } +} + +func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk { + return CompletionChunk{ + Id: id, + Object: "text_completion", + Created: time.Now().Unix(), + Model: r.Model, + SystemFingerprint: "fp_ollama", + Choices: []CompleteChunkChoice{{ + Text: r.Response, + Index: 0, + FinishReason: func(reason string) *string { + if len(reason) > 0 { + return &reason + } + return nil + }(r.DoneReason), + }}, + } +} + func toListCompletion(r api.ListResponse) ListCompletion { var data []Model for _, m := range r.Models { @@ -195,7 +280,7 @@ func fromChatRequest(r ChatCompletionRequest) api.ChatRequest { switch stop := r.Stop.(type) { case string: options["stop"] = []string{stop} - case []interface{}: + case []any: var stops []string for _, s := range stop { if str, ok := s.(string); ok { @@ -247,6 +332,52 @@ func fromChatRequest(r ChatCompletionRequest) api.ChatRequest { } } +func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { + options := make(map[string]any) + + switch stop := r.Stop.(type) { + case string: + options["stop"] = []string{stop} + case []string: + options["stop"] = stop + default: + if r.Stop != nil { + return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", r.Stop) + } + } + + if r.MaxTokens != nil { + options["num_predict"] = *r.MaxTokens + } + + if r.Temperature != nil { + options["temperature"] = *r.Temperature * 2.0 + } else { + options["temperature"] = 1.0 + } + + if r.Seed != nil { + options["seed"] = *r.Seed + } + + options["frequency_penalty"] = r.FrequencyPenalty * 2.0 + + options["presence_penalty"] = r.PresencePenalty * 2.0 + + if r.TopP != 0.0 { + options["top_p"] = r.TopP + } else { + options["top_p"] = 1.0 + } + + return api.GenerateRequest{ + Model: r.Model, + Prompt: r.Prompt, + Options: options, + Stream: &r.Stream, + }, nil +} + type BaseWriter struct { gin.ResponseWriter } @@ -257,6 +388,12 @@ type ChatWriter struct { BaseWriter } +type CompleteWriter struct { + stream bool + id string + BaseWriter +} + type ListWriter struct { BaseWriter } @@ -331,6 +468,55 @@ func (w *ChatWriter) Write(data []byte) (int, error) { return w.writeResponse(data) } +func (w *CompleteWriter) writeResponse(data []byte) (int, error) { + var generateResponse api.GenerateResponse + err := json.Unmarshal(data, &generateResponse) + if err != nil { + return 0, err + } + + // completion chunk + if w.stream { + d, err := json.Marshal(toCompleteChunk(w.id, generateResponse)) + if err != nil { + return 0, err + } + + w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") + _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) + if err != nil { + return 0, err + } + + if generateResponse.Done { + _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) + if err != nil { + return 0, err + } + } + + return len(data), nil + } + + // completion + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse)) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *CompleteWriter) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(code, data) + } + + return w.writeResponse(data) +} + func (w *ListWriter) writeResponse(data []byte) (int, error) { var listResponse api.ListResponse err := json.Unmarshal(data, &listResponse) @@ -416,6 +602,41 @@ func RetrieveMiddleware() gin.HandlerFunc { } } +func CompletionsMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + var req CompletionRequest + err := c.ShouldBindJSON(&req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) + return + } + + var b bytes.Buffer + genReq, err := fromCompleteRequest(req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) + return + } + + if err := json.NewEncoder(&b).Encode(genReq); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) + return + } + + c.Request.Body = io.NopCloser(&b) + + w := &CompleteWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + stream: req.Stream, + id: fmt.Sprintf("cmpl-%d", rand.Intn(999)), + } + + c.Writer = w + + c.Next() + } +} + func ChatMiddleware() gin.HandlerFunc { return func(c *gin.Context) { var req ChatCompletionRequest diff --git a/openai/openai_test.go b/openai/openai_test.go index 1f335b96..4d21382c 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -3,9 +3,11 @@ package openai import ( "bytes" "encoding/json" + "fmt" "io" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -69,6 +71,8 @@ func TestMiddleware(t *testing.T) { req.Header.Set("Content-Type", "application/json") }, Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusOK, resp.Code) + var chatResp ChatCompletion if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil { t.Fatal(err) @@ -83,6 +87,130 @@ func TestMiddleware(t *testing.T) { } }, }, + { + Name: "completions handler", + Method: http.MethodPost, + Path: "/api/generate", + TestPath: "/api/generate", + Handler: CompletionsMiddleware, + Endpoint: func(c *gin.Context) { + c.JSON(http.StatusOK, api.GenerateResponse{ + Response: "Hello!", + }) + }, + Setup: func(t *testing.T, req *http.Request) { + body := CompletionRequest{ + Model: "test-model", + Prompt: "Hello", + } + + bodyBytes, _ := json.Marshal(body) + + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + }, + Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusOK, resp.Code) + var completionResp Completion + if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil { + t.Fatal(err) + } + + if completionResp.Object != "text_completion" { + t.Fatalf("expected text_completion, got %s", completionResp.Object) + } + + if completionResp.Choices[0].Text != "Hello!" { + t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text) + } + }, + }, + { + Name: "completions handler with params", + Method: http.MethodPost, + Path: "/api/generate", + TestPath: "/api/generate", + Handler: CompletionsMiddleware, + Endpoint: func(c *gin.Context) { + var generateReq api.GenerateRequest + if err := c.ShouldBindJSON(&generateReq); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"}) + return + } + + temperature := generateReq.Options["temperature"].(float64) + var assistantMessage string + + switch temperature { + case 1.6: + assistantMessage = "Received temperature of 1.6" + default: + assistantMessage = fmt.Sprintf("Received temperature of %f", temperature) + } + + c.JSON(http.StatusOK, api.GenerateResponse{ + Response: assistantMessage, + }) + }, + Setup: func(t *testing.T, req *http.Request) { + temp := float32(0.8) + body := CompletionRequest{ + Model: "test-model", + Prompt: "Hello", + Temperature: &temp, + } + + bodyBytes, _ := json.Marshal(body) + + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + }, + Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusOK, resp.Code) + var completionResp Completion + if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil { + t.Fatal(err) + } + + if completionResp.Object != "text_completion" { + t.Fatalf("expected text_completion, got %s", completionResp.Object) + } + + if completionResp.Choices[0].Text != "Received temperature of 1.6" { + t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text) + } + }, + }, + { + Name: "completions handler with error", + Method: http.MethodPost, + Path: "/api/generate", + TestPath: "/api/generate", + Handler: CompletionsMiddleware, + Endpoint: func(c *gin.Context) { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"}) + }, + Setup: func(t *testing.T, req *http.Request) { + body := CompletionRequest{ + Model: "test-model", + Prompt: "Hello", + } + + bodyBytes, _ := json.Marshal(body) + + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + }, + Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { + if resp.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.Code) + } + + if !strings.Contains(resp.Body.String(), `"invalid request"`) { + t.Fatalf("error was not forwarded") + } + }, + }, { Name: "list handler", Method: http.MethodGet, @@ -99,6 +227,8 @@ func TestMiddleware(t *testing.T) { }) }, Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusOK, resp.Code) + var listResp ListCompletion if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil { t.Fatal(err) @@ -162,8 +292,6 @@ func TestMiddleware(t *testing.T) { resp := httptest.NewRecorder() router.ServeHTTP(resp, req) - assert.Equal(t, http.StatusOK, resp.Code) - tc.Expected(t, resp) }) } diff --git a/server/routes.go b/server/routes.go index 9fe5fcc4..41c92084 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1054,6 +1054,7 @@ func (s *Server) GenerateRoutes() http.Handler { // Compatibility endpoints r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler) + r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler) r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler) r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)