From c2714fcbfd600c2a13efbc42bab95b49b0b4fa33 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 6 May 2024 16:34:13 -0700 Subject: [PATCH] routes: use Manifests for ListHandler --- server/manifest.go | 11 ++++- server/manifest_test.go | 90 +++++++++++++++++++++++++++++++++++++++++ server/routes.go | 84 +++++++++++++------------------------- 3 files changed, 127 insertions(+), 58 deletions(-) create mode 100644 server/manifest_test.go diff --git a/server/manifest.go b/server/manifest.go index 36ed5b4c..131d4918 100644 --- a/server/manifest.go +++ b/server/manifest.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "os" "path/filepath" @@ -16,6 +17,7 @@ type Manifest struct { ManifestV2 filepath string + fi os.FileInfo digest string } @@ -65,6 +67,11 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) { } defer f.Close() + fi, err := f.Stat() + if err != nil { + return nil, err + } + sha256sum := sha256.New() if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil { return nil, err @@ -73,6 +80,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) { return &Manifest{ ManifestV2: m, filepath: p, + fi: fi, digest: fmt.Sprintf("%x", sha256sum.Sum(nil)), }, nil } @@ -126,7 +134,8 @@ func Manifests() (map[model.Name]*Manifest, error) { if n.IsValid() { m, err := ParseNamedManifest(n) if err != nil { - return nil, err + slog.Warn("bad manifest", "name", n, "error", err) + continue } ms[n] = m diff --git a/server/manifest_test.go b/server/manifest_test.go new file mode 100644 index 00000000..35c6bc8d --- /dev/null +++ b/server/manifest_test.go @@ -0,0 +1,90 @@ +package server + +import ( + "encoding/json" + "os" + "path/filepath" + "slices" + "testing" + + "github.com/ollama/ollama/types/model" +) + +func createManifest(t *testing.T, path, name string) { + t.Helper() + + p := filepath.Join(path, "manifests", name) + if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil { + t.Fatal(err) + } + + f, err := os.Create(p) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil { + t.Fatal(err) + } +} + +func TestManifests(t *testing.T) { + cases := map[string][]string{ + "empty": {}, + "single": { + filepath.Join("host", "namespace", "model", "tag"), + }, + "multiple": { + filepath.Join("registry.ollama.ai", "library", "llama3", "latest"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q4_0"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q4_1"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q8_0"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q5_0"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q5_1"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q2_K"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_S"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_M"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_L"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_S"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_M"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_S"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_M"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q6_K"), + }, + "hidden": { + filepath.Join("host", "namespace", "model", "tag"), + filepath.Join("host", "namespace", "model", ".hidden"), + }, + } + + for n, wants := range cases { + t.Run(n, func(t *testing.T) { + d := t.TempDir() + t.Setenv("OLLAMA_MODELS", d) + + for _, want := range wants { + createManifest(t, d, want) + } + + ms, err := Manifests() + if err != nil { + t.Fatal(err) + } + + var ns []model.Name + for k := range ms { + ns = append(ns, k) + } + + for _, want := range wants { + n := model.ParseNameFromFilepath(want) + if !n.IsValid() && slices.Contains(ns, n) { + t.Errorf("unexpected invalid name: %s", want) + } else if n.IsValid() && !slices.Contains(ns, n) { + t.Errorf("missing valid name: %s", want) + } + } + }) + } +} diff --git a/server/routes.go b/server/routes.go index ff888e3c..14853feb 100644 --- a/server/routes.go +++ b/server/routes.go @@ -702,72 +702,42 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { } func (s *Server) ListModelsHandler(c *gin.Context) { - manifests, err := GetManifestPath() + ms, err := Manifests() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } models := []api.ModelResponse{} - if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error { - if !info.IsDir() { - rel, err := filepath.Rel(manifests, path) - if err != nil { - return err - } + for n, m := range ms { + f, err := m.Config.Open() + if err != nil { + slog.Warn("bad manifest filepath", "name", n, "error", err) + continue + } + defer f.Close() - if hidden, err := filepath.Match(".*", filepath.Base(rel)); err != nil { - return err - } else if hidden { - return nil - } - - n := model.ParseNameFromFilepath(rel) - if !n.IsValid() { - slog.Warn("bad manifest filepath", "path", rel) - return nil - } - - m, err := ParseNamedManifest(n) - if err != nil { - slog.Warn("bad manifest", "name", n, "error", err) - return nil - } - - f, err := m.Config.Open() - if err != nil { - slog.Warn("bad manifest config filepath", "name", n, "error", err) - return nil - } - defer f.Close() - - var c ConfigV2 - if err := json.NewDecoder(f).Decode(&c); err != nil { - slog.Warn("bad manifest config", "name", n, "error", err) - return nil - } - - // tag should never be masked - models = append(models, api.ModelResponse{ - Model: n.DisplayShortest(), - Name: n.DisplayShortest(), - Size: m.Size(), - Digest: m.digest, - ModifiedAt: info.ModTime(), - Details: api.ModelDetails{ - Format: c.ModelFormat, - Family: c.ModelFamily, - Families: c.ModelFamilies, - ParameterSize: c.ModelType, - QuantizationLevel: c.FileType, - }, - }) + var cf ConfigV2 + if err := json.NewDecoder(f).Decode(&cf); err != nil { + slog.Warn("bad manifest config", "name", n, "error", err) + continue } - return nil - }); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return + // tag should never be masked + models = append(models, api.ModelResponse{ + Model: n.DisplayShortest(), + Name: n.DisplayShortest(), + Size: m.Size(), + Digest: m.digest, + ModifiedAt: m.fi.ModTime(), + Details: api.ModelDetails{ + Format: cf.ModelFormat, + Family: cf.ModelFamily, + Families: cf.ModelFamilies, + ParameterSize: cf.ModelType, + QuantizationLevel: cf.FileType, + }, + }) } slices.SortStableFunc(models, func(i, j api.ModelResponse) int {