From 868e3b31c76b7e51599d6f0dac16a6b004f90bca Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 25 Jul 2023 17:08:51 -0400 Subject: [PATCH] allow for concurrent pulls of the same files --- server/download.go | 217 +++++++++++++++++++++++++++++++++++++++++++++ server/images.go | 9 +- server/routes.go | 11 ++- 3 files changed, 231 insertions(+), 6 deletions(-) create mode 100644 server/download.go diff --git a/server/download.go b/server/download.go new file mode 100644 index 00000000..1a696f5f --- /dev/null +++ b/server/download.go @@ -0,0 +1,217 @@ +package server + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "net/http" + "os" + "path" + "strconv" + "sync" + "time" + + "github.com/jmorganca/ollama/api" +) + +type FileDownload struct { + Digest string + FilePath string + Total int64 + Completed int64 +} + +var inProgress sync.Map // map of digests currently being downloaded to their current download progress + +// downloadBlob downloads a blob from the registry and stores it in the blobs directory +func downloadBlob(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { + fp, err := GetBlobsPath(digest) + if err != nil { + return err + } + + if fi, _ := os.Stat(fp); fi != nil { + // we already have the file, so return + fn(api.ProgressResponse{ + Digest: digest, + Total: int(fi.Size()), + Completed: int(fi.Size()), + }) + + return nil + } + + fileDownload := &FileDownload{ + Digest: digest, + FilePath: fp, + Total: 1, // dummy value to indicate that we don't know the total size yet + Completed: 0, + } + + _, downloading := inProgress.LoadOrStore(digest, fileDownload) + if downloading { + // this is another client requesting the server to download the same blob concurrently + return monitorDownload(ctx, mp, regOpts, fileDownload, fn) + } + resp, err := requestDownload(ctx, mp, regOpts, fileDownload) + if err != nil { + return err + } + return doDownload(ctx, fileDownload, resp, fn) +} + +var downloadMu sync.Mutex // mutex to check to resume a download while monitoring + +// monitorDownload monitors the download progress of a blob and resumes it if it is interrupted +func monitorDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error { + tick := time.NewTicker(time.Second) + for range tick.C { + downloadMu.Lock() + val, downloading := inProgress.Load(f.Digest) + if !downloading { + // check once again if the download is complete + if fi, _ := os.Stat(f.FilePath); fi != nil { + downloadMu.Unlock() + // successfull download while monitoring + fn(api.ProgressResponse{ + Digest: f.Digest, + Total: int(fi.Size()), + Completed: int(fi.Size()), + }) + return nil + } + // resume the download + resp, err := requestDownload(ctx, mp, regOpts, f) + if err != nil { + return fmt.Errorf("resume: %w", err) + } + inProgress.Store(f.Digest, f) + downloadMu.Unlock() + return doDownload(ctx, f, resp, fn) + } + downloadMu.Unlock() + f, ok := val.(*FileDownload) + if !ok { + return fmt.Errorf("invalid type for in progress download: %T", val) + } + fn(api.ProgressResponse{ + Status: fmt.Sprintf("downloading %s", f.Digest), + Digest: f.Digest, + Total: int(f.Total), + Completed: int(f.Completed), + }) + } + return nil +} + +var chunkSize = 1024 * 1024 // 1 MiB in bytes + +// requestDownload requests a blob from the registry and returns the response +func requestDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload) (*http.Response, error) { + var size int64 + + fi, err := os.Stat(f.FilePath + "-partial") + switch { + case errors.Is(err, os.ErrNotExist): + // noop, file doesn't exist so create it + case err != nil: + return nil, fmt.Errorf("stat: %w", err) + default: + size = fi.Size() + // Ensure the size is divisible by the chunk size by removing excess bytes + size -= size % int64(chunkSize) + + err := os.Truncate(f.FilePath+"-partial", size) + if err != nil { + return nil, fmt.Errorf("truncate: %w", err) + } + } + + url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), f.Digest) + headers := map[string]string{ + "Range": fmt.Sprintf("bytes=%d-", size), + } + + resp, err := makeRequest("GET", url, headers, nil, regOpts) + if err != nil { + log.Printf("couldn't download blob: %v", err) + return nil, err + } + // resp MUST be closed by doDownload, which should follow this function + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("on download registry responded with code %d: %v", resp.StatusCode, string(body)) + } + + err = os.MkdirAll(path.Dir(f.FilePath), 0o700) + if err != nil { + return nil, fmt.Errorf("make blobs directory: %w", err) + } + + remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) + f.Completed = size + f.Total = remaining + f.Completed + + inProgress.Store(f.Digest, f) + return resp, nil +} + +// doDownload downloads a blob from the registry and stores it in the blobs directory +func doDownload(ctx context.Context, f *FileDownload, resp *http.Response, fn func(api.ProgressResponse)) error { + defer resp.Body.Close() + out, err := os.OpenFile(f.FilePath+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) + if err != nil { + return fmt.Errorf("open file: %w", err) + } + defer out.Close() +outerLoop: + for { + select { + case <-ctx.Done(): + // handle client request cancellation + inProgress.Delete(f.Digest) + return nil + default: + fn(api.ProgressResponse{ + Status: fmt.Sprintf("downloading %s", f.Digest), + Digest: f.Digest, + Total: int(f.Total), + Completed: int(f.Completed), + }) + + if f.Completed >= f.Total { + if err := out.Close(); err != nil { + return err + } + + if err := os.Rename(f.FilePath+"-partial", f.FilePath); err != nil { + fn(api.ProgressResponse{ + Status: fmt.Sprintf("error renaming file: %v", err), + Digest: f.Digest, + Total: int(f.Total), + Completed: int(f.Completed), + }) + return err + } + + break outerLoop + } + } + + n, err := io.CopyN(out, resp.Body, int64(chunkSize)) + if err != nil && !errors.Is(err, io.EOF) { + return err + } + f.Completed += n + + inProgress.Store(f.Digest, f) + } + + inProgress.Delete(f.Digest) + + log.Printf("success getting %s\n", f.Digest) + return nil +} diff --git a/server/images.go b/server/images.go index 5796d2f4..01a71997 100644 --- a/server/images.go +++ b/server/images.go @@ -3,6 +3,7 @@ package server import ( "bufio" "bytes" + "context" "crypto/sha256" "encoding/json" "errors" @@ -232,7 +233,7 @@ func filenameWithPath(path, f string) (string, error) { return f, nil } -func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) error { +func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error { mf, err := os.Open(path) if err != nil { fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)}) @@ -265,7 +266,7 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e // the model file does not exist, try pulling it if errors.Is(err, os.ErrNotExist) { fn(api.ProgressResponse{Status: "pulling model file"}) - if err := PullModel(c.Args, &RegistryOptions{}, fn); err != nil { + if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil { return err } mf, err = GetManifest(ParseModelPath(c.Args)) @@ -900,7 +901,7 @@ func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressRespon return nil } -func PullModel(name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { +func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) fn(api.ProgressResponse{Status: "pulling manifest"}) @@ -915,7 +916,7 @@ func PullModel(name string, regOpts *RegistryOptions, fn func(api.ProgressRespon layers = append(layers, &manifest.Config) for _, layer := range layers { - if err := downloadBlob(mp, layer.Digest, regOpts, fn); err != nil { + if err := downloadBlob(ctx, mp, layer.Digest, regOpts, fn); err != nil { return err } } diff --git a/server/routes.go b/server/routes.go index 9182de3b..731e8078 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1,6 +1,7 @@ package server import ( + "context" "encoding/json" "errors" "fmt" @@ -200,7 +201,10 @@ func PullModelHandler(c *gin.Context) { Password: req.Password, } - if err := PullModel(req.Name, regOpts, fn); err != nil { + ctx, cancel := context.WithCancel(c.Request.Context()) + defer cancel() + + if err := PullModel(ctx, req.Name, regOpts, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() @@ -250,7 +254,10 @@ func CreateModelHandler(c *gin.Context) { ch <- resp } - if err := CreateModel(req.Name, req.Path, fn); err != nil { + ctx, cancel := context.WithCancel(c.Request.Context()) + defer cancel() + + if err := CreateModel(ctx, req.Name, req.Path, fn); err != nil { ch <- gin.H{"error": err.Error()} } }()