diff --git a/server/download.go b/server/download.go index 2eadbce4..6023de31 100644 --- a/server/download.go +++ b/server/download.go @@ -12,17 +12,238 @@ import ( "os" "path/filepath" "strconv" + "sync" + "sync/atomic" + "time" + + "golang.org/x/sync/errgroup" "github.com/jmorganca/ollama/api" - "golang.org/x/sync/errgroup" ) -type BlobDownloadPart struct { +var blobDownloadManager sync.Map + +type blobDownload struct { + Name string + Digest string + + Total int64 + Completed atomic.Int64 + + *os.File + Parts []*blobDownloadPart + + done chan struct{} + context.CancelFunc + refCount atomic.Int32 +} + +type blobDownloadPart struct { Offset int64 Size int64 Completed int64 } +func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { + b.done = make(chan struct{}, 1) + + partFilePaths, err := filepath.Glob(b.Name + "-partial-*") + if err != nil { + return err + } + + for _, partFilePath := range partFilePaths { + part, err := b.readPart(partFilePath) + if err != nil { + return err + } + + b.Total += part.Size + b.Completed.Add(part.Completed) + b.Parts = append(b.Parts, part) + } + + if len(b.Parts) == 0 { + resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, opts) + if err != nil { + return err + } + defer resp.Body.Close() + + b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) + + var offset int64 + var size int64 = 64 * 1024 * 1024 + + for offset < b.Total { + if offset+size > b.Total { + size = b.Total - offset + } + + partName := b.Name + "-partial-" + strconv.Itoa(len(b.Parts)) + part := blobDownloadPart{Offset: offset, Size: size} + if err := b.writePart(partName, &part); err != nil { + return err + } + + b.Parts = append(b.Parts, &part) + + offset += size + } + } + + log.Printf("downloading %s in %d part(s)", b.Digest[7:19], len(b.Parts)) + return nil +} + +func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) (err error) { + defer blobDownloadManager.Delete(b.Digest) + + ctx, b.CancelFunc = context.WithCancel(ctx) + + b.File, err = os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644) + if err != nil { + return err + } + defer b.Close() + + b.Truncate(b.Total) + + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(64) + for i := range b.Parts { + part := b.Parts[i] + if part.Completed == part.Size { + continue + } + + i := i + g.Go(func() error { + for try := 0; try < maxRetries; try++ { + err := b.downloadChunk(ctx, requestURL, i, opts) + switch { + case errors.Is(err, context.Canceled): + return err + case err != nil: + log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], i, try, err) + continue + default: + return nil + } + } + + return errors.New("max retries exceeded") + }) + } + + if err := g.Wait(); err != nil { + return err + } + + if err := b.Close(); err != nil { + return err + } + + for i := range b.Parts { + if err := os.Remove(b.File.Name() + "-" + strconv.Itoa(i)); err != nil { + return err + } + } + + if err := os.Rename(b.File.Name(), b.Name); err != nil { + return err + } + + close(b.done) + return nil +} + +func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, i int, opts *RegistryOptions) error { + part := b.Parts[i] + + partName := b.File.Name() + "-" + strconv.Itoa(i) + offset := part.Offset + part.Completed + w := io.NewOffsetWriter(b.File, offset) + + headers := make(http.Header) + headers.Set("Range", fmt.Sprintf("bytes=%d-%d", offset, part.Offset+part.Size-1)) + resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts) + if err != nil { + return err + } + defer resp.Body.Close() + + n, err := io.Copy(w, io.TeeReader(resp.Body, b)) + if err != nil && !errors.Is(err, io.EOF) { + // rollback progress + b.Completed.Add(-n) + return err + } + + part.Completed += n + return b.writePart(partName, part) +} + +func (b *blobDownload) readPart(partName string) (*blobDownloadPart, error) { + var part blobDownloadPart + partFile, err := os.Open(partName) + if err != nil { + return nil, err + } + defer partFile.Close() + + if err := json.NewDecoder(partFile).Decode(&part); err != nil { + return nil, err + } + + return &part, nil +} + +func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error { + partFile, err := os.OpenFile(partName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644) + if err != nil { + return err + } + defer partFile.Close() + + return json.NewEncoder(partFile).Encode(part) +} + +func (b *blobDownload) Write(p []byte) (n int, err error) { + n = len(p) + b.Completed.Add(int64(n)) + return n, nil +} + +func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error { + b.refCount.Add(1) + + ticker := time.NewTicker(60 * time.Millisecond) + for { + select { + case <-ticker.C: + case <-ctx.Done(): + if b.refCount.Add(-1) == 0 { + b.CancelFunc() + } + + return ctx.Err() + } + + fn(api.ProgressResponse{ + Status: fmt.Sprintf("downloading %s", b.Digest), + Digest: b.Digest, + Total: b.Total, + Completed: b.Completed.Load(), + }) + + if b.Completed.Load() >= b.Total { + <-b.done + return nil + } + } +} + type downloadOpts struct { mp ModelPath digest string @@ -55,154 +276,17 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error { return nil } - f, err := os.OpenFile(fp+"-partial", os.O_CREATE|os.O_RDWR, 0644) - if err != nil { - return err - } - defer f.Close() - - partFilePaths, err := filepath.Glob(fp + "-partial-*") - if err != nil { - return err - } - - var total, completed int64 - var parts []BlobDownloadPart - for _, partFilePath := range partFilePaths { - var part BlobDownloadPart - partFile, err := os.Open(partFilePath) - if err != nil { - return err - } - defer partFile.Close() - - if err := json.NewDecoder(partFile).Decode(&part); err != nil { + value, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest}) + blobDownload := value.(*blobDownload) + if !ok { + requestURL := opts.mp.BaseURL() + requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest) + if err := blobDownload.Prepare(ctx, requestURL, opts.regOpts); err != nil { return err } - total += part.Size - completed += part.Completed - - parts = append(parts, part) + go blobDownload.Run(context.Background(), requestURL, opts.regOpts) } - requestURL := opts.mp.BaseURL() - requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest) - - if len(parts) == 0 { - resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, opts.regOpts) - if err != nil { - return err - } - defer resp.Body.Close() - - total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) - - // reserve the file - f.Truncate(total) - - var offset int64 - var size int64 = 64 * 1024 * 1024 - - for offset < total { - if offset+size > total { - size = total - offset - } - - parts = append(parts, BlobDownloadPart{ - Offset: offset, - Size: size, - }) - - offset += size - } - } - - pw := &ProgressWriter{ - status: fmt.Sprintf("downloading %s", opts.digest), - digest: opts.digest, - total: total, - completed: completed, - fn: opts.fn, - } - - g, ctx := errgroup.WithContext(ctx) - g.SetLimit(64) - for i := range parts { - part := parts[i] - if part.Completed == part.Size { - continue - } - - i := i - g.Go(func() error { - for try := 0; try < maxRetries; try++ { - if err := downloadBlobChunk(ctx, f, requestURL, parts, i, pw, opts); err != nil { - log.Printf("%s part %d attempt %d failed: %v, retrying", opts.digest[7:19], i, try, err) - continue - } - - return nil - } - - return errors.New("max retries exceeded") - }) - } - - if err := g.Wait(); err != nil { - return err - } - - if err := f.Close(); err != nil { - return err - } - - for i := range parts { - if err := os.Remove(f.Name() + "-" + strconv.Itoa(i)); err != nil { - return err - } - } - - return os.Rename(f.Name(), fp) -} - -func downloadBlobChunk(ctx context.Context, f *os.File, requestURL *url.URL, parts []BlobDownloadPart, i int, pw *ProgressWriter, opts downloadOpts) error { - part := &parts[i] - - partName := f.Name() + "-" + strconv.Itoa(i) - if err := flushPart(partName, part); err != nil { - return err - } - - offset := part.Offset + part.Completed - w := io.NewOffsetWriter(f, offset) - - headers := make(http.Header) - headers.Set("Range", fmt.Sprintf("bytes=%d-%d", offset, part.Offset+part.Size-1)) - resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts.regOpts) - if err != nil { - return err - } - defer resp.Body.Close() - - n, err := io.Copy(w, io.TeeReader(resp.Body, pw)) - if err != nil && !errors.Is(err, io.EOF) { - // rollback progress bar - pw.completed -= n - return err - } - - part.Completed += n - - return flushPart(partName, part) -} - -func flushPart(name string, part *BlobDownloadPart) error { - partFile, err := os.OpenFile(name, os.O_CREATE|os.O_RDWR, 0644) - if err != nil { - return err - } - defer partFile.Close() - - return json.NewEncoder(partFile).Encode(part) + return blobDownload.Wait(ctx, opts.fn) }