diff --git a/server/download.go b/server/download.go index dd7753b3..531219fd 100644 --- a/server/download.go +++ b/server/download.go @@ -20,6 +20,7 @@ import ( "golang.org/x/sync/errgroup" "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/format" ) var blobDownloadManager sync.Map @@ -47,6 +48,12 @@ type blobDownloadPart struct { *blobDownload `json:"-"` } +const ( + numDownloadParts = 64 + minDownloadPartSize int64 = 32 * 1000 * 1000 + maxDownloadPartSize int64 = 256 * 1000 * 1000 +) + func (p *blobDownloadPart) Name() string { return strings.Join([]string{ p.blobDownload.Name, "partial", strconv.Itoa(p.N), @@ -92,9 +99,15 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) - var offset int64 - var size int64 = 64 * 1024 * 1024 + var size = b.Total / numDownloadParts + switch { + case size < minDownloadPartSize: + size = minDownloadPartSize + case size > maxDownloadPartSize: + size = maxDownloadPartSize + } + var offset int64 for offset < b.Total { if offset+size > b.Total { size = b.Total - offset @@ -108,7 +121,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R } } - log.Printf("downloading %s in %d part(s)", b.Digest[7:19], len(b.Parts)) + log.Printf("downloading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size)) return nil } @@ -126,8 +139,7 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis file.Truncate(b.Total) g, _ := errgroup.WithContext(ctx) - // TODO(mxyng): download concurrency should be configurable - g.SetLimit(64) + g.SetLimit(numDownloadParts) for i := range b.Parts { part := b.Parts[i] if part.Completed == part.Size {