diff --git a/types/model/name.go b/types/model/name.go index 906b3152..1a820c4f 100644 --- a/types/model/name.go +++ b/types/model/name.go @@ -7,6 +7,7 @@ import ( "hash/maphash" "io" "log/slog" + "path/filepath" "slices" "strings" "sync" @@ -54,6 +55,10 @@ const ( PartBuild PartDigest + // NumParts is the number of parts in a Name. In this list, it must + // follow the final part. + NumParts + PartExtraneous = -1 ) @@ -100,7 +105,7 @@ func (k PartKind) String() string { // To check if a Name has at minimum a valid model part, use [Name.IsValid]. type Name struct { _ structs.Incomparable - parts [6]string // host, namespace, model, tag, build, digest + parts [NumParts]string // host, namespace, model, tag, build, digest // TODO(bmizerany): track offsets and hold s (raw string) here? We // could pack the offsets all into a single uint64 since the first @@ -159,7 +164,6 @@ func ParseName(s, fill string) Name { return true }) if r.IsValid() || r.IsResolved() { - fill = cmp.Or(fill, FillDefault) return fillName(r, fill) } return Name{} @@ -195,7 +199,11 @@ func MustParseName(s, defaults string) Name { // // It skipps fill parts that are "?". func fillName(r Name, fill string) Name { + fill = cmp.Or(fill, FillDefault) f := parseMask(fill) + if fill != FillNothing && f.IsZero() { + panic("invalid fill") + } for i := range r.parts { if f.parts[i] == "?" { continue @@ -226,7 +234,7 @@ func (r Name) MapHash() uint64 { // correctly hash the parts with case insensitive comparison var h maphash.Hash h.SetSeed(mapHashSeed) - for _, part := range r.Parts() { + for _, part := range r.parts { // downcase the part for hashing for i := range part { c := part[i] @@ -267,11 +275,8 @@ func (r Name) slice(from, to PartKind) Name { // // If mask is the empty string, then [MaskDefault] is used. // -// # Safety -// -// To avoid unsafe behavior, DisplayShortest will panic if r is the zero -// value to prevent the returns of a "" string. Callers should consult -// [Name.IsValid] before calling this method. +// DisplayShortest panics if the mask is not the empty string, MaskNothing, and +// invalid. // // # Builds // @@ -280,10 +285,7 @@ func (r Name) slice(from, to PartKind) Name { func (r Name) DisplayShortest(mask string) string { mask = cmp.Or(mask, MaskDefault) d := parseMask(mask) - if d.IsZero() { - panic(fmt.Errorf("invalid mask %q", mask)) - } - if r.IsZero() { + if mask != MaskNothing && r.IsZero() { panic("invalid Name") } for i := range PartTag { @@ -298,7 +300,12 @@ func (r Name) DisplayShortest(mask string) string { } r.parts[i] = "" } - return r.slice(PartHost, PartTag).String() + return r.slice(PartHost, PartTag).DisplayLong() +} + +// DisplayLongest returns the result of r.DisplayShortest(MaskNothing). +func (r Name) DisplayLongest() string { + return r.DisplayShortest(MaskNothing) } var seps = [...]string{ @@ -345,15 +352,12 @@ var builderPool = sync.Pool{ }, } -// String returns the fullest possible display string in form: +// DisplayLong returns the fullest possible display string in form: // // //:+ // // If any part is missing, it is omitted from the display string. -// -// For the fullest possible display string without the build, use -// [Name.DisplayFullest]. -func (r Name) String() string { +func (r Name) DisplayLong() string { b := builderPool.Get().(*strings.Builder) defer builderPool.Put(b) b.Reset() @@ -363,14 +367,14 @@ func (r Name) String() string { } // GoString implements fmt.GoStringer. It returns a string suitable for -// debugging and logging. It is similar to [Name.String] but it always +// debugging and logging. It is similar to [Name.DisplayLong] but it always // returns a string that includes all parts of the Name, with missing parts // replaced with a ("?"). func (r Name) GoString() string { for i := range r.parts { r.parts[i] = cmp.Or(r.parts[i], "?") } - return r.String() + return r.DisplayLong() } // LogValue implements slog.Valuer. @@ -435,14 +439,11 @@ func downcase(r rune) rune { return r } -// TODO(bmizerany): driver.Value? (MarshalText etc should be enough) - -// Parts returns the parts of the Name in order of concreteness. -// -// The length of the returned slice is always 5. -func (r Name) Parts() []string { - return slices.Clone(r.parts[:]) -} +func (r Name) Host() string { return r.parts[PartHost] } +func (r Name) Namespace() string { return r.parts[PartNamespace] } +func (r Name) Model() string { return r.parts[PartModel] } +func (r Name) Build() string { return r.parts[PartBuild] } +func (r Name) Tag() string { return r.parts[PartTag] } // iter_Seq2 is a iter.Seq2 defined here to avoid the current build // restrictions in the go1.22 iter package requiring the @@ -461,7 +462,7 @@ type iter_Seq2[A, B any] func(func(A, B) bool) func parts(s string) iter_Seq2[PartKind, string] { return func(yield func(PartKind, string) bool) { if strings.HasPrefix(s, "http://") { - s = s[len("http://"):] + s = strings.TrimPrefix(s, "http://") } else { s = strings.TrimPrefix(s, "https://") } @@ -561,7 +562,7 @@ func parts(s string) iter_Seq2[PartKind, string] { } func (r Name) IsZero() bool { - return r.parts == [6]string{} + return r.parts == [NumParts]string{} } // IsValid reports if a model has at minimum a valid model part. @@ -571,13 +572,92 @@ func (r Name) IsValid() bool { return r.parts[PartModel] != "" } +// ParseNameFromURLPath parses forms of a URL path into a Name. Specifically, +// it trims any leading "/" and then calls [ParseName] with fill. +func ParseNameFromURLPath(s, fill string) Name { + s = strings.TrimPrefix(s, "/") + return ParseName(s, fill) +} + +// URLPath returns a complete, canonicalized, relative URL path using the parts of a +// complete Name. +// +// The parts maintain their original case. +// +// Example: +// +// ParseName("example.com/namespace/model:tag+build").URLPath() // returns "/example.com/namespace/model:tag" +func (r Name) URLPath() string { + return r.DisplayShortest(MaskNothing) +} + +// ParseNameFromFilepath parses a file path into a Name. The input string must be a +// valid file path representation of a model name in the form: +// +// host/namespace/model/tag/build +// +// The zero valid is returned if s does not contain all path elements +// leading up to the model part, or if any path element is an invalid part +// for the its corresponding part kind. +// +// The fill string is used to fill in missing parts of any constructed Name. +// See [ParseName] for more information on the fill string. +func ParseNameFromFilepath(s, fill string) Name { + var r Name + for i := range PartBuild + 1 { + part, rest, _ := strings.Cut(s, string(filepath.Separator)) + if !isValidPart(i, part) { + return Name{} + } + r.parts[i] = part + s = rest + if s == "" { + break + } + } + if s != "" { + return Name{} + } + if !r.IsValid() { + return Name{} + } + return fillName(r, fill) +} + +// Filepath returns a complete, canonicalized, relative file path using the +// parts of a complete Name. +// +// Each parts is downcased, except for the build part which is upcased. +// +// Example: +// +// ParseName("example.com/namespace/model:tag+build").Filepath() // returns "example.com/namespace/model/tag/BUILD" +func (r Name) Filepath() string { + for i := range r.parts { + if PartKind(i) == PartBuild { + r.parts[i] = strings.ToUpper(r.parts[i]) + } else { + r.parts[i] = strings.ToLower(r.parts[i]) + } + } + return filepath.Join(r.parts[:]...) +} + // isValidPart reports if s contains all valid characters for the given // part kind. func isValidPart(kind PartKind, s string) bool { if s == "" { return false } + var consecutiveDots int for _, c := range []byte(s) { + if c == '.' { + if consecutiveDots++; consecutiveDots >= 2 { + return false + } + } else { + consecutiveDots = 0 + } if !isValidByteFor(kind, c) { return false } diff --git a/types/model/name_test.go b/types/model/name_test.go index 14d36b64..166e0a57 100644 --- a/types/model/name_test.go +++ b/types/model/name_test.go @@ -5,6 +5,7 @@ import ( "cmp" "fmt" "log/slog" + "path/filepath" "slices" "strings" "testing" @@ -111,11 +112,11 @@ func TestNameConsecutiveDots(t *testing.T) { for i := 1; i < 10; i++ { s := strings.Repeat(".", i) if i > 1 { - if g := ParseName(s, FillNothing).String(); g != "" { + if g := ParseName(s, FillNothing).DisplayLong(); g != "" { t.Errorf("ParseName(%q) = %q; want empty string", s, g) } } else { - if g := ParseName(s, FillNothing).String(); g != s { + if g := ParseName(s, FillNothing).DisplayLong(); g != s { t.Errorf("ParseName(%q) = %q; want %q", s, g, s) } } @@ -124,7 +125,7 @@ func TestNameConsecutiveDots(t *testing.T) { func TestNameParts(t *testing.T) { var p Name - if w, g := int(PartDigest+1), len(p.Parts()); w != g { + if w, g := int(NumParts), len(p.parts); w != g { t.Errorf("Parts() = %d; want %d", g, w) } } @@ -155,8 +156,8 @@ func TestParseName(t *testing.T) { } // test round-trip - if !ParseName(name.String(), FillNothing).EqualFold(name) { - t.Errorf("ParseName(%q).String() = %s; want %s", s, name.String(), baseName) + if !ParseName(name.DisplayLong(), FillNothing).EqualFold(name) { + t.Errorf("ParseName(%q).String() = %s; want %s", s, name.DisplayLong(), baseName) } }) } @@ -181,11 +182,20 @@ func TestParseNameFill(t *testing.T) { for _, tt := range cases { t.Run(tt.in, func(t *testing.T) { name := ParseName(tt.in, tt.fill) - if g := name.String(); g != tt.want { + if g := name.DisplayLong(); g != tt.want { t.Errorf("ParseName(%q, %q) = %q; want %q", tt.in, tt.fill, g, tt.want) } }) } + + t.Run("invalid fill", func(t *testing.T) { + defer func() { + if recover() == nil { + t.Fatal("expected panic") + } + }() + ParseName("x", "^") + }) } func TestParseNameHTTPDoublePrefixStrip(t *testing.T) { @@ -379,6 +389,22 @@ func BenchmarkParseName(b *testing.B) { } } +func FuzzParseNameFromFilepath(f *testing.F) { + f.Add("example.com/library/mistral/7b/Q4_0") + f.Add("example.com/../mistral/7b/Q4_0") + f.Add("example.com/x/../7b/Q4_0") + f.Add("example.com/x/../7b") + f.Fuzz(func(t *testing.T, s string) { + name := ParseNameFromFilepath(s, FillNothing) + if strings.Contains(s, "..") && !name.IsZero() { + t.Fatalf("non-zero value for path with '..': %q", s) + } + if name.IsValid() == name.IsZero() { + t.Errorf("expected valid path to be non-zero value; got %#v", name) + } + }) +} + func FuzzParseName(f *testing.F) { f.Add("example.com/mistral:7b+Q4_0") f.Add("example.com/mistral:7b+q4_0") @@ -403,17 +429,17 @@ func FuzzParseName(f *testing.F) { t.Skipf("invalid path: %q", s) } - for _, p := range r0.Parts() { + for _, p := range r0.parts { if len(p) > MaxNamePartLen { t.Errorf("part too long: %q", p) } } - if !strings.EqualFold(r0.String(), s) { - t.Errorf("String() did not round-trip with case insensitivity: %q\ngot = %q\nwant = %q", s, r0.String(), s) + if !strings.EqualFold(r0.DisplayLong(), s) { + t.Errorf("String() did not round-trip with case insensitivity: %q\ngot = %q\nwant = %q", s, r0.DisplayLong(), s) } - r1 := ParseName(r0.String(), FillNothing) + r1 := ParseName(r0.DisplayLong(), FillNothing) if !r0.EqualFold(r1) { t.Errorf("round-trip mismatch: %+v != %+v", r0, r1) } @@ -423,13 +449,173 @@ func FuzzParseName(f *testing.F) { func TestNameStringAllocs(t *testing.T) { name := ParseName("example.com/ns/mistral:latest+Q4_0", FillNothing) allocs := testing.AllocsPerRun(1000, func() { - keep(name.String()) + keep(name.DisplayLong()) }) if allocs > 1 { t.Errorf("String allocs = %v; want 0", allocs) } } +func TestNamePath(t *testing.T) { + cases := []struct { + in string + want string + }{ + {"example.com/library/mistral:latest+Q4_0", "example.com/library/mistral:latest"}, + + // incomplete + {"example.com/library/mistral:latest", "example.com/library/mistral:latest"}, + {"", ""}, + } + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + p := ParseName(tt.in, FillNothing) + t.Logf("ParseName(%q) = %#v", tt.in, p) + if g := p.URLPath(); g != tt.want { + t.Errorf("got = %q; want %q", g, tt.want) + } + }) + } +} + +func TestNameFromFilepath(t *testing.T) { + cases := []struct { + in string + want string + }{ + { + in: "example.com/library/mistral:latest+Q4_0", + want: "example.com/library/mistral/latest/Q4_0", + }, + { + in: "Example.Com/Library/Mistral:Latest+Q4_0", + want: "example.com/library/mistral/latest/Q4_0", + }, + { + in: "Example.Com/Library/Mistral:Latest+Q4_0", + want: "example.com/library/mistral/latest/Q4_0", + }, + { + in: "example.com/library/mistral:latest", + want: "example.com/library/mistral/latest", + }, + { + in: "", + want: "", + }, + } + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + p := ParseName(tt.in, FillNothing) + t.Logf("ParseName(%q) = %#v", tt.in, p) + g := p.Filepath() + g = filepath.ToSlash(g) + if g != tt.want { + t.Errorf("got = %q; want %q", g, tt.want) + } + }) + } +} + +func TestParseNameFilepath(t *testing.T) { + cases := []struct { + in string + fill string // default is FillNothing + want string + }{ + { + in: "example.com/library/mistral/latest/Q4_0", + want: "example.com/library/mistral:latest+Q4_0", + }, + { + in: "example.com/library/mistral/latest", + fill: "?/?/?:latest+Q4_0", + want: "example.com/library/mistral:latest+Q4_0", + }, + { + in: "example.com/library/mistral", + fill: "?/?/?:latest+Q4_0", + want: "example.com/library/mistral:latest+Q4_0", + }, + { + in: "example.com/library", + want: "", + }, + { + in: "example.com/", + want: "", + }, + { + in: "example.com/^/mistral/latest/Q4_0", + want: "", + }, + { + in: "example.com/library/mistral/../Q4_0", + want: "", + }, + { + in: "example.com/library/mistral/latest/Q4_0/extra", + want: "", + }, + } + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + in := strings.ReplaceAll(tt.in, "/", string(filepath.Separator)) + fill := cmp.Or(tt.fill, FillNothing) + want := ParseName(tt.want, fill) + if g := ParseNameFromFilepath(in, fill); !g.EqualFold(want) { + t.Errorf("got = %q; want %q", g.DisplayLong(), tt.want) + } + }) + } +} + +func TestParseNameFromPath(t *testing.T) { + cases := []struct { + in string + want string + fill string // default is FillNothing + }{ + { + in: "example.com/library/mistral:latest+Q4_0", + want: "example.com/library/mistral:latest+Q4_0", + }, + { + in: "/example.com/library/mistral:latest+Q4_0", + want: "example.com/library/mistral:latest+Q4_0", + }, + { + in: "/example.com/library/mistral", + want: "example.com/library/mistral", + }, + { + in: "/example.com/library/mistral", + fill: "?/?/?:latest+Q4_0", + want: "example.com/library/mistral:latest+Q4_0", + }, + { + in: "/example.com/library", + want: "", + }, + { + in: "/example.com/", + want: "", + }, + { + in: "/example.com/^/mistral/latest", + want: "", + }, + } + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + fill := cmp.Or(tt.fill, FillNothing) + if g := ParseNameFromURLPath(tt.in, fill); g.DisplayLong() != tt.want { + t.Errorf("got = %q; want %q", g.DisplayLong(), tt.want) + } + }) + } +} + func ExampleName_MapHash() { m := map[uint64]bool{} @@ -456,7 +642,7 @@ func ExampleName_CompareFold_sort() { slices.SortFunc(names, Name.CompareFold) for _, n := range names { - fmt.Println(n) + fmt.Println(n.DisplayLong()) } // Output: