From 6a1de2317582c6faa49e62c5d73ce58e37668dad Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Wed, 10 Apr 2024 16:30:05 -0700 Subject: [PATCH] types/model: init with Name and Digest types (#3541) --- types/model/digest.go | 120 ++++ types/model/digest_test.go | 46 ++ types/model/name.go | 581 ++++++++++++++++++ types/model/name_test.go | 490 +++++++++++++++ .../fuzz/FuzzParseRef/1d43ee52085cb4aa | 2 + .../fuzz/FuzzParseRef/27fd759314f0e6d6 | 2 + .../fuzz/FuzzParseRef/3e3b70dba384074d | 2 + .../fuzz/FuzzParseRef/71f1fdff711b6dab | 2 + .../fuzz/FuzzParseRef/82c2975c430ac608 | 2 + .../fuzz/FuzzParseRef/b51b1c875e61a948 | 2 + types/structs/structs.go | 15 + 11 files changed, 1264 insertions(+) create mode 100644 types/model/digest.go create mode 100644 types/model/digest_test.go create mode 100644 types/model/name.go create mode 100644 types/model/name_test.go create mode 100644 types/model/testdata/fuzz/FuzzParseRef/1d43ee52085cb4aa create mode 100644 types/model/testdata/fuzz/FuzzParseRef/27fd759314f0e6d6 create mode 100644 types/model/testdata/fuzz/FuzzParseRef/3e3b70dba384074d create mode 100644 types/model/testdata/fuzz/FuzzParseRef/71f1fdff711b6dab create mode 100644 types/model/testdata/fuzz/FuzzParseRef/82c2975c430ac608 create mode 100644 types/model/testdata/fuzz/FuzzParseRef/b51b1c875e61a948 create mode 100644 types/structs/structs.go diff --git a/types/model/digest.go b/types/model/digest.go new file mode 100644 index 00000000..f3cefa00 --- /dev/null +++ b/types/model/digest.go @@ -0,0 +1,120 @@ +package model + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "log/slog" + "strings" + "unicode" +) + +// Digest represents a digest of a model Manifest. It is a comparable value +// type and is immutable. +// +// The zero Digest is not a valid digest. +type Digest struct { + s string +} + +// Type returns the digest type of the digest. +// +// Example: +// +// ParseDigest("sha256-1234").Type() // returns "sha256" +func (d Digest) Type() string { + typ, _, _ := strings.Cut(d.s, "-") + return typ +} + +// String returns the digest in the form of "-", or the +// empty string if the digest is invalid. +func (d Digest) String() string { return d.s } + +// IsValid returns true if the digest is valid (not zero). +// +// A valid digest may be created only by ParseDigest, or +// ParseName(name).Digest(). +func (d Digest) IsValid() bool { return d.s != "" } + +// MarshalText implements encoding.TextMarshaler. +func (d Digest) MarshalText() ([]byte, error) { + return []byte(d.String()), nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (d *Digest) UnmarshalText(text []byte) error { + if d.IsValid() { + return errors.New("model.Digest: illegal UnmarshalText on valid Digest") + } + *d = ParseDigest(string(text)) + return nil +} + +// LogValue implements slog.Value. +func (d Digest) LogValue() slog.Value { + return slog.StringValue(d.String()) +} + +var ( + _ driver.Valuer = Digest{} + _ sql.Scanner = (*Digest)(nil) + _ slog.LogValuer = Digest{} +) + +// Scan implements the sql.Scanner interface. +func (d *Digest) Scan(src any) error { + if d.IsValid() { + return errors.New("model.Digest: illegal Scan on valid Digest") + } + switch v := src.(type) { + case string: + *d = ParseDigest(v) + return nil + case []byte: + *d = ParseDigest(string(v)) + return nil + } + return fmt.Errorf("model.Digest: invalid Scan source %T", src) +} + +// Value implements the driver.Valuer interface. +func (d Digest) Value() (driver.Value, error) { + return d.String(), nil +} + +// ParseDigest parses a string in the form of "-" into a +// Digest. +func ParseDigest(s string) Digest { + typ, digest, ok := strings.Cut(s, "-") + if ok && isValidDigestType(typ) && isValidHex(digest) { + return Digest{s: s} + } + return Digest{} +} + +func isValidDigestType(s string) bool { + if len(s) == 0 { + return false + } + for _, r := range s { + if !unicode.IsLower(r) && !unicode.IsDigit(r) { + return false + } + } + return true +} + +func isValidHex(s string) bool { + if len(s) == 0 { + return false + } + for i := range s { + c := s[i] + if c < '0' || c > '9' && c < 'a' || c > 'f' { + return false + } + } + return true +} diff --git a/types/model/digest_test.go b/types/model/digest_test.go new file mode 100644 index 00000000..5096a28a --- /dev/null +++ b/types/model/digest_test.go @@ -0,0 +1,46 @@ +package model + +import "testing" + +var testDigests = map[string]Digest{ + "": {}, + "sha256-1234": {s: "sha256-1234"}, + "sha256-5678": {s: "sha256-5678"}, + "blake2-9abc": {s: "blake2-9abc"}, + "-1234": {}, + "sha256-": {}, + "sha256-1234-5678": {}, + "sha256-P": {}, // invalid hex + "sha256-1234P": {}, + "---": {}, +} + +func TestDigestParse(t *testing.T) { + // Test cases. + for s, want := range testDigests { + got := ParseDigest(s) + t.Logf("ParseDigest(%q) = %#v", s, got) + if got != want { + t.Errorf("ParseDigest(%q) = %q; want %q", s, got, want) + } + } +} + +func TestDigestString(t *testing.T) { + // Test cases. + for s, d := range testDigests { + want := s + if !d.IsValid() { + want = "" + } + got := d.String() + if got != want { + t.Errorf("ParseDigest(%q).String() = %q; want %q", s, got, want) + } + + got = ParseDigest(s).String() + if got != want { + t.Errorf("roundtrip ParseDigest(%q).String() = %q; want %q", s, got, want) + } + } +} diff --git a/types/model/name.go b/types/model/name.go new file mode 100644 index 00000000..3def95bf --- /dev/null +++ b/types/model/name.go @@ -0,0 +1,581 @@ +package model + +import ( + "cmp" + "errors" + "hash/maphash" + "io" + "log/slog" + "slices" + "strings" + "sync" + + "github.com/ollama/ollama/types/structs" +) + +// Errors +var ( + // ErrInvalidName, ErrIncompleteName, and ErrInvalidDigest are not + // used by this package, but are exported so that other packages can + // use them, instead of defining their own errors for them. + ErrInvalidName = errors.New("invalid model name") + ErrIncompleteName = errors.New("incomplete model name") + ErrInvalidDigest = errors.New("invalid digest") +) + +// Defaults +const ( + // DefaultMask is the default mask used by [Name.DisplayShortest]. + DefaultMask = "registry.ollama.ai/library/_:latest" + + // DefaultFill is the default fill used by [ParseName]. + DefaultFill = "registry.ollama.ai/library/_:latest" +) + +const MaxNamePartLen = 128 + +type PartKind int + +// Levels of concreteness +const ( + // Each value aligns with its index in the Name.parts array. + + PartHost PartKind = iota + PartNamespace + PartModel + PartTag + PartBuild + PartDigest + + // Invalid is a special part that is used to indicate that a part is + // invalid. It is not a valid part of a Name. + // + // It should be kept as the last part in the list. + PartInvalid +) + +var kindNames = map[PartKind]string{ + PartHost: "Host", + PartNamespace: "Namespace", + PartModel: "Name", + PartTag: "Tag", + PartBuild: "Build", + PartDigest: "Digest", + PartInvalid: "Invalid", +} + +func (k PartKind) String() string { + return cmp.Or(kindNames[k], "Unknown") +} + +// Name is an opaque reference to a model. It holds the parts of a model +// with the case preserved, but is not directly comparable with other Names +// since model names can be represented with different casing depending on +// the use case. For instance, "Mistral" and "mistral" are the same model +// but each version may have come from different sources (e.g. copied from a +// Web page, or from a file path). +// +// Valid Names can ONLY be constructed by calling [ParseName]. +// +// A Name is valid if and only if is have a valid Model part. The other parts +// are optional. +// +// A Name is considered "complete" if it has all parts present. To check if a +// Name is complete, use [Name.IsComplete]. +// +// To compare two names in a case-insensitive manner, use [Name.EqualFold]. +// +// The parts of a Name are: +// +// - Host: the domain of the model (optional) +// - Namespace: the namespace of the model (optional) +// - Model: the name of the model (required) +// - Tag: the tag of the model (optional) +// - Build: the build of the model; usually the quantization or "file type" (optional) +// +// The parts can be obtained in their original form by calling [Name.Parts]. +// +// To check if a Name has at minimum a valid model part, use [Name.IsValid]. +// +// To make a Name by filling in missing parts from another Name, use [Fill]. +type Name struct { + _ structs.Incomparable + parts [6]string // host, namespace, model, tag, build, digest + + // TODO(bmizerany): track offsets and hold s (raw string) here? We + // could pack the offsets all into a single uint64 since the first + // parts take less bits since their max offset is less than the max + // offset of the next part. This would save a ton of bytes per Name + // and mean zero allocations for String. +} + +// ParseNameFill parses s into a Name, and returns the result of filling it with +// defaults. The input string must be a valid string +// representation of a model name in the form: +// +// [host/][namespace/][:tag][+build][@-] +// +// The name part is required, all others are optional. If a part is missing, +// it is left empty in the returned Name. If a part is invalid, the zero Ref +// value is returned. +// +// The build part is normalized to uppercase. +// +// Examples of valid paths: +// +// "example.com/library/mistral:7b+x" +// "example.com/eva/mistral:7b+Q4_0" +// "mistral:7b+x" +// "example.com/mike/mistral:latest+Q4_0" +// "example.com/bruce/mistral:latest" +// "example.com/pdevine/thisisfine:7b+Q4_0@sha256-1234567890abcdef" +// +// Examples of invalid paths: +// +// "example.com/mistral:7b+" +// "example.com/mistral:7b+Q4_0+" +// "x/y/z/z:8n+I" +// "" +// +// It returns the zero value if any part is invalid. +// +// As a rule of thumb, an valid name is one that can be round-tripped with +// the [Name.String] method. That means ("x+") is invalid because +// [Name.String] will not print a "+" if the build is empty. +// +// For more about filling in missing parts, see [Fill]. +func ParseNameFill(s, defaults string) Name { + var r Name + parts(s)(func(kind PartKind, part string) bool { + if kind == PartInvalid { + r = Name{} + return false + } + if kind == PartDigest && !ParseDigest(part).IsValid() { + r = Name{} + return false + } + r.parts[kind] = part + return true + }) + if r.IsValid() || r.IsResolved() { + if defaults == "" { + return r + } + return Fill(r, ParseNameFill(defaults, "")) + } + return Name{} +} + +// ParseName is equal to ParseNameFill(s, DefaultFill). +func ParseName(s string) Name { + return ParseNameFill(s, DefaultFill) +} + +func MustParseNameFill(s, defaults string) Name { + r := ParseNameFill(s, "") + if !r.IsValid() { + panic("model.MustParseName: invalid name: " + s) + } + return r +} + +// Fill fills in the missing parts of dst with the parts of src. +// +// The returned Name will only be valid if dst is valid. +func Fill(dst, src Name) Name { + var r Name + for i := range r.parts { + r.parts[i] = cmp.Or(dst.parts[i], src.parts[i]) + } + return r +} + +// WithBuild returns a copy of r with the build set to the given string. +func (r Name) WithBuild(build string) Name { + r.parts[PartBuild] = build + return r +} + +func (r Name) WithDigest(digest Digest) Name { + r.parts[PartDigest] = digest.String() + return r +} + +var mapHashSeed = maphash.MakeSeed() + +// MapHash returns a case insensitive hash for use in maps and equality +// checks. For a convenient way to compare names, use [Name.EqualFold]. +// +//nolint:errcheck +func (r Name) MapHash() uint64 { + // correctly hash the parts with case insensitive comparison + var h maphash.Hash + h.SetSeed(mapHashSeed) + for _, part := range r.Parts() { + // downcase the part for hashing + for i := range part { + c := part[i] + if c >= 'A' && c <= 'Z' { + c = c - 'A' + 'a' + } + h.WriteByte(c) + } + } + return h.Sum64() +} + +func (r Name) slice(from, to PartKind) Name { + var v Name + copy(v.parts[from:to+1], r.parts[from:to+1]) + return v +} + +// DisplayShortest returns the shortest possible display string in form: +// +// [host/][/][:] +// +// The host is omitted if it is the mask host is the same as r. +// The namespace is omitted if the host and the namespace are the same as r. +// The tag is omitted if it is the mask tag is the same as r. +func (r Name) DisplayShortest(mask string) string { + mask = cmp.Or(mask, DefaultMask) + d := ParseName(mask) + if !d.IsValid() { + panic("mask is an invalid Name") + } + equalSlice := func(form, to PartKind) bool { + return r.slice(form, to).EqualFold(d.slice(form, to)) + } + if equalSlice(PartHost, PartNamespace) { + r.parts[PartNamespace] = "" + } + if equalSlice(PartHost, PartHost) { + r.parts[PartHost] = "" + } + if equalSlice(PartTag, PartTag) { + r.parts[PartTag] = "" + } + return r.slice(PartHost, PartTag).String() +} + +// DisplayLong returns the fullest possible display string in form: +// +// /: +// +// If any part is missing, it is omitted from the display string. +func (r Name) DisplayLong() string { + return r.slice(PartNamespace, PartTag).String() +} + +var seps = [...]string{ + PartHost: "/", + PartNamespace: "/", + PartModel: ":", + PartTag: "+", + PartBuild: "@", + PartDigest: "", +} + +// WriteTo implements io.WriterTo. It writes the fullest possible display +// string in form: +// +// //:+@- +// +// Missing parts and their separators are not written. +// +// The full digest is always prefixed with "@". That is if [Name.IsValid] +// reports false and [Name.IsResolved] reports true, then the string is +// returned as "@-". +func (r Name) writeTo(w io.StringWriter) error { + var partsWritten int + for i := range r.parts { + if r.parts[i] == "" { + continue + } + if partsWritten > 0 || i == int(PartDigest) { + if _, err := w.WriteString(seps[i-1]); err != nil { + return err + } + } + if _, err := w.WriteString(r.parts[i]); err != nil { + return err + } + partsWritten++ + } + return nil +} + +var builderPool = sync.Pool{ + New: func() interface{} { + return &strings.Builder{} + }, +} + +// String returns the fullest possible display string in form: +// +// //:+ +// +// If any part is missing, it is omitted from the display string. +// +// For the fullest possible display string without the build, use +// [Name.DisplayFullest]. +func (r Name) String() string { + b := builderPool.Get().(*strings.Builder) + defer builderPool.Put(b) + b.Reset() + b.Grow(50) // arbitrarily long enough for most names + _ = r.writeTo(b) + return b.String() +} + +// GoString implements fmt.GoStringer. It returns a string suitable for +// debugging and logging. It is similar to [Name.String] but it always +// returns a string that includes all parts of the Name, with missing parts +// replaced with a ("?"). +func (r Name) GoString() string { + for i := range r.parts { + r.parts[i] = cmp.Or(r.parts[i], "?") + } + return r.String() +} + +// LogValue implements slog.Valuer. +func (r Name) LogValue() slog.Value { + return slog.StringValue(r.GoString()) +} + +// IsComplete reports whether the Name is fully qualified. That is it has a +// domain, namespace, name, tag, and build. +func (r Name) IsComplete() bool { + return !slices.Contains(r.parts[:PartDigest], "") +} + +// IsCompleteNoBuild is like [Name.IsComplete] but it does not require the +// build part to be present. +func (r Name) IsCompleteNoBuild() bool { + return !slices.Contains(r.parts[:PartBuild], "") +} + +// IsResolved reports true if the Name has a valid digest. +// +// It is possible to have a valid Name, or a complete Name that is not +// resolved. +func (r Name) IsResolved() bool { + return r.Digest().IsValid() +} + +// Digest returns the digest part of the Name, if any. +// +// If Digest returns a non-empty string, then [Name.IsResolved] will return +// true, and digest is considered valid. +func (r Name) Digest() Digest { + // This was already validated by ParseName, so we can just return it. + return Digest{r.parts[PartDigest]} +} + +// EqualFold reports whether r and o are equivalent model names, ignoring +// case. +func (r Name) EqualFold(o Name) bool { + return r.CompareFold(o) == 0 +} + +// CompareFold performs a case-insensitive cmp.Compare on r and o. +// +// This can be used with [slices.SortFunc]. +// +// For simple equality checks, use [Name.EqualFold]. +func (r Name) CompareFold(o Name) int { + return slices.CompareFunc(r.parts[:], o.parts[:], compareFold) +} + +func compareFold(a, b string) int { + return slices.CompareFunc([]rune(a), []rune(b), func(a, b rune) int { + return cmp.Compare(downcase(a), downcase(b)) + }) +} + +func downcase(r rune) rune { + if r >= 'A' && r <= 'Z' { + return r - 'A' + 'a' + } + return r +} + +// TODO(bmizerany): driver.Value? (MarshalText etc should be enough) + +// Parts returns the parts of the Name in order of concreteness. +// +// The length of the returned slice is always 5. +func (r Name) Parts() []string { + return slices.Clone(r.parts[:]) +} + +// iter_Seq2 is a iter.Seq2 defined here to avoid the current build +// restrictions in the go1.22 iter package requiring the +// goexperiment.rangefunc tag to be set via the GOEXPERIMENT=rangefunc flag, +// which we are not yet ready to support. +// +// Once we are ready to support rangefunc, this can be removed and replaced +// with the iter.Seq2 type. +type iter_Seq2[A, B any] func(func(A, B) bool) + +// Parts returns a sequence of the parts of a Name string from most specific +// to least specific. +// +// It normalizes the input string by removing "http://" and "https://" only. +// No other normalizations are performed. +func parts(s string) iter_Seq2[PartKind, string] { + return func(yield func(PartKind, string) bool) { + //nolint:gosimple + if strings.HasPrefix(s, "http://") { + s = s[len("http://"):] + } + //nolint:gosimple + if strings.HasPrefix(s, "https://") { + s = s[len("https://"):] + } + + if len(s) > MaxNamePartLen || len(s) == 0 { + return + } + + yieldValid := func(kind PartKind, part string) bool { + if !isValidPart(kind, part) { + yield(PartInvalid, "") + return false + } + return yield(kind, part) + } + + numConsecutiveDots := 0 + partLen := 0 + state, j := PartDigest, len(s) + for i := len(s) - 1; i >= 0; i-- { + if partLen++; partLen > MaxNamePartLen { + // catch a part that is too long early, so + // we don't keep spinning on it, waiting for + // an isInValidPart check which would scan + // over it again. + yield(PartInvalid, "") + return + } + + switch s[i] { + case '@': + switch state { + case PartDigest: + if !yieldValid(PartDigest, s[i+1:j]) { + return + } + if i == 0 { + // This is the form + // "@" which is valid. + // + // We're done. + return + } + state, j, partLen = PartBuild, i, 0 + default: + yield(PartInvalid, "") + return + } + case '+': + switch state { + case PartBuild, PartDigest: + if !yieldValid(PartBuild, s[i+1:j]) { + return + } + state, j, partLen = PartTag, i, 0 + default: + yield(PartInvalid, "") + return + } + case ':': + switch state { + case PartTag, PartBuild, PartDigest: + if !yieldValid(PartTag, s[i+1:j]) { + return + } + state, j, partLen = PartModel, i, 0 + default: + yield(PartInvalid, "") + return + } + case '/': + switch state { + case PartModel, PartTag, PartBuild, PartDigest: + if !yieldValid(PartModel, s[i+1:j]) { + return + } + state, j = PartNamespace, i + case PartNamespace: + if !yieldValid(PartNamespace, s[i+1:j]) { + return + } + state, j, partLen = PartHost, i, 0 + default: + yield(PartInvalid, "") + return + } + default: + if s[i] == '.' { + if numConsecutiveDots++; numConsecutiveDots > 1 { + yield(PartInvalid, "") + return + } + } else { + numConsecutiveDots = 0 + } + if !isValidByteFor(state, s[i]) { + yield(PartInvalid, "") + return + } + } + } + + if state <= PartNamespace { + yieldValid(state, s[:j]) + } else { + yieldValid(PartModel, s[:j]) + } + } +} + +func (r Name) IsZero() bool { + return r.parts == [6]string{} +} + +// IsValid reports if a model has at minimum a valid model part. +func (r Name) IsValid() bool { + // Parts ensures we only have valid parts, so no need to validate + // them here, only check if we have a name or not. + return r.parts[PartModel] != "" +} + +// isValidPart reports if s contains all valid characters for the given +// part kind. +func isValidPart(kind PartKind, s string) bool { + if s == "" { + return false + } + for _, c := range []byte(s) { + if !isValidByteFor(kind, c) { + return false + } + } + return true +} + +func isValidByteFor(kind PartKind, c byte) bool { + if kind == PartNamespace && c == '.' { + return false + } + if c == '.' || c == '-' { + return true + } + if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '_' { + return true + } + return false +} diff --git a/types/model/name_test.go b/types/model/name_test.go new file mode 100644 index 00000000..78b57fca --- /dev/null +++ b/types/model/name_test.go @@ -0,0 +1,490 @@ +package model + +import ( + "bytes" + "cmp" + "fmt" + "log/slog" + "slices" + "strings" + "testing" +) + +type fields struct { + host, namespace, model, tag, build string + digest string +} + +func fieldsFromName(p Name) fields { + return fields{ + host: p.parts[PartHost], + namespace: p.parts[PartNamespace], + model: p.parts[PartModel], + tag: p.parts[PartTag], + build: p.parts[PartBuild], + digest: p.parts[PartDigest], + } +} + +var testNames = map[string]fields{ + "mistral:latest": {model: "mistral", tag: "latest"}, + "mistral": {model: "mistral"}, + "mistral:30B": {model: "mistral", tag: "30B"}, + "mistral:7b": {model: "mistral", tag: "7b"}, + "mistral:7b+Q4_0": {model: "mistral", tag: "7b", build: "Q4_0"}, + "mistral+KQED": {model: "mistral", build: "KQED"}, + "mistral.x-3:7b+Q4_0": {model: "mistral.x-3", tag: "7b", build: "Q4_0"}, + "mistral:7b+q4_0": {model: "mistral", tag: "7b", build: "q4_0"}, + "llama2": {model: "llama2"}, + "user/model": {namespace: "user", model: "model"}, + "example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"}, + "example.com/ns/mistral:7b+X": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"}, + + // invalid digest + "mistral:latest@invalid256-": {}, + "mistral:latest@-123": {}, + "mistral:latest@!-123": {}, + "mistral:latest@1-!": {}, + "mistral:latest@": {}, + + // resolved + "x@sha123-1": {model: "x", digest: "sha123-1"}, + "@sha456-2": {digest: "sha456-2"}, + + "@@sha123-1": {}, + + // preserves case for build + "x+b": {model: "x", build: "b"}, + + // invalid (includes fuzzing trophies) + " / / : + ": {}, + " / : + ": {}, + " : + ": {}, + " + ": {}, + " : ": {}, + " / ": {}, + " /": {}, + "/ ": {}, + "/": {}, + ":": {}, + "+": {}, + + // (".") in namepsace is not allowed + "invalid.com/7b+x": {}, + + "invalid:7b+Q4_0:latest": {}, + "in valid": {}, + "invalid/y/z/foo": {}, + "/0": {}, + "0 /0": {}, + "0 /": {}, + "0/": {}, + ":/0": {}, + "+0/00000": {}, + "0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91": {}, + "0//0": {}, + "m+^^^": {}, + "file:///etc/passwd": {}, + "file:///etc/passwd:latest": {}, + "file:///etc/passwd:latest+u": {}, + + ":x": {}, + "+x": {}, + "x+": {}, + + // Disallow ("\.+") in any part to prevent path traversal anywhere + // we convert the name to a path. + "../etc/passwd": {}, + ".../etc/passwd": {}, + "./../passwd": {}, + "./0+..": {}, + + strings.Repeat("a", MaxNamePartLen): {model: strings.Repeat("a", MaxNamePartLen)}, + strings.Repeat("a", MaxNamePartLen+1): {}, +} + +// TestConsecutiveDots tests that consecutive dots are not allowed in any +// part, to avoid path traversal. There also are some tests in testNames, but +// this test is more exhaustive and exists to emphasize the importance of +// preventing path traversal. +func TestNameConsecutiveDots(t *testing.T) { + for i := 1; i < 10; i++ { + s := strings.Repeat(".", i) + if i > 1 { + if g := ParseNameFill(s, "").String(); g != "" { + t.Errorf("ParseName(%q) = %q; want empty string", s, g) + } + } else { + if g := ParseNameFill(s, "").String(); g != s { + t.Errorf("ParseName(%q) = %q; want %q", s, g, s) + } + } + } +} + +func TestNameParts(t *testing.T) { + var p Name + if w, g := int(PartDigest+1), len(p.Parts()); w != g { + t.Errorf("Parts() = %d; want %d", g, w) + } +} + +func TestNamePartString(t *testing.T) { + if g := PartKind(-2).String(); g != "Unknown" { + t.Errorf("Unknown part = %q; want %q", g, "Unknown") + } + for kind, name := range kindNames { + if g := kind.String(); g != name { + t.Errorf("%s = %q; want %q", kind, g, name) + } + } +} + +func TestParseName(t *testing.T) { + for baseName, want := range testNames { + for _, prefix := range []string{"", "https://", "http://"} { + // We should get the same results with or without the + // http(s) prefixes + s := prefix + baseName + + t.Run(s, func(t *testing.T) { + name := ParseNameFill(s, "") + got := fieldsFromName(name) + if got != want { + t.Errorf("ParseName(%q) = %q; want %q", s, got, want) + } + + // test round-trip + if !ParseNameFill(name.String(), "").EqualFold(name) { + t.Errorf("ParseName(%q).String() = %s; want %s", s, name.String(), baseName) + } + }) + } + } +} + +func TestCompleteWithAndWithoutBuild(t *testing.T) { + cases := []struct { + in string + complete bool + completeNoBuild bool + }{ + {"", false, false}, + {"incomplete/mistral:7b+x", false, false}, + {"incomplete/mistral:7b+Q4_0", false, false}, + {"incomplete:7b+x", false, false}, + {"complete.com/x/mistral:latest+Q4_0", true, true}, + {"complete.com/x/mistral:latest", false, true}, + } + + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + p := ParseNameFill(tt.in, "") + t.Logf("ParseName(%q) = %#v", tt.in, p) + if g := p.IsComplete(); g != tt.complete { + t.Errorf("Complete(%q) = %v; want %v", tt.in, g, tt.complete) + } + if g := p.IsCompleteNoBuild(); g != tt.completeNoBuild { + t.Errorf("CompleteNoBuild(%q) = %v; want %v", tt.in, g, tt.completeNoBuild) + } + }) + } + + // Complete uses Parts which returns a slice, but it should be + // inlined when used in Complete, preventing any allocations or + // escaping to the heap. + allocs := testing.AllocsPerRun(1000, func() { + keep(ParseNameFill("complete.com/x/mistral:latest+Q4_0", "").IsComplete()) + }) + if allocs > 0 { + t.Errorf("Complete allocs = %v; want 0", allocs) + } +} + +func TestNameLogValue(t *testing.T) { + cases := []string{ + "example.com/library/mistral:latest+Q4_0", + "mistral:latest", + "mistral:7b+Q4_0", + } + for _, s := range cases { + t.Run(s, func(t *testing.T) { + var b bytes.Buffer + log := slog.New(slog.NewTextHandler(&b, nil)) + name := ParseNameFill(s, "") + log.Info("", "name", name) + want := fmt.Sprintf("name=%s", name.GoString()) + got := b.String() + if !strings.Contains(got, want) { + t.Errorf("expected log output to contain %q; got %q", want, got) + } + }) + } +} + +func TestNameGoString(t *testing.T) { + cases := []struct { + name string + in string + wantString string + wantGoString string // default is tt.in + }{ + { + name: "Complete Name", + in: "example.com/library/mistral:latest+Q4_0", + wantGoString: "example.com/library/mistral:latest+Q4_0@?", + }, + { + name: "Short Name", + in: "mistral:latest", + wantGoString: "?/?/mistral:latest+?@?", + }, + { + name: "Long Name", + in: "library/mistral:latest", + wantGoString: "?/library/mistral:latest+?@?", + }, + { + name: "Case Preserved", + in: "Library/Mistral:Latest", + wantGoString: "?/Library/Mistral:Latest+?@?", + }, + { + name: "With digest", + in: "Library/Mistral:Latest@sha256-123456", + wantGoString: "?/Library/Mistral:Latest+?@sha256-123456", + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + p := ParseNameFill(tt.in, "") + tt.wantGoString = cmp.Or(tt.wantGoString, tt.in) + if g := fmt.Sprintf("%#v", p); g != tt.wantGoString { + t.Errorf("GoString() = %q; want %q", g, tt.wantGoString) + } + }) + } +} + +func TestDisplayShortest(t *testing.T) { + cases := []struct { + in string + mask string + want string + wantPanic bool + }{ + {"example.com/library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false}, + {"example.com/library/mistral:latest+Q4_0", "example.com/_/_:latest", "library/mistral", false}, + {"example.com/library/mistral:latest+Q4_0", "", "example.com/library/mistral", false}, + {"example.com/library/mistral:latest+Q4_0", "", "example.com/library/mistral", false}, + + // case-insensitive + {"Example.com/library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false}, + {"example.com/Library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false}, + {"example.com/library/Mistral:latest+Q4_0", "example.com/library/_:latest", "Mistral", false}, + {"example.com/library/mistral:Latest+Q4_0", "example.com/library/_:latest", "mistral", false}, + {"example.com/library/mistral:Latest+q4_0", "example.com/library/_:latest", "mistral", false}, + + // invalid mask + {"example.com/library/mistral:latest+Q4_0", "example.com/mistral", "", true}, + + // DefaultMask + {"registry.ollama.ai/library/mistral:latest+Q4_0", DefaultMask, "mistral", false}, + + // Auto-Fill + {"x", "example.com/library/_:latest", "x", false}, + {"x", "example.com/library/_:latest+Q4_0", "x", false}, + {"x/y:z", "a.com/library/_:latest+Q4_0", "x/y:z", false}, + {"x/y:z", "a.com/library/_:latest+Q4_0", "x/y:z", false}, + } + + for _, tt := range cases { + t.Run("", func(t *testing.T) { + defer func() { + if tt.wantPanic { + if recover() == nil { + t.Errorf("expected panic") + } + } + }() + + p := ParseNameFill(tt.in, "") + t.Logf("ParseName(%q) = %#v", tt.in, p) + if g := p.DisplayShortest(tt.mask); g != tt.want { + t.Errorf("got = %q; want %q", g, tt.want) + } + }) + } +} + +func TestParseNameAllocs(t *testing.T) { + allocs := testing.AllocsPerRun(1000, func() { + keep(ParseNameFill("example.com/mistral:7b+Q4_0", "")) + }) + if allocs > 0 { + t.Errorf("ParseName allocs = %v; want 0", allocs) + } +} + +func BenchmarkParseName(b *testing.B) { + b.ReportAllocs() + + for range b.N { + keep(ParseNameFill("example.com/mistral:7b+Q4_0", "")) + } +} + +func FuzzParseName(f *testing.F) { + f.Add("example.com/mistral:7b+Q4_0") + f.Add("example.com/mistral:7b+q4_0") + f.Add("example.com/mistral:7b+x") + f.Add("x/y/z:8n+I") + f.Add(":x") + f.Add("@sha256-123456") + f.Add("example.com/mistral:latest+Q4_0@sha256-123456") + f.Add(":@!@") + f.Add("...") + f.Fuzz(func(t *testing.T, s string) { + r0 := ParseNameFill(s, "") + + if strings.Contains(s, "..") && !r0.IsZero() { + t.Fatalf("non-zero value for path with '..': %q", s) + } + + if !r0.IsValid() && !r0.IsResolved() { + if !r0.EqualFold(Name{}) { + t.Errorf("expected invalid path to be zero value; got %#v", r0) + } + t.Skipf("invalid path: %q", s) + } + + for _, p := range r0.Parts() { + if len(p) > MaxNamePartLen { + t.Errorf("part too long: %q", p) + } + } + + if !strings.EqualFold(r0.String(), s) { + t.Errorf("String() did not round-trip with case insensitivity: %q\ngot = %q\nwant = %q", s, r0.String(), s) + } + + r1 := ParseNameFill(r0.String(), "") + if !r0.EqualFold(r1) { + t.Errorf("round-trip mismatch: %+v != %+v", r0, r1) + } + }) +} + +func TestFill(t *testing.T) { + cases := []struct { + dst string + src string + want string + }{ + {"mistral", "o.com/library/PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"}, + {"o.com/library/mistral", "PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"}, + {"", "o.com/library/mistral:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"}, + } + + for _, tt := range cases { + t.Run(tt.dst, func(t *testing.T) { + r := Fill(ParseNameFill(tt.dst, ""), ParseNameFill(tt.src, "")) + if r.String() != tt.want { + t.Errorf("Fill(%q, %q) = %q; want %q", tt.dst, tt.src, r, tt.want) + } + }) + } +} + +func TestNameStringAllocs(t *testing.T) { + name := ParseNameFill("example.com/ns/mistral:latest+Q4_0", "") + allocs := testing.AllocsPerRun(1000, func() { + keep(name.String()) + }) + if allocs > 1 { + t.Errorf("String allocs = %v; want 0", allocs) + } +} + +func ExampleFill() { + defaults := ParseNameFill("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0", "") + r := Fill(ParseNameFill("mistral", ""), defaults) + fmt.Println(r) + + // Output: + // registry.ollama.com/library/mistral:latest+Q4_0 +} + +func ExampleName_MapHash() { + m := map[uint64]bool{} + + // key 1 + m[ParseNameFill("mistral:latest+q4", "").MapHash()] = true + m[ParseNameFill("miSTRal:latest+Q4", "").MapHash()] = true + m[ParseNameFill("mistral:LATest+Q4", "").MapHash()] = true + + // key 2 + m[ParseNameFill("mistral:LATest", "").MapHash()] = true + + fmt.Println(len(m)) + // Output: + // 2 +} + +func ExampleName_CompareFold_sort() { + names := []Name{ + ParseNameFill("mistral:latest", ""), + ParseNameFill("mistRal:7b+q4", ""), + ParseNameFill("MIstral:7b", ""), + } + + slices.SortFunc(names, Name.CompareFold) + + for _, n := range names { + fmt.Println(n) + } + + // Output: + // MIstral:7b + // mistRal:7b+q4 + // mistral:latest +} + +func ExampleName_completeAndResolved() { + for _, s := range []string{ + "x/y/z:latest+q4_0@sha123-1", + "x/y/z:latest+q4_0", + "@sha123-1", + } { + name := ParseNameFill(s, "") + fmt.Printf("complete:%v resolved:%v digest:%s\n", name.IsComplete(), name.IsResolved(), name.Digest()) + } + + // Output: + // complete:true resolved:true digest:sha123-1 + // complete:true resolved:false digest: + // complete:false resolved:true digest:sha123-1 +} + +func ExampleName_DisplayShortest() { + name := ParseNameFill("example.com/jmorganca/mistral:latest+Q4_0", "") + + fmt.Println(name.DisplayShortest("example.com/jmorganca/_:latest")) + fmt.Println(name.DisplayShortest("example.com/_/_:latest")) + fmt.Println(name.DisplayShortest("example.com/_/_:_")) + fmt.Println(name.DisplayShortest("_/_/_:_")) + + // Default + name = ParseNameFill("registry.ollama.ai/library/mistral:latest+Q4_0", "") + fmt.Println(name.DisplayShortest("")) + + // Output: + // mistral + // jmorganca/mistral + // jmorganca/mistral:latest + // example.com/jmorganca/mistral:latest + // mistral +} + +func keep[T any](v T) T { return v } diff --git a/types/model/testdata/fuzz/FuzzParseRef/1d43ee52085cb4aa b/types/model/testdata/fuzz/FuzzParseRef/1d43ee52085cb4aa new file mode 100644 index 00000000..0cdf1eac --- /dev/null +++ b/types/model/testdata/fuzz/FuzzParseRef/1d43ee52085cb4aa @@ -0,0 +1,2 @@ +go test fuzz v1 +string("/0") diff --git a/types/model/testdata/fuzz/FuzzParseRef/27fd759314f0e6d6 b/types/model/testdata/fuzz/FuzzParseRef/27fd759314f0e6d6 new file mode 100644 index 00000000..c5d09a4c --- /dev/null +++ b/types/model/testdata/fuzz/FuzzParseRef/27fd759314f0e6d6 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("0//0") diff --git a/types/model/testdata/fuzz/FuzzParseRef/3e3b70dba384074d b/types/model/testdata/fuzz/FuzzParseRef/3e3b70dba384074d new file mode 100644 index 00000000..880ce7a3 --- /dev/null +++ b/types/model/testdata/fuzz/FuzzParseRef/3e3b70dba384074d @@ -0,0 +1,2 @@ +go test fuzz v1 +string("0 /0") diff --git a/types/model/testdata/fuzz/FuzzParseRef/71f1fdff711b6dab b/types/model/testdata/fuzz/FuzzParseRef/71f1fdff711b6dab new file mode 100644 index 00000000..fa981c52 --- /dev/null +++ b/types/model/testdata/fuzz/FuzzParseRef/71f1fdff711b6dab @@ -0,0 +1,2 @@ +go test fuzz v1 +string("+0/00000") diff --git a/types/model/testdata/fuzz/FuzzParseRef/82c2975c430ac608 b/types/model/testdata/fuzz/FuzzParseRef/82c2975c430ac608 new file mode 100644 index 00000000..0a66beb8 --- /dev/null +++ b/types/model/testdata/fuzz/FuzzParseRef/82c2975c430ac608 @@ -0,0 +1,2 @@ +go test fuzz v1 +string(":") diff --git a/types/model/testdata/fuzz/FuzzParseRef/b51b1c875e61a948 b/types/model/testdata/fuzz/FuzzParseRef/b51b1c875e61a948 new file mode 100644 index 00000000..db07727d --- /dev/null +++ b/types/model/testdata/fuzz/FuzzParseRef/b51b1c875e61a948 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91") diff --git a/types/structs/structs.go b/types/structs/structs.go new file mode 100644 index 00000000..52929ebf --- /dev/null +++ b/types/structs/structs.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package structs contains the Incomparable type. +package structs + +// Incomparable is a zero-width incomparable type. If added as the +// first field in a struct, it marks that struct as not comparable +// (can't do == or be a map key) and usually doesn't add any width to +// the struct (unless the struct has only small fields). +// +// By making a struct incomparable, you can prevent misuse (prevent +// people from using ==), but also you can shrink generated binaries, +// as the compiler can omit equality funcs from the binary. +type Incomparable [0]func()