Merge pull request #843 from jmorganca/mxyng/request-validation

basic request validation
This commit is contained in:
Michael Yang 2023-10-19 09:30:45 -07:00 committed by GitHub
commit 0a53da03fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 109 additions and 30 deletions

View file

@ -3,10 +3,10 @@ package main
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"net/http"
"os"
"io" "io"
"log" "log"
"net/http"
"os"
) )
func main() { func main() {

View file

@ -252,7 +252,7 @@ func filenameWithPath(path, f string) (string, error) {
return f, nil return f, nil
} }
func CreateModel(ctx context.Context, workDir, name string, path string, fn func(resp api.ProgressResponse)) error { func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error {
mp := ParseModelPath(name) mp := ParseModelPath(name)
var manifest *ManifestV2 var manifest *ManifestV2

View file

@ -137,8 +137,18 @@ func GenerateHandler(c *gin.Context) {
checkpointStart := time.Now() checkpointStart := time.Now()
var req api.GenerateRequest var req api.GenerateRequest
if err := c.ShouldBindJSON(&req); err != nil { err := c.ShouldBindJSON(&req)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return return
} }
@ -177,6 +187,12 @@ func GenerateHandler(c *gin.Context) {
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
// an empty request loads the model
if req.Prompt == "" && req.Template == "" && req.System == "" {
ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}
return
}
fn := func(r api.GenerateResponse) { fn := func(r api.GenerateResponse) {
loaded.expireAt = time.Now().Add(sessionDuration) loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration) loaded.expireTimer.Reset(sessionDuration)
@ -191,14 +207,9 @@ func GenerateHandler(c *gin.Context) {
ch <- r ch <- r
} }
// an empty request loads the model
if req.Prompt == "" && req.Template == "" && req.System == "" {
ch <- api.GenerateResponse{Model: req.Model, Done: true}
} else {
if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil { if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}
}() }()
if req.Stream != nil && !*req.Stream { if req.Stream != nil && !*req.Stream {
@ -226,8 +237,18 @@ func EmbeddingHandler(c *gin.Context) {
defer loaded.mu.Unlock() defer loaded.mu.Unlock()
var req api.EmbeddingRequest var req api.EmbeddingRequest
if err := c.ShouldBindJSON(&req); err != nil { err := c.ShouldBindJSON(&req)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return return
} }
@ -263,8 +284,18 @@ func EmbeddingHandler(c *gin.Context) {
func PullModelHandler(c *gin.Context) { func PullModelHandler(c *gin.Context) {
var req api.PullRequest var req api.PullRequest
if err := c.ShouldBindJSON(&req); err != nil { err := c.ShouldBindJSON(&req)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Name == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return return
} }
@ -297,8 +328,18 @@ func PullModelHandler(c *gin.Context) {
func PushModelHandler(c *gin.Context) { func PushModelHandler(c *gin.Context) {
var req api.PushRequest var req api.PushRequest
if err := c.ShouldBindJSON(&req); err != nil { err := c.ShouldBindJSON(&req)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Name == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return return
} }
@ -329,12 +370,20 @@ func PushModelHandler(c *gin.Context) {
func CreateModelHandler(c *gin.Context) { func CreateModelHandler(c *gin.Context) {
var req api.CreateRequest var req api.CreateRequest
if err := c.ShouldBindJSON(&req); err != nil { err := c.ShouldBindJSON(&req)
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
workDir := c.GetString("workDir") if req.Name == "" || req.Path == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name and path are required"})
return
}
ch := make(chan any) ch := make(chan any)
go func() { go func() {
@ -346,7 +395,7 @@ func CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
if err := CreateModel(ctx, workDir, req.Name, req.Path, fn); err != nil { if err := CreateModel(ctx, req.Name, req.Path, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
@ -361,8 +410,18 @@ func CreateModelHandler(c *gin.Context) {
func DeleteModelHandler(c *gin.Context) { func DeleteModelHandler(c *gin.Context) {
var req api.DeleteRequest var req api.DeleteRequest
if err := c.ShouldBindJSON(&req); err != nil { err := c.ShouldBindJSON(&req)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Name == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return return
} }
@ -391,8 +450,18 @@ func DeleteModelHandler(c *gin.Context) {
func ShowModelHandler(c *gin.Context) { func ShowModelHandler(c *gin.Context) {
var req api.ShowRequest var req api.ShowRequest
if err := c.ShouldBindJSON(&req); err != nil { err := c.ShouldBindJSON(&req)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Name == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return return
} }
@ -502,8 +571,18 @@ func ListModelsHandler(c *gin.Context) {
func CopyModelHandler(c *gin.Context) { func CopyModelHandler(c *gin.Context) {
var req api.CopyRequest var req api.CopyRequest
if err := c.ShouldBindJSON(&req); err != nil { err := c.ShouldBindJSON(&req)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Source == "" || req.Destination == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "source add destination are required"})
return return
} }