diff --git a/api/client.go b/api/client.go index 4a5b97c9..f308b233 100644 --- a/api/client.go +++ b/api/client.go @@ -127,7 +127,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData return nil } -const maxBufferSize = 512 * 1024 // 512KB +const maxBufferSize = 512 * 1000 // 512KB func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error { var buf *bytes.Buffer diff --git a/format/bytes.go b/format/bytes.go new file mode 100644 index 00000000..63cc7b00 --- /dev/null +++ b/format/bytes.go @@ -0,0 +1,16 @@ +package format + +import "fmt" + +func HumanBytes(b int64) string { + switch { + case b > 1000*1000*1000: + return fmt.Sprintf("%d GB", b/1000/1000/1000) + case b > 1000*1000: + return fmt.Sprintf("%d MB", b/1000/1000) + case b > 1000: + return fmt.Sprintf("%d KB", b/1000) + default: + return fmt.Sprintf("%d B", b) + } +} diff --git a/llm/llama.go b/llm/llama.go index 8bd11f53..33468cba 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -454,7 +454,7 @@ type PredictRequest struct { Stop []string `json:"stop,omitempty"` } -const maxBufferSize = 512 * 1024 // 512KB +const maxBufferSize = 512 * 1000 // 512KB func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error { prevConvo, err := llm.Decode(ctx, prevContext) diff --git a/llm/llm.go b/llm/llm.go index ef424b5d..6df2a47c 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -60,33 +60,33 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error totalResidentMemory := memory.TotalMemory() switch ggml.ModelType() { case "3B", "7B": - if ggml.FileType() == "F16" && totalResidentMemory < 16*1024*1024 { - return nil, fmt.Errorf("F16 model requires at least 16GB of memory") - } else if totalResidentMemory < 8*1024*1024 { - return nil, fmt.Errorf("model requires at least 8GB of memory") + if ggml.FileType() == "F16" && totalResidentMemory < 16*1000*1000 { + return nil, fmt.Errorf("F16 model requires at least 16 GB of memory") + } else if totalResidentMemory < 8*1000*1000 { + return nil, fmt.Errorf("model requires at least 8 GB of memory") } case "13B": - if ggml.FileType() == "F16" && totalResidentMemory < 32*1024*1024 { - return nil, fmt.Errorf("F16 model requires at least 32GB of memory") - } else if totalResidentMemory < 16*1024*1024 { - return nil, fmt.Errorf("model requires at least 16GB of memory") + if ggml.FileType() == "F16" && totalResidentMemory < 32*1000*1000 { + return nil, fmt.Errorf("F16 model requires at least 32 GB of memory") + } else if totalResidentMemory < 16*1000*1000 { + return nil, fmt.Errorf("model requires at least 16 GB of memory") } case "30B", "34B", "40B": - if ggml.FileType() == "F16" && totalResidentMemory < 64*1024*1024 { - return nil, fmt.Errorf("F16 model requires at least 64GB of memory") - } else if totalResidentMemory < 32*1024*1024 { - return nil, fmt.Errorf("model requires at least 32GB of memory") + if ggml.FileType() == "F16" && totalResidentMemory < 64*1000*1000 { + return nil, fmt.Errorf("F16 model requires at least 64 GB of memory") + } else if totalResidentMemory < 32*1000*1000 { + return nil, fmt.Errorf("model requires at least 32 GB of memory") } case "65B", "70B": - if ggml.FileType() == "F16" && totalResidentMemory < 128*1024*1024 { - return nil, fmt.Errorf("F16 model requires at least 128GB of memory") - } else if totalResidentMemory < 64*1024*1024 { - return nil, fmt.Errorf("model requires at least 64GB of memory") + if ggml.FileType() == "F16" && totalResidentMemory < 128*1000*1000 { + return nil, fmt.Errorf("F16 model requires at least 128 GB of memory") + } else if totalResidentMemory < 64*1000*1000 { + return nil, fmt.Errorf("model requires at least 64 GB of memory") } case "180B": - if ggml.FileType() == "F16" && totalResidentMemory < 512*1024*1024 { + if ggml.FileType() == "F16" && totalResidentMemory < 512*1000*1000 { return nil, fmt.Errorf("F16 model requires at least 512GB of memory") - } else if totalResidentMemory < 128*1024*1024 { + } else if totalResidentMemory < 128*1000*1000 { return nil, fmt.Errorf("model requires at least 128GB of memory") } } diff --git a/server/download.go b/server/download.go index 80914ae3..0c3beb7e 100644 --- a/server/download.go +++ b/server/download.go @@ -20,6 +20,7 @@ import ( "golang.org/x/sync/errgroup" "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/format" ) var blobDownloadManager sync.Map @@ -34,6 +35,9 @@ type blobDownload struct { Parts []*blobDownloadPart context.CancelFunc + + done bool + err error references atomic.Int32 } @@ -46,6 +50,12 @@ type blobDownloadPart struct { *blobDownload `json:"-"` } +const ( + numDownloadParts = 64 + minDownloadPartSize int64 = 32 * 1000 * 1000 + maxDownloadPartSize int64 = 256 * 1000 * 1000 +) + func (p *blobDownloadPart) Name() string { return strings.Join([]string{ p.blobDownload.Name, "partial", strconv.Itoa(p.N), @@ -91,9 +101,15 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) - var offset int64 - var size int64 = 64 * 1024 * 1024 + var size = b.Total / numDownloadParts + switch { + case size < minDownloadPartSize: + size = minDownloadPartSize + case size > maxDownloadPartSize: + size = maxDownloadPartSize + } + var offset int64 for offset < b.Total { if offset+size > b.Total { size = b.Total - offset @@ -107,11 +123,15 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R } } - log.Printf("downloading %s in %d part(s)", b.Digest[7:19], len(b.Parts)) + log.Printf("downloading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size)) return nil } -func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) (err error) { +func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) { + b.err = b.run(ctx, requestURL, opts) +} + +func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { defer blobDownloadManager.Delete(b.Digest) ctx, b.CancelFunc = context.WithCancel(ctx) @@ -124,9 +144,8 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis file.Truncate(b.Total) - g, ctx := errgroup.WithContext(ctx) - // TODO(mxyng): download concurrency should be configurable - g.SetLimit(64) + g, _ := errgroup.WithContext(ctx) + g.SetLimit(numDownloadParts) for i := range b.Parts { part := b.Parts[i] if part.Completed == part.Size { @@ -168,7 +187,12 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis } } - return os.Rename(file.Name(), b.Name) + if err := os.Rename(file.Name(), b.Name); err != nil { + return err + } + + b.done = true + return nil } func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error { @@ -267,11 +291,8 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) Completed: b.Completed.Load(), }) - if b.Completed.Load() >= b.Total { - // wait for the file to get renamed - if _, err := os.Stat(b.Name); err == nil { - return nil - } + if b.done || b.err != nil { + return b.err } } }