diff --git a/cmd/cmd.go b/cmd/cmd.go index 08254799..d713a35b 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -97,7 +97,16 @@ func CreateHandler(cmd *cobra.Command, args []string) error { } func RunHandler(cmd *cobra.Command, args []string) error { - mp := server.ParseModelPath(args[0]) + insecure, err := cmd.Flags().GetBool("insecure") + if err != nil { + return err + } + + mp, err := server.ParseModelPath(args[0], insecure) + if err != nil { + return err + } + fp, err := mp.GetManifestPath(false) if err != nil { return err @@ -106,7 +115,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { _, err = os.Stat(fp) switch { case errors.Is(err, os.ErrNotExist): - if err := pull(args[0], false); err != nil { + if err := pull(args[0], insecure); err != nil { var apiStatusError api.StatusError if !errors.As(err, &apiStatusError) { return err @@ -506,7 +515,11 @@ func generateInteractive(cmd *cobra.Command, model string) error { case strings.HasPrefix(line, "/show"): args := strings.Fields(line) if len(args) > 1 { - mp := server.ParseModelPath(model) + mp, err := server.ParseModelPath(model, false) + if err != nil { + return err + } + manifest, err := server.GetManifest(mp) if err != nil { fmt.Println("error: couldn't get a manifest for this model") @@ -569,7 +582,7 @@ func generateBatch(cmd *cobra.Command, model string) error { } func RunServer(cmd *cobra.Command, _ []string) error { - var host, port = "127.0.0.1", "11434" + host, port := "127.0.0.1", "11434" parts := strings.Split(os.Getenv("OLLAMA_HOST"), ":") if ip := net.ParseIP(parts[0]); ip != nil { @@ -630,7 +643,7 @@ func initializeKeypair() error { return fmt.Errorf("could not create directory %w", err) } - err = os.WriteFile(privKeyPath, pem.EncodeToMemory(privKeyBytes), 0600) + err = os.WriteFile(privKeyPath, pem.EncodeToMemory(privKeyBytes), 0o600) if err != nil { return err } @@ -642,7 +655,7 @@ func initializeKeypair() error { pubKeyData := ssh.MarshalAuthorizedKey(sshPrivateKey.PublicKey()) - err = os.WriteFile(pubKeyPath, pubKeyData, 0644) + err = os.WriteFile(pubKeyPath, pubKeyData, 0o644) if err != nil { return err } @@ -737,6 +750,7 @@ func NewCLI() *cobra.Command { } runCmd.Flags().Bool("verbose", false, "Show timings for response") + runCmd.Flags().Bool("insecure", false, "Use an insecure registry") serveCmd := &cobra.Command{ Use: "serve", diff --git a/server/images.go b/server/images.go index 2c14ec8d..0c0d428e 100644 --- a/server/images.go +++ b/server/images.go @@ -153,7 +153,10 @@ func GetManifest(mp ModelPath) (*ManifestV2, error) { } func GetModel(name string) (*Model, error) { - mp := ParseModelPath(name) + mp, err := ParseModelPath(name, false) + if err != nil { + return nil, err + } manifest, err := GetManifest(mp) if err != nil { @@ -272,7 +275,12 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api case "model": fn(api.ProgressResponse{Status: "looking for model"}) embed.model = c.Args - mp := ParseModelPath(c.Args) + + mp, err := ParseModelPath(c.Args, false) + if err != nil { + return err + } + mf, err := GetManifest(mp) if err != nil { modelFile, err := filenameWithPath(path, c.Args) @@ -286,7 +294,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil { return err } - mf, err = GetManifest(ParseModelPath(c.Args)) + mf, err = GetManifest(mp) if err != nil { return fmt.Errorf("failed to open file after pull: %v", err) } @@ -674,7 +682,10 @@ func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force } func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error { - mp := ParseModelPath(name) + mp, err := ParseModelPath(name, false) + if err != nil { + return err + } manifest := ManifestV2{ SchemaVersion: 2, @@ -806,11 +817,22 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) { } func CopyModel(src, dest string) error { - srcPath, err := ParseModelPath(src).GetManifestPath(false) + srcModelPath, err := ParseModelPath(src, false) if err != nil { return err } - destPath, err := ParseModelPath(dest).GetManifestPath(true) + + srcPath, err := srcModelPath.GetManifestPath(false) + if err != nil { + return err + } + + destModelPath, err := ParseModelPath(dest, false) + if err != nil { + return err + } + + destPath, err := destModelPath.GetManifestPath(true) if err != nil { return err } @@ -832,7 +854,10 @@ func CopyModel(src, dest string) error { } func DeleteModel(name string) error { - mp := ParseModelPath(name) + mp, err := ParseModelPath(name, false) + if err != nil { + return err + } manifest, err := GetManifest(mp) if err != nil { @@ -859,7 +884,10 @@ func DeleteModel(name string) error { return nil } tag := path[:slashIndex] + ":" + path[slashIndex+1:] - fmp := ParseModelPath(tag) + fmp, err := ParseModelPath(tag, false) + if err != nil { + return err + } // skip the manifest we're trying to delete if mp.GetFullTagname() == fmp.GetFullTagname() { @@ -912,7 +940,10 @@ func DeleteModel(name string) error { } func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { - mp := ParseModelPath(name) + mp, err := ParseModelPath(name, regOpts.Insecure) + if err != nil { + return err + } fn(api.ProgressResponse{Status: "retrieving manifest"}) @@ -995,7 +1026,10 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu } func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { - mp := ParseModelPath(name) + mp, err := ParseModelPath(name, regOpts.Insecure) + if err != nil { + return err + } fn(api.ProgressResponse{Status: "pulling manifest"}) diff --git a/server/modelpath.go b/server/modelpath.go index 49f903ce..e331f1f6 100644 --- a/server/modelpath.go +++ b/server/modelpath.go @@ -1,6 +1,7 @@ package server import ( + "errors" "fmt" "os" "path/filepath" @@ -23,42 +24,54 @@ const ( DefaultProtocolScheme = "https" ) -func ParseModelPath(name string) ModelPath { - slashParts := strings.Split(name, "/") - var registry, namespace, repository, tag string +var ( + ErrInvalidImageFormat = errors.New("invalid image format") + ErrInvalidProtocol = errors.New("invalid protocol scheme") + ErrInsecureProtocol = errors.New("insecure protocol http") +) +func ParseModelPath(name string, allowInsecure bool) (ModelPath, error) { + mp := ModelPath{ + ProtocolScheme: DefaultProtocolScheme, + Registry: DefaultRegistry, + Namespace: DefaultNamespace, + Repository: "", + Tag: DefaultTag, + } + + protocol, rest, didSplit := strings.Cut(name, "://") + if didSplit { + if protocol == "https" || protocol == "http" && allowInsecure { + mp.ProtocolScheme = protocol + name = rest + } else if protocol == "http" && !allowInsecure { + return ModelPath{}, ErrInsecureProtocol + } else { + return ModelPath{}, ErrInvalidProtocol + } + } + + slashParts := strings.Split(name, "/") switch len(slashParts) { case 3: - registry = slashParts[0] - namespace = slashParts[1] - repository = strings.Split(slashParts[2], ":")[0] + mp.Registry = slashParts[0] + mp.Namespace = slashParts[1] + mp.Repository = slashParts[2] case 2: - registry = DefaultRegistry - namespace = slashParts[0] - repository = strings.Split(slashParts[1], ":")[0] + mp.Namespace = slashParts[0] + mp.Repository = slashParts[1] case 1: - registry = DefaultRegistry - namespace = DefaultNamespace - repository = strings.Split(slashParts[0], ":")[0] + mp.Repository = slashParts[0] default: - fmt.Println("Invalid image format.") - return ModelPath{} + return ModelPath{}, ErrInvalidImageFormat } - colonParts := strings.Split(slashParts[len(slashParts)-1], ":") - if len(colonParts) == 2 { - tag = colonParts[1] - } else { - tag = DefaultTag + if repo, tag, didSplit := strings.Cut(mp.Repository, ":"); didSplit { + mp.Repository = repo + mp.Tag = tag } - return ModelPath{ - ProtocolScheme: DefaultProtocolScheme, - Registry: registry, - Namespace: namespace, - Repository: repository, - Tag: tag, - } + return mp, nil } func (mp ModelPath) GetNamespaceRepository() string { diff --git a/server/modelpath_test.go b/server/modelpath_test.go new file mode 100644 index 00000000..2641af90 --- /dev/null +++ b/server/modelpath_test.go @@ -0,0 +1,122 @@ +package server + +import "testing" + +func TestParseModelPath(t *testing.T) { + type input struct { + name string + allowInsecure bool + } + + tests := []struct { + name string + args input + want ModelPath + wantErr error + }{ + { + "full path https", + input{"https://example.com/ns/repo:tag", false}, + ModelPath{ + ProtocolScheme: "https", + Registry: "example.com", + Namespace: "ns", + Repository: "repo", + Tag: "tag", + }, + nil, + }, + { + "full path http without insecure", + input{"http://example.com/ns/repo:tag", false}, + ModelPath{}, + ErrInsecureProtocol, + }, + { + "full path http with insecure", + input{"http://example.com/ns/repo:tag", true}, + ModelPath{ + ProtocolScheme: "http", + Registry: "example.com", + Namespace: "ns", + Repository: "repo", + Tag: "tag", + }, + nil, + }, + { + "full path invalid protocol", + input{"file://example.com/ns/repo:tag", false}, + ModelPath{}, + ErrInvalidProtocol, + }, + { + "no protocol", + input{"example.com/ns/repo:tag", false}, + ModelPath{ + ProtocolScheme: "https", + Registry: "example.com", + Namespace: "ns", + Repository: "repo", + Tag: "tag", + }, + nil, + }, + { + "no registry", + input{"ns/repo:tag", false}, + ModelPath{ + ProtocolScheme: "https", + Registry: DefaultRegistry, + Namespace: "ns", + Repository: "repo", + Tag: "tag", + }, + nil, + }, + { + "no namespace", + input{"repo:tag", false}, + ModelPath{ + ProtocolScheme: "https", + Registry: DefaultRegistry, + Namespace: DefaultNamespace, + Repository: "repo", + Tag: "tag", + }, + nil, + }, + { + "no tag", + input{"repo", false}, + ModelPath{ + ProtocolScheme: "https", + Registry: DefaultRegistry, + Namespace: DefaultNamespace, + Repository: "repo", + Tag: DefaultTag, + }, + nil, + }, + { + "invalid image format", + input{"example.com/a/b/c", false}, + ModelPath{}, + ErrInvalidImageFormat, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := ParseModelPath(tc.args.name, tc.args.allowInsecure) + + if err != tc.wantErr { + t.Errorf("got: %q want: %q", err, tc.wantErr) + } + + if got != tc.want { + t.Errorf("got: %q want: %q", got, tc.want) + } + }) + } +} diff --git a/server/routes.go b/server/routes.go index 7e78178c..d0dc3d32 100644 --- a/server/routes.go +++ b/server/routes.go @@ -357,7 +357,12 @@ func ListModelsHandler(c *gin.Context) { return nil } tag := path[:slashIndex] + ":" + path[slashIndex+1:] - mp := ParseModelPath(tag) + + mp, err := ParseModelPath(tag, false) + if err != nil { + return err + } + manifest, err := GetManifest(mp) if err != nil { log.Printf("skipping file: %s", fp)