From 9f6e97865cb31588fd81dc50bf82043e0b09e343 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Fri, 21 Jul 2023 15:42:19 -0700 Subject: [PATCH] allow pushing/pulling to insecure registries (#157) --- api/types.go | 2 ++ cmd/cmd.go | 30 ++++++++++++++------ server/images.go | 68 ++++++++++++++++++++++++++++----------------- server/modelpath.go | 9 ++++-- server/routes.go | 16 +++++++++-- 5 files changed, 86 insertions(+), 39 deletions(-) diff --git a/api/types.go b/api/types.go index b0433596..e9e72d83 100644 --- a/api/types.go +++ b/api/types.go @@ -50,6 +50,7 @@ type DeleteRequest struct { type PullRequest struct { Name string `json:"name"` + Insecure bool `json:"insecure,omitempty"` Username string `json:"username"` Password string `json:"password"` } @@ -63,6 +64,7 @@ type ProgressResponse struct { type PushRequest struct { Name string `json:"name"` + Insecure bool `json:"insecure,omitempty"` Username string `json:"username"` Password string `json:"password"` } diff --git a/cmd/cmd.go b/cmd/cmd.go index 4ce698a0..ad4e56a9 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -69,7 +69,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { _, err = os.Stat(fp) switch { case errors.Is(err, os.ErrNotExist): - if err := pull(args[0]); err != nil { + if err := pull(args[0], false); err != nil { var apiStatusError api.StatusError if !errors.As(err, &apiStatusError) { return err @@ -89,7 +89,12 @@ func RunHandler(cmd *cobra.Command, args []string) error { func PushHandler(cmd *cobra.Command, args []string) error { client := api.NewClient() - request := api.PushRequest{Name: args[0]} + insecure, err := cmd.Flags().GetBool("insecure") + if err != nil { + return err + } + + request := api.PushRequest{Name: args[0], Insecure: insecure} fn := func(resp api.ProgressResponse) error { fmt.Println(resp.Status) return nil @@ -147,16 +152,21 @@ func DeleteHandler(cmd *cobra.Command, args []string) error { } func PullHandler(cmd *cobra.Command, args []string) error { - return pull(args[0]) + insecure, err := cmd.Flags().GetBool("insecure") + if err != nil { + return err + } + + return pull(args[0], insecure) } -func pull(model string) error { +func pull(model string, insecure bool) error { client := api.NewClient() var currentDigest string var bar *progressbar.ProgressBar - request := api.PullRequest{Name: model} + request := api.PullRequest{Name: model, Insecure: insecure} fn := func(resp api.ProgressResponse) error { if resp.Digest != currentDigest && resp.Digest != "" { currentDigest = resp.Digest @@ -430,6 +440,8 @@ func NewCLI() *cobra.Command { RunE: PullHandler, } + pullCmd.Flags().Bool("insecure", false, "Use an insecure registry") + pushCmd := &cobra.Command{ Use: "push MODEL", Short: "Push a model to a registry", @@ -437,11 +449,13 @@ func NewCLI() *cobra.Command { RunE: PushHandler, } + pushCmd.Flags().Bool("insecure", false, "Use an insecure registry") + listCmd := &cobra.Command{ - Use: "list", + Use: "list", Aliases: []string{"ls"}, - Short: "List models", - RunE: ListHandler, + Short: "List models", + RunE: ListHandler, } deleteCmd := &cobra.Command{ diff --git a/server/images.go b/server/images.go index c635d939..861e58f7 100644 --- a/server/images.go +++ b/server/images.go @@ -22,6 +22,12 @@ import ( "github.com/jmorganca/ollama/parser" ) +type RegistryOptions struct { + Insecure bool + Username string + Password string +} + type Model struct { Name string `json:"name"` ModelPath string @@ -564,7 +570,7 @@ func DeleteModel(name string, fn func(api.ProgressResponse)) error { return nil } -func PushModel(name, username, password string, fn func(api.ProgressResponse)) error { +func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) fn(api.ProgressResponse{Status: "retrieving manifest"}) @@ -586,7 +592,7 @@ func PushModel(name, username, password string, fn func(api.ProgressResponse)) e total += manifest.Config.Size for _, layer := range layers { - exists, err := checkBlobExistence(mp, layer.Digest, username, password) + exists, err := checkBlobExistence(mp, layer.Digest, regOpts) if err != nil { return err } @@ -609,13 +615,13 @@ func PushModel(name, username, password string, fn func(api.ProgressResponse)) e Completed: completed, }) - location, err := startUpload(mp, username, password) + location, err := startUpload(mp, regOpts) if err != nil { log.Printf("couldn't start upload: %v", err) return err } - err = uploadBlob(location, layer, username, password) + err = uploadBlob(location, layer, regOpts) if err != nil { log.Printf("error uploading blob: %v", err) return err @@ -634,7 +640,7 @@ func PushModel(name, username, password string, fn func(api.ProgressResponse)) e Total: total, Completed: completed, }) - url := fmt.Sprintf("%s://%s/v2/%s/manifests/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), mp.Tag) + url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag) headers := map[string]string{ "Content-Type": "application/vnd.docker.distribution.manifest.v2+json", } @@ -644,7 +650,7 @@ func PushModel(name, username, password string, fn func(api.ProgressResponse)) e return err } - resp, err := makeRequest("PUT", url, headers, bytes.NewReader(manifestJSON), username, password) + resp, err := makeRequest("PUT", url, headers, bytes.NewReader(manifestJSON), regOpts) if err != nil { return err } @@ -665,12 +671,12 @@ func PushModel(name, username, password string, fn func(api.ProgressResponse)) e return nil } -func PullModel(name, username, password string, fn func(api.ProgressResponse)) error { +func PullModel(name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) fn(api.ProgressResponse{Status: "pulling manifest"}) - manifest, err := pullModelManifest(mp, username, password) + manifest, err := pullModelManifest(mp, regOpts) if err != nil { return fmt.Errorf("pull model manifest: %q", err) } @@ -680,7 +686,7 @@ func PullModel(name, username, password string, fn func(api.ProgressResponse)) e layers = append(layers, &manifest.Config) for _, layer := range layers { - if err := downloadBlob(mp, layer.Digest, username, password, fn); err != nil { + if err := downloadBlob(mp, layer.Digest, regOpts, fn); err != nil { return err } } @@ -715,13 +721,13 @@ func PullModel(name, username, password string, fn func(api.ProgressResponse)) e return nil } -func pullModelManifest(mp ModelPath, username, password string) (*ManifestV2, error) { - url := fmt.Sprintf("%s://%s/v2/%s/manifests/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), mp.Tag) +func pullModelManifest(mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) { + url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag) headers := map[string]string{ "Accept": "application/vnd.docker.distribution.manifest.v2+json", } - resp, err := makeRequest("GET", url, headers, nil, username, password) + resp, err := makeRequest("GET", url, headers, nil, regOpts) if err != nil { log.Printf("couldn't get manifest: %v", err) return nil, err @@ -782,10 +788,10 @@ func GetSHA256Digest(r io.Reader) (string, int) { return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n) } -func startUpload(mp ModelPath, username string, password string) (string, error) { - url := fmt.Sprintf("%s://%s/v2/%s/blobs/uploads/", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository()) +func startUpload(mp ModelPath, regOpts *RegistryOptions) (string, error) { + url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository()) - resp, err := makeRequest("POST", url, nil, nil, username, password) + resp, err := makeRequest("POST", url, nil, nil, regOpts) if err != nil { log.Printf("couldn't start upload: %v", err) return "", err @@ -808,10 +814,10 @@ func startUpload(mp ModelPath, username string, password string) (string, error) } // Function to check if a blob already exists in the Docker registry -func checkBlobExistence(mp ModelPath, digest string, username string, password string) (bool, error) { - url := fmt.Sprintf("%s://%s/v2/%s/blobs/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), digest) +func checkBlobExistence(mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) { + url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest) - resp, err := makeRequest("HEAD", url, nil, nil, username, password) + resp, err := makeRequest("HEAD", url, nil, nil, regOpts) if err != nil { log.Printf("couldn't check for blob: %v", err) return false, err @@ -822,7 +828,7 @@ func checkBlobExistence(mp ModelPath, digest string, username string, password s return resp.StatusCode == http.StatusOK, nil } -func uploadBlob(location string, layer *Layer, username string, password string) error { +func uploadBlob(location string, layer *Layer, regOpts *RegistryOptions) error { // Create URL url := fmt.Sprintf("%s&digest=%s", location, layer.Digest) @@ -845,7 +851,7 @@ func uploadBlob(location string, layer *Layer, username string, password string) return err } - resp, err := makeRequest("PUT", url, headers, f, username, password) + resp, err := makeRequest("PUT", url, headers, f, regOpts) if err != nil { log.Printf("couldn't upload blob: %v", err) return err @@ -861,7 +867,7 @@ func uploadBlob(location string, layer *Layer, username string, password string) return nil } -func downloadBlob(mp ModelPath, digest string, username, password string, fn func(api.ProgressResponse)) error { +func downloadBlob(mp ModelPath, digest string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { fp, err := GetBlobsPath(digest) if err != nil { return err @@ -890,12 +896,12 @@ func downloadBlob(mp ModelPath, digest string, username, password string, fn fun size = fi.Size() } - url := fmt.Sprintf("%s://%s/v2/%s/blobs/%s", mp.ProtocolScheme, mp.Registry, mp.GetNamespaceRepository(), digest) + url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest) headers := map[string]string{ "Range": fmt.Sprintf("bytes=%d-", size), } - resp, err := makeRequest("GET", url, headers, nil, username, password) + resp, err := makeRequest("GET", url, headers, nil, regOpts) if err != nil { log.Printf("couldn't download blob: %v", err) return err @@ -959,7 +965,17 @@ func downloadBlob(mp ModelPath, digest string, username, password string, fn fun return nil } -func makeRequest(method, url string, headers map[string]string, body io.Reader, username, password string) (*http.Response, error) { +func makeRequest(method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) { + if !strings.HasPrefix(url, "http") { + if regOpts.Insecure { + url = "http://" + url + } else { + url = "https://" + url + } + } + + log.Printf("url = %s", url) + req, err := http.NewRequest(method, url, body) if err != nil { return nil, err @@ -970,8 +986,8 @@ func makeRequest(method, url string, headers map[string]string, body io.Reader, } // TODO: better auth - if username != "" && password != "" { - req.SetBasicAuth(username, password) + if regOpts.Username != "" && regOpts.Password != "" { + req.SetBasicAuth(regOpts.Username, regOpts.Password) } client := &http.Client{ diff --git a/server/modelpath.go b/server/modelpath.go index 02a955f6..0bf36945 100644 --- a/server/modelpath.go +++ b/server/modelpath.go @@ -70,10 +70,13 @@ func (mp ModelPath) GetFullTagname() string { } func (mp ModelPath) GetShortTagname() string { - if mp.Registry == DefaultRegistry && mp.Namespace == DefaultNamespace { - return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag) + if mp.Registry == DefaultRegistry { + if mp.Namespace == DefaultNamespace { + return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag) + } + return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag) } - return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag) + return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag) } func (mp ModelPath) GetManifestPath(createDir bool) (string, error) { diff --git a/server/routes.go b/server/routes.go index c1a32209..0d7edaf0 100644 --- a/server/routes.go +++ b/server/routes.go @@ -93,7 +93,13 @@ func PullModelHandler(c *gin.Context) { ch <- r } - if err := PullModel(req.Name, req.Username, req.Password, fn); err != nil { + regOpts := &RegistryOptions{ + Insecure: req.Insecure, + Username: req.Username, + Password: req.Password, + } + + if err := PullModel(req.Name, regOpts, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() @@ -115,7 +121,13 @@ func PushModelHandler(c *gin.Context) { ch <- r } - if err := PushModel(req.Name, req.Username, req.Password, fn); err != nil { + regOpts := &RegistryOptions{ + Insecure: req.Insecure, + Username: req.Username, + Password: req.Password, + } + + if err := PushModel(req.Name, regOpts, fn); err != nil { ch <- gin.H{"error": err.Error()} } }()