api: add model for all requests

prefer using req.Model and fallback to req.Name
This commit is contained in:
Michael Yang 2024-01-11 14:07:54 -08:00
parent abec7f06e5
commit a38d88d828
2 changed files with 59 additions and 26 deletions

View file

@ -137,23 +137,31 @@ type EmbeddingResponse struct {
} }
type CreateRequest struct { type CreateRequest struct {
Name string `json:"name"` Model string `json:"model"`
Path string `json:"path"` Path string `json:"path"`
Modelfile string `json:"modelfile"` Modelfile string `json:"modelfile"`
Stream *bool `json:"stream,omitempty"` Stream *bool `json:"stream,omitempty"`
// Name is deprecated, see Model
Name string `json:"name"`
} }
type DeleteRequest struct { type DeleteRequest struct {
Model string `json:"model"`
// Name is deprecated, see Model
Name string `json:"name"` Name string `json:"name"`
} }
type ShowRequest struct { type ShowRequest struct {
Name string `json:"name"`
Model string `json:"model"` Model string `json:"model"`
System string `json:"system"` System string `json:"system"`
Template string `json:"template"` Template string `json:"template"`
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
// Name is deprecated, see Model
Name string `json:"name"`
} }
type ShowResponse struct { type ShowResponse struct {
@ -171,11 +179,14 @@ type CopyRequest struct {
} }
type PullRequest struct { type PullRequest struct {
Name string `json:"name"` Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"` Insecure bool `json:"insecure,omitempty"`
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
Stream *bool `json:"stream,omitempty"` Stream *bool `json:"stream,omitempty"`
// Name is deprecated, see Model
Name string `json:"name"`
} }
type ProgressResponse struct { type ProgressResponse struct {
@ -186,11 +197,14 @@ type ProgressResponse struct {
} }
type PushRequest struct { type PushRequest struct {
Name string `json:"name"` Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"` Insecure bool `json:"insecure,omitempty"`
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
Stream *bool `json:"stream,omitempty"` Stream *bool `json:"stream,omitempty"`
// Name is deprecated, see Model
Name string `json:"name"`
} }
type ListResponse struct { type ListResponse struct {

View file

@ -414,8 +414,13 @@ func PullModelHandler(c *gin.Context) {
return return
} }
if req.Name == "" { var model string
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"}) if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return return
} }
@ -433,7 +438,7 @@ func PullModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
if err := PullModel(ctx, req.Name, regOpts, fn); err != nil { if err := PullModel(ctx, model, regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
@ -458,8 +463,13 @@ func PushModelHandler(c *gin.Context) {
return return
} }
if req.Name == "" { var model string
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"}) if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return return
} }
@ -477,7 +487,7 @@ func PushModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
if err := PushModel(ctx, req.Name, regOpts, fn); err != nil { if err := PushModel(ctx, model, regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
@ -502,12 +512,17 @@ func CreateModelHandler(c *gin.Context) {
return return
} }
if req.Name == "" { var model string
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"}) if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return return
} }
if err := ParseModelPath(req.Name).Validate(); err != nil { if err := ParseModelPath(model).Validate(); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
@ -545,7 +560,7 @@ func CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
if err := CreateModel(ctx, req.Name, filepath.Dir(req.Path), commands, fn); err != nil { if err := CreateModel(ctx, model, filepath.Dir(req.Path), commands, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
@ -570,14 +585,19 @@ func DeleteModelHandler(c *gin.Context) {
return return
} }
if req.Name == "" { var model string
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"}) if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return return
} }
if err := DeleteModel(req.Name); err != nil { if err := DeleteModel(model); err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Name)}) c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", model)})
} else { } else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
} }
@ -610,21 +630,20 @@ func ShowModelHandler(c *gin.Context) {
return return
} }
switch { var model string
case req.Model == "" && req.Name == "": if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return return
case req.Model != "" && req.Name != "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "both model and name are set"})
return
case req.Model == "" && req.Name != "":
req.Model = req.Name
} }
resp, err := GetModelInfo(req) resp, err := GetModelInfo(req)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Name)}) c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", model)})
} else { } else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
} }