diff --git a/examples/golang-simplegenerate/main.go b/examples/golang-simplegenerate/main.go index 9b60acef..26e3bc6d 100644 --- a/examples/golang-simplegenerate/main.go +++ b/examples/golang-simplegenerate/main.go @@ -3,10 +3,10 @@ package main import ( "bytes" "fmt" - "net/http" - "os" "io" "log" + "net/http" + "os" ) func main() { @@ -16,7 +16,7 @@ func main() { if err != nil { fmt.Print(err.Error()) os.Exit(1) - } + } responseData, err := io.ReadAll(resp.Body) if err != nil { diff --git a/format/time_test.go b/format/time_test.go index b4e2db53..cc6b8930 100644 --- a/format/time_test.go +++ b/format/time_test.go @@ -29,7 +29,7 @@ func TestHumanTime(t *testing.T) { }) t.Run("soon", func(t *testing.T) { - v := now.Add(800*time.Millisecond) + v := now.Add(800 * time.Millisecond) assertEqual(t, HumanTime(v, ""), "Less than a second from now") }) } diff --git a/server/images.go b/server/images.go index 5514c643..88d7a206 100644 --- a/server/images.go +++ b/server/images.go @@ -252,7 +252,7 @@ func filenameWithPath(path, f string) (string, error) { return f, nil } -func CreateModel(ctx context.Context, workDir, name string, path string, fn func(resp api.ProgressResponse)) error { +func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error { mp := ParseModelPath(name) var manifest *ManifestV2 diff --git a/server/routes.go b/server/routes.go index 3cee381e..3754e9a8 100644 --- a/server/routes.go +++ b/server/routes.go @@ -137,8 +137,18 @@ func GenerateHandler(c *gin.Context) { checkpointStart := time.Now() var req api.GenerateRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + err := c.ShouldBindJSON(&req) + switch { + case errors.Is(err, io.EOF): + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) + return + case err != nil: + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if req.Model == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) return } @@ -177,6 +187,12 @@ func GenerateHandler(c *gin.Context) { ch := make(chan any) go func() { defer close(ch) + // an empty request loads the model + if req.Prompt == "" && req.Template == "" && req.System == "" { + ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true} + return + } + fn := func(r api.GenerateResponse) { loaded.expireAt = time.Now().Add(sessionDuration) loaded.expireTimer.Reset(sessionDuration) @@ -191,13 +207,8 @@ func GenerateHandler(c *gin.Context) { ch <- r } - // an empty request loads the model - if req.Prompt == "" && req.Template == "" && req.System == "" { - ch <- api.GenerateResponse{Model: req.Model, Done: true} - } else { - if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil { - ch <- gin.H{"error": err.Error()} - } + if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil { + ch <- gin.H{"error": err.Error()} } }() @@ -226,8 +237,18 @@ func EmbeddingHandler(c *gin.Context) { defer loaded.mu.Unlock() var req api.EmbeddingRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + err := c.ShouldBindJSON(&req) + switch { + case errors.Is(err, io.EOF): + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) + return + case err != nil: + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if req.Model == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) return } @@ -263,8 +284,18 @@ func EmbeddingHandler(c *gin.Context) { func PullModelHandler(c *gin.Context) { var req api.PullRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + err := c.ShouldBindJSON(&req) + switch { + case errors.Is(err, io.EOF): + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) + return + case err != nil: + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if req.Name == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"}) return } @@ -297,8 +328,18 @@ func PullModelHandler(c *gin.Context) { func PushModelHandler(c *gin.Context) { var req api.PushRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + err := c.ShouldBindJSON(&req) + switch { + case errors.Is(err, io.EOF): + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) + return + case err != nil: + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if req.Name == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"}) return } @@ -329,12 +370,20 @@ func PushModelHandler(c *gin.Context) { func CreateModelHandler(c *gin.Context) { var req api.CreateRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + err := c.ShouldBindJSON(&req) + switch { + case errors.Is(err, io.EOF): + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) + return + case err != nil: + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - workDir := c.GetString("workDir") + if req.Name == "" || req.Path == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name and path are required"}) + return + } ch := make(chan any) go func() { @@ -346,7 +395,7 @@ func CreateModelHandler(c *gin.Context) { ctx, cancel := context.WithCancel(c.Request.Context()) defer cancel() - if err := CreateModel(ctx, workDir, req.Name, req.Path, fn); err != nil { + if err := CreateModel(ctx, req.Name, req.Path, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() @@ -361,8 +410,18 @@ func CreateModelHandler(c *gin.Context) { func DeleteModelHandler(c *gin.Context) { var req api.DeleteRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + err := c.ShouldBindJSON(&req) + switch { + case errors.Is(err, io.EOF): + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) + return + case err != nil: + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if req.Name == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"}) return } @@ -391,8 +450,18 @@ func DeleteModelHandler(c *gin.Context) { func ShowModelHandler(c *gin.Context) { var req api.ShowRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + err := c.ShouldBindJSON(&req) + switch { + case errors.Is(err, io.EOF): + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) + return + case err != nil: + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if req.Name == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"}) return } @@ -502,8 +571,18 @@ func ListModelsHandler(c *gin.Context) { func CopyModelHandler(c *gin.Context) { var req api.CopyRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + err := c.ShouldBindJSON(&req) + switch { + case errors.Is(err, io.EOF): + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) + return + case err != nil: + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if req.Source == "" || req.Destination == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "source add destination are required"}) return }