From 790d24eb7bd1b15192a8acd79b60e225aaa6688e Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Wed, 6 Sep 2023 11:04:17 -0700 Subject: [PATCH] add show command (#474) --- api/client.go | 8 +++ api/types.go | 12 ++++ cmd/cmd.go | 149 ++++++++++++++++++++++++++++++++++------------- server/images.go | 107 +++++++++++++++++++++++++++++++--- server/routes.go | 73 +++++++++++++++++++++++ 5 files changed, 299 insertions(+), 50 deletions(-) diff --git a/api/client.go b/api/client.go index dc69f689..87975a9f 100644 --- a/api/client.go +++ b/api/client.go @@ -255,6 +255,14 @@ func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error { return nil } +func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) { + var resp ShowResponse + if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + func (c *Client) Heartbeat(ctx context.Context) error { if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil { return err diff --git a/api/types.go b/api/types.go index 8c97a792..edd333d2 100644 --- a/api/types.go +++ b/api/types.go @@ -61,6 +61,18 @@ type DeleteRequest struct { Name string `json:"name"` } +type ShowRequest struct { + Name string `json:"name"` +} + +type ShowResponse struct { + License string `json:"license,omitempty"` + Modelfile string `json:"modelfile,omitempty"` + Parameters string `json:"parameters,omitempty"` + Template string `json:"template,omitempty"` + System string `json:"system,omitempty"` +} + type CopyRequest struct { Source string `json:"source"` Destination string `json:"destination"` diff --git a/cmd/cmd.go b/cmd/cmd.go index 5123aa5b..ab5257b2 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -230,6 +230,84 @@ func DeleteHandler(cmd *cobra.Command, args []string) error { return nil } +func ShowHandler(cmd *cobra.Command, args []string) error { + client, err := api.FromEnv() + if err != nil { + return err + } + + if len(args) != 1 { + return errors.New("missing model name") + } + + license, errLicense := cmd.Flags().GetBool("license") + modelfile, errModelfile := cmd.Flags().GetBool("modelfile") + parameters, errParams := cmd.Flags().GetBool("parameters") + system, errSystem := cmd.Flags().GetBool("system") + template, errTemplate := cmd.Flags().GetBool("template") + + for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate} { + if boolErr != nil { + return errors.New("error retrieving flags") + } + } + + flagsSet := 0 + showType := "" + + if license { + flagsSet++ + showType = "license" + } + + if modelfile { + flagsSet++ + showType = "modelfile" + } + + if parameters { + flagsSet++ + showType = "parameters" + } + + if system { + flagsSet++ + showType = "system" + } + + if template { + flagsSet++ + showType = "template" + } + + if flagsSet > 1 { + return errors.New("only one of 'license', 'modelfile', 'parameters', 'system', or 'template' can be set") + } else if flagsSet == 0 { + return errors.New("one of 'license', 'modelfile', 'parameters', 'system', or 'template' must be set") + } + + req := api.ShowRequest{Name: args[0]} + resp, err := client.Show(context.Background(), &req) + if err != nil { + return err + } + + switch showType { + case "license": + fmt.Println(resp.License) + case "modelfile": + fmt.Println(resp.Modelfile) + case "parameters": + fmt.Println(resp.Parameters) + case "system": + fmt.Println(resp.System) + case "template": + fmt.Println(resp.Template) + } + + return nil +} + func CopyHandler(cmd *cobra.Command, args []string) error { client, err := api.FromEnv() if err != nil { @@ -377,20 +455,6 @@ func generate(cmd *cobra.Command, model, prompt string) error { return nil } -func showLayer(l *server.Layer) { - filename, err := server.GetBlobsPath(l.Digest) - if err != nil { - fmt.Println("Couldn't get layer's path") - return - } - bts, err := os.ReadFile(filename) - if err != nil { - fmt.Println("Couldn't read layer") - return - } - fmt.Println(string(bts)) -} - func generateInteractive(cmd *cobra.Command, model string) error { home, err := os.UserHomeDir() if err != nil { @@ -413,6 +477,8 @@ func generateInteractive(cmd *cobra.Command, model string) error { ), readline.PcItem("/show", readline.PcItem("license"), + readline.PcItem("modelfile"), + readline.PcItem("parameters"), readline.PcItem("system"), readline.PcItem("template"), ), @@ -522,42 +588,28 @@ func generateInteractive(cmd *cobra.Command, model string) error { case strings.HasPrefix(line, "/show"): args := strings.Fields(line) if len(args) > 1 { - mp := server.ParseModelPath(model) + resp, err := server.GetModelInfo(model) if err != nil { - return err + fmt.Println("error: couldn't get model") + continue } - manifest, _, err := server.GetManifest(mp) - if err != nil { - fmt.Println("error: couldn't get a manifest for this model") - continue - } switch args[1] { case "license": - for _, l := range manifest.Layers { - if l.MediaType == "application/vnd.ollama.image.license" { - showLayer(l) - } - } - continue + fmt.Println(resp.License) + case "modelfile": + fmt.Println(resp.Modelfile) + case "parameters": + fmt.Println(resp.Parameters) case "system": - for _, l := range manifest.Layers { - if l.MediaType == "application/vnd.ollama.image.system" { - showLayer(l) - } - } - continue + fmt.Println(resp.System) case "template": - for _, l := range manifest.Layers { - if l.MediaType == "application/vnd.ollama.image.template" { - showLayer(l) - } - } - continue + fmt.Println(resp.Template) default: - usage() - continue + fmt.Println("error: unknown command") } + + continue } else { usage() continue @@ -749,6 +801,20 @@ func NewCLI() *cobra.Command { createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")") + showCmd := &cobra.Command{ + Use: "show MODEL", + Short: "Show information for a model", + Args: cobra.MinimumNArgs(1), + PreRunE: checkServerHeartbeat, + RunE: ShowHandler, + } + + showCmd.Flags().Bool("license", false, "Show license of a model") + showCmd.Flags().Bool("modelfile", false, "Show Modelfile of a model") + showCmd.Flags().Bool("parameters", false, "Show parameters of a model") + showCmd.Flags().Bool("template", false, "Show template of a model") + showCmd.Flags().Bool("system", false, "Show system prompt of a model") + runCmd := &cobra.Command{ Use: "run MODEL [PROMPT]", Short: "Run a model", @@ -814,6 +880,7 @@ func NewCLI() *cobra.Command { rootCmd.AddCommand( serveCmd, createCmd, + showCmd, runCmd, pullCmd, pushCmd, diff --git a/server/images.go b/server/images.go index cc284510..1356c9e9 100644 --- a/server/images.go +++ b/server/images.go @@ -41,15 +41,18 @@ type RegistryOptions struct { } type Model struct { - Name string `json:"name"` - ModelPath string - AdapterPaths []string - Template string - System string - Digest string - ConfigDigest string - Options map[string]interface{} - Embeddings []vector.Embedding + Name string `json:"name"` + ShortName string + ModelPath string + OriginalModel string + AdapterPaths []string + Template string + System string + License []string + Digest string + ConfigDigest string + Options map[string]interface{} + Embeddings []vector.Embedding } func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) { @@ -171,9 +174,11 @@ func GetModel(name string) (*Model, error) { model := &Model{ Name: mp.GetFullTagname(), + ShortName: mp.GetShortTagname(), Digest: digest, ConfigDigest: manifest.Config.Digest, Template: "{{ .Prompt }}", + License: []string{}, } for _, layer := range manifest.Layers { @@ -185,6 +190,7 @@ func GetModel(name string) (*Model, error) { switch layer.MediaType { case "application/vnd.ollama.image.model": model.ModelPath = filename + model.OriginalModel = layer.From case "application/vnd.ollama.image.embed": file, err := os.Open(filename) if err != nil { @@ -229,6 +235,12 @@ func GetModel(name string) (*Model, error) { if err = json.NewDecoder(params).Decode(&model.Options); err != nil { return nil, err } + case "application/vnd.ollama.image.license": + bts, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + model.License = append(model.License, string(bts)) } } @@ -933,6 +945,83 @@ func DeleteModel(name string) error { return nil } +func ShowModelfile(model *Model) (string, error) { + type modelTemplate struct { + *Model + From string + Params string + } + + var params []string + for k, v := range model.Options { + switch val := v.(type) { + case string: + params = append(params, fmt.Sprintf("PARAMETER %s %s", k, val)) + case int: + params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.Itoa(val))) + case float64: + params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.FormatFloat(val, 'f', 0, 64))) + case bool: + params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.FormatBool(val))) + case []interface{}: + for _, nv := range val { + switch nval := nv.(type) { + case string: + params = append(params, fmt.Sprintf("PARAMETER %s %s", k, nval)) + case int: + params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.Itoa(nval))) + case float64: + params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.FormatFloat(nval, 'f', 0, 64))) + case bool: + params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.FormatBool(nval))) + default: + log.Printf("unknown type: %s", reflect.TypeOf(nv).String()) + } + } + default: + log.Printf("unknown type: %s", reflect.TypeOf(v).String()) + } + } + + mt := modelTemplate{ + Model: model, + From: model.OriginalModel, + Params: strings.Join(params, "\n"), + } + + if mt.From == "" { + mt.From = model.ModelPath + } + + modelFile := `# Modelfile generated by "ollama show" +# To build a new Modelfile based on this one, replace the FROM line with: +# FROM {{ .ShortName }} + +FROM {{ .From }} +TEMPLATE """{{ .Template }}""" +SYSTEM """{{ .System }}""" +{{ .Params }} +` + for _, l := range mt.Model.AdapterPaths { + modelFile += fmt.Sprintf("ADAPTER %s\n", l) + } + + tmpl, err := template.New("").Parse(modelFile) + if err != nil { + log.Printf("error parsing template: %q", err) + return "", err + } + + var buf bytes.Buffer + + if err = tmpl.Execute(&buf, mt); err != nil { + log.Printf("error executing template: %q", err) + return "", err + } + + return buf.String(), nil +} + func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) fn(api.ProgressResponse{Status: "retrieving manifest"}) diff --git a/server/routes.go b/server/routes.go index 1a049cbd..44871bba 100644 --- a/server/routes.go +++ b/server/routes.go @@ -12,6 +12,7 @@ import ( "os/signal" "path/filepath" "reflect" + "strconv" "strings" "sync" "syscall" @@ -364,6 +365,77 @@ func DeleteModelHandler(c *gin.Context) { } } +func ShowModelHandler(c *gin.Context) { + var req api.ShowRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + resp, err := GetModelInfo(req.Name) + if err != nil { + if os.IsNotExist(err) { + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Name)}) + } else { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + } + return + } + + c.JSON(http.StatusOK, resp) +} + +func GetModelInfo(name string) (*api.ShowResponse, error) { + model, err := GetModel(name) + if err != nil { + return nil, err + } + + resp := &api.ShowResponse{ + License: strings.Join(model.License, "\n"), + System: model.System, + Template: model.Template, + } + + mf, err := ShowModelfile(model) + if err != nil { + return nil, err + } + + resp.Modelfile = mf + + var params []string + cs := 30 + for k, v := range model.Options { + switch val := v.(type) { + case string: + params = append(params, fmt.Sprintf("%-*s %s", cs, k, val)) + case int: + params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.Itoa(val))) + case float64: + params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatFloat(val, 'f', 0, 64))) + case bool: + params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatBool(val))) + case []interface{}: + for _, nv := range val { + switch nval := nv.(type) { + case string: + params = append(params, fmt.Sprintf("%-*s %s", cs, k, nval)) + case int: + params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.Itoa(nval))) + case float64: + params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatFloat(nval, 'f', 0, 64))) + case bool: + params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatBool(nval))) + } + } + } + } + resp.Parameters = strings.Join(params, "\n") + + return resp, nil +} + func ListModelsHandler(c *gin.Context) { var models []api.ModelResponse fp, err := GetManifestPath() @@ -457,6 +529,7 @@ func Serve(ln net.Listener, origins []string) error { r.POST("/api/copy", CopyModelHandler) r.GET("/api/tags", ListModelsHandler) r.DELETE("/api/delete", DeleteModelHandler) + r.POST("/api/show", ShowModelHandler) log.Printf("Listening on %s", ln.Addr()) s := &http.Server{