Merge pull request #843 from jmorganca/mxyng/request-validation
basic request validation
This commit is contained in:
commit
0a53da03fd
|
@ -3,10 +3,10 @@ package main
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
|
@ -252,7 +252,7 @@ func filenameWithPath(path, f string) (string, error) {
|
||||||
return f, nil
|
return f, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateModel(ctx context.Context, workDir, name string, path string, fn func(resp api.ProgressResponse)) error {
|
func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error {
|
||||||
mp := ParseModelPath(name)
|
mp := ParseModelPath(name)
|
||||||
|
|
||||||
var manifest *ManifestV2
|
var manifest *ManifestV2
|
||||||
|
|
125
server/routes.go
125
server/routes.go
|
@ -137,8 +137,18 @@ func GenerateHandler(c *gin.Context) {
|
||||||
checkpointStart := time.Now()
|
checkpointStart := time.Now()
|
||||||
|
|
||||||
var req api.GenerateRequest
|
var req api.GenerateRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
err := c.ShouldBindJSON(&req)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Model == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -177,6 +187,12 @@ func GenerateHandler(c *gin.Context) {
|
||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
|
// an empty request loads the model
|
||||||
|
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
||||||
|
ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
fn := func(r api.GenerateResponse) {
|
fn := func(r api.GenerateResponse) {
|
||||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||||
loaded.expireTimer.Reset(sessionDuration)
|
loaded.expireTimer.Reset(sessionDuration)
|
||||||
|
@ -191,14 +207,9 @@ func GenerateHandler(c *gin.Context) {
|
||||||
ch <- r
|
ch <- r
|
||||||
}
|
}
|
||||||
|
|
||||||
// an empty request loads the model
|
|
||||||
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
|
||||||
ch <- api.GenerateResponse{Model: req.Model, Done: true}
|
|
||||||
} else {
|
|
||||||
if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
|
if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if req.Stream != nil && !*req.Stream {
|
if req.Stream != nil && !*req.Stream {
|
||||||
|
@ -226,8 +237,18 @@ func EmbeddingHandler(c *gin.Context) {
|
||||||
defer loaded.mu.Unlock()
|
defer loaded.mu.Unlock()
|
||||||
|
|
||||||
var req api.EmbeddingRequest
|
var req api.EmbeddingRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
err := c.ShouldBindJSON(&req)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Model == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -263,8 +284,18 @@ func EmbeddingHandler(c *gin.Context) {
|
||||||
|
|
||||||
func PullModelHandler(c *gin.Context) {
|
func PullModelHandler(c *gin.Context) {
|
||||||
var req api.PullRequest
|
var req api.PullRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
err := c.ShouldBindJSON(&req)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Name == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -297,8 +328,18 @@ func PullModelHandler(c *gin.Context) {
|
||||||
|
|
||||||
func PushModelHandler(c *gin.Context) {
|
func PushModelHandler(c *gin.Context) {
|
||||||
var req api.PushRequest
|
var req api.PushRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
err := c.ShouldBindJSON(&req)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Name == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -329,12 +370,20 @@ func PushModelHandler(c *gin.Context) {
|
||||||
|
|
||||||
func CreateModelHandler(c *gin.Context) {
|
func CreateModelHandler(c *gin.Context) {
|
||||||
var req api.CreateRequest
|
var req api.CreateRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
err := c.ShouldBindJSON(&req)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
workDir := c.GetString("workDir")
|
if req.Name == "" || req.Path == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name and path are required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -346,7 +395,7 @@ func CreateModelHandler(c *gin.Context) {
|
||||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := CreateModel(ctx, workDir, req.Name, req.Path, fn); err != nil {
|
if err := CreateModel(ctx, req.Name, req.Path, fn); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -361,8 +410,18 @@ func CreateModelHandler(c *gin.Context) {
|
||||||
|
|
||||||
func DeleteModelHandler(c *gin.Context) {
|
func DeleteModelHandler(c *gin.Context) {
|
||||||
var req api.DeleteRequest
|
var req api.DeleteRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
err := c.ShouldBindJSON(&req)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Name == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -391,8 +450,18 @@ func DeleteModelHandler(c *gin.Context) {
|
||||||
|
|
||||||
func ShowModelHandler(c *gin.Context) {
|
func ShowModelHandler(c *gin.Context) {
|
||||||
var req api.ShowRequest
|
var req api.ShowRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
err := c.ShouldBindJSON(&req)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Name == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -502,8 +571,18 @@ func ListModelsHandler(c *gin.Context) {
|
||||||
|
|
||||||
func CopyModelHandler(c *gin.Context) {
|
func CopyModelHandler(c *gin.Context) {
|
||||||
var req api.CopyRequest
|
var req api.CopyRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
err := c.ShouldBindJSON(&req)
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Source == "" || req.Destination == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "source add destination are required"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue