diff --git a/server/download.go b/server/download.go index d31797da..4985fb53 100644 --- a/server/download.go +++ b/server/download.go @@ -20,6 +20,7 @@ import ( "time" "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/format" @@ -150,8 +151,7 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *regis _ = file.Truncate(b.Total) - g, inner := errgroup.WithContext(ctx) - g.SetLimit(numDownloadParts) + g, inner := NewLimitGroup(ctx, numDownloadParts) for i := range b.Parts { part := b.Parts[i] if part.Completed == part.Size { @@ -378,3 +378,41 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error { return download.Wait(ctx, opts.fn) } + +type LimitGroup struct { + *errgroup.Group + context.Context + Semaphore *semaphore.Weighted + + weight, max_weight int64 +} + +func NewLimitGroup(ctx context.Context, n int64) (*LimitGroup, context.Context) { + g, ctx := errgroup.WithContext(ctx) + return &LimitGroup{ + Group: g, + Context: ctx, + Semaphore: semaphore.NewWeighted(n), + weight: n, + max_weight: n, + }, ctx +} + +func (g *LimitGroup) Go(fn func() error) { + weight := g.weight + g.Semaphore.Acquire(g.Context, weight) + if g.Context.Err() != nil { + return + } + + g.Group.Go(func() error { + defer g.Semaphore.Release(weight) + return fn() + }) +} + +func (g *LimitGroup) SetLimit(n int64) { + if n > 0 { + g.weight = g.max_weight / n + } +}