From 8bbff2df986629e5481547e913ab4de0245afb37 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Mon, 28 Aug 2023 20:50:24 -0700 Subject: [PATCH] add model IDs (#439) --- api/types.go | 1 + cmd/cmd.go | 6 +++--- server/images.go | 36 +++++++++++++++++++++--------------- server/routes.go | 3 ++- 4 files changed, 27 insertions(+), 19 deletions(-) diff --git a/api/types.go b/api/types.go index 7ed102bf..b42a3626 100644 --- a/api/types.go +++ b/api/types.go @@ -96,6 +96,7 @@ type ListResponseModel struct { Name string `json:"name"` ModifiedAt time.Time `json:"modified_at"` Size int `json:"size"` + Digest string `json:"digest"` } type TokenResponse struct { diff --git a/cmd/cmd.go b/cmd/cmd.go index 5ce83326..5123aa5b 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -196,12 +196,12 @@ func ListHandler(cmd *cobra.Command, args []string) error { for _, m := range models.Models { if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) { - data = append(data, []string{m.Name, humanize.Bytes(uint64(m.Size)), format.HumanTime(m.ModifiedAt, "Never")}) + data = append(data, []string{m.Name, m.Digest[:12], humanize.Bytes(uint64(m.Size)), format.HumanTime(m.ModifiedAt, "Never")}) } } table := tablewriter.NewWriter(os.Stdout) - table.SetHeader([]string{"NAME", "SIZE", "MODIFIED"}) + table.SetHeader([]string{"NAME", "ID", "SIZE", "MODIFIED"}) table.SetHeaderAlignment(tablewriter.ALIGN_LEFT) table.SetAlignment(tablewriter.ALIGN_LEFT) table.SetHeaderLine(false) @@ -527,7 +527,7 @@ func generateInteractive(cmd *cobra.Command, model string) error { return err } - manifest, err := server.GetManifest(mp) + manifest, _, err := server.GetManifest(mp) if err != nil { fmt.Println("error: couldn't get a manifest for this model") continue diff --git a/server/images.go b/server/images.go index f847e09e..5598d6c8 100644 --- a/server/images.go +++ b/server/images.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "crypto/sha256" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -44,6 +45,7 @@ type Model struct { Template string System string Digest string + ConfigDigest string Options map[string]interface{} Embeddings []vector.Embedding } @@ -131,41 +133,45 @@ func (m *ManifestV2) GetTotalSize() int { return total } -func GetManifest(mp ModelPath) (*ManifestV2, error) { +func GetManifest(mp ModelPath) (*ManifestV2, string, error) { fp, err := mp.GetManifestPath(false) if err != nil { - return nil, err + return nil, "", err } if _, err = os.Stat(fp); err != nil { - return nil, err + return nil, "", err } var manifest *ManifestV2 bts, err := os.ReadFile(fp) if err != nil { - return nil, fmt.Errorf("couldn't open file '%s'", fp) + return nil, "", fmt.Errorf("couldn't open file '%s'", fp) } + shaSum := sha256.Sum256(bts) + shaStr := hex.EncodeToString(shaSum[:]) + if err := json.Unmarshal(bts, &manifest); err != nil { - return nil, err + return nil, "", err } - return manifest, nil + return manifest, shaStr, nil } func GetModel(name string) (*Model, error) { mp := ParseModelPath(name) - manifest, err := GetManifest(mp) + manifest, digest, err := GetManifest(mp) if err != nil { return nil, err } model := &Model{ - Name: mp.GetFullTagname(), - Digest: manifest.Config.Digest, - Template: "{{ .Prompt }}", + Name: mp.GetFullTagname(), + Digest: digest, + ConfigDigest: manifest.Config.Digest, + Template: "{{ .Prompt }}", } for _, layer := range manifest.Layers { @@ -277,7 +283,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api embed.model = c.Args mp := ParseModelPath(c.Args) - mf, err := GetManifest(mp) + mf, _, err := GetManifest(mp) if err != nil { modelFile, err := filenameWithPath(path, c.Args) if err != nil { @@ -290,7 +296,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil { return err } - mf, err = GetManifest(mp) + mf, _, err = GetManifest(mp) if err != nil { return fmt.Errorf("failed to open file after pull: %v", err) } @@ -839,7 +845,7 @@ func CopyModel(src, dest string) error { func DeleteModel(name string) error { mp := ParseModelPath(name) - manifest, err := GetManifest(mp) + manifest, _, err := GetManifest(mp) if err != nil { return err } @@ -872,7 +878,7 @@ func DeleteModel(name string) error { } // save (i.e. delete from the deleteMap) any files used in other manifests - manifest, err := GetManifest(fmp) + manifest, _, err := GetManifest(fmp) if err != nil { log.Printf("skipping file: %s", fp) return nil @@ -924,7 +930,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu return fmt.Errorf("insecure protocol http") } - manifest, err := GetManifest(mp) + manifest, _, err := GetManifest(mp) if err != nil { fn(api.ProgressResponse{Status: "couldn't retrieve manifest"}) return err diff --git a/server/routes.go b/server/routes.go index dbe0a6ad..dd1846f9 100644 --- a/server/routes.go +++ b/server/routes.go @@ -373,7 +373,7 @@ func ListModelsHandler(c *gin.Context) { tag := path[:slashIndex] + ":" + path[slashIndex+1:] mp := ParseModelPath(tag) - manifest, err := GetManifest(mp) + manifest, digest, err := GetManifest(mp) if err != nil { log.Printf("skipping file: %s", fp) return nil @@ -381,6 +381,7 @@ func ListModelsHandler(c *gin.Context) { model := api.ListResponseModel{ Name: mp.GetShortTagname(), Size: manifest.GetTotalSize(), + Digest: digest, ModifiedAt: fi.ModTime(), } models = append(models, model)