Merge pull request #843 from jmorganca/mxyng/request-validation

basic request validation
This commit is contained in:
Michael Yang 2023-10-19 09:30:45 -07:00 committed by GitHub
commit 0a53da03fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 109 additions and 30 deletions

View file

@ -3,10 +3,10 @@ package main
import (
"bytes"
"fmt"
"net/http"
"os"
"io"
"log"
"net/http"
"os"
)
func main() {

View file

@ -29,7 +29,7 @@ func TestHumanTime(t *testing.T) {
})
t.Run("soon", func(t *testing.T) {
v := now.Add(800*time.Millisecond)
v := now.Add(800 * time.Millisecond)
assertEqual(t, HumanTime(v, ""), "Less than a second from now")
})
}

View file

@ -252,7 +252,7 @@ func filenameWithPath(path, f string) (string, error) {
return f, nil
}
func CreateModel(ctx context.Context, workDir, name string, path string, fn func(resp api.ProgressResponse)) error {
func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error {
mp := ParseModelPath(name)
var manifest *ManifestV2

View file

@ -137,8 +137,18 @@ func GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
@ -177,6 +187,12 @@ func GenerateHandler(c *gin.Context) {
ch := make(chan any)
go func() {
defer close(ch)
// an empty request loads the model
if req.Prompt == "" && req.Template == "" && req.System == "" {
ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}
return
}
fn := func(r api.GenerateResponse) {
loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration)
@ -191,14 +207,9 @@ func GenerateHandler(c *gin.Context) {
ch <- r
}
// an empty request loads the model
if req.Prompt == "" && req.Template == "" && req.System == "" {
ch <- api.GenerateResponse{Model: req.Model, Done: true}
} else {
if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}
}()
if req.Stream != nil && !*req.Stream {
@ -226,8 +237,18 @@ func EmbeddingHandler(c *gin.Context) {
defer loaded.mu.Unlock()
var req api.EmbeddingRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
@ -263,8 +284,18 @@ func EmbeddingHandler(c *gin.Context) {
func PullModelHandler(c *gin.Context) {
var req api.PullRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Name == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
@ -297,8 +328,18 @@ func PullModelHandler(c *gin.Context) {
func PushModelHandler(c *gin.Context) {
var req api.PushRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Name == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
@ -329,12 +370,20 @@ func PushModelHandler(c *gin.Context) {
func CreateModelHandler(c *gin.Context) {
var req api.CreateRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
workDir := c.GetString("workDir")
if req.Name == "" || req.Path == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name and path are required"})
return
}
ch := make(chan any)
go func() {
@ -346,7 +395,7 @@ func CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := CreateModel(ctx, workDir, req.Name, req.Path, fn); err != nil {
if err := CreateModel(ctx, req.Name, req.Path, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@ -361,8 +410,18 @@ func CreateModelHandler(c *gin.Context) {
func DeleteModelHandler(c *gin.Context) {
var req api.DeleteRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Name == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
@ -391,8 +450,18 @@ func DeleteModelHandler(c *gin.Context) {
func ShowModelHandler(c *gin.Context) {
var req api.ShowRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Name == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
@ -502,8 +571,18 @@ func ListModelsHandler(c *gin.Context) {
func CopyModelHandler(c *gin.Context) {
var req api.CopyRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Source == "" || req.Destination == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "source add destination are required"})
return
}