ollama/server/download.go

209 lines
4.1 KiB
Go
Raw Normal View History

package server
import (
"context"
2023-09-27 19:22:30 -04:00
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
2023-09-27 19:22:30 -04:00
"net/url"
"os"
2023-09-19 12:36:30 -04:00
"path/filepath"
"strconv"
"github.com/jmorganca/ollama/api"
2023-09-27 19:22:30 -04:00
"golang.org/x/sync/errgroup"
)
2023-09-27 19:22:30 -04:00
type BlobDownloadPart struct {
Offset int64
Size int64
Completed int64
}
2023-08-15 14:07:19 -04:00
type downloadOpts struct {
mp ModelPath
digest string
regOpts *RegistryOptions
fn func(api.ProgressResponse)
}
2023-09-27 19:22:30 -04:00
const maxRetries = 3
2023-08-15 14:07:19 -04:00
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
2023-08-15 14:07:19 -04:00
func downloadBlob(ctx context.Context, opts downloadOpts) error {
fp, err := GetBlobsPath(opts.digest)
if err != nil {
return err
}
2023-09-27 19:22:30 -04:00
fi, err := os.Stat(fp)
switch {
case errors.Is(err, os.ErrNotExist):
case err != nil:
return err
default:
2023-08-15 14:07:19 -04:00
opts.fn(api.ProgressResponse{
2023-09-27 19:22:30 -04:00
Status: fmt.Sprintf("downloading %s", opts.digest),
2023-08-15 14:07:19 -04:00
Digest: opts.digest,
2023-09-28 13:00:34 -04:00
Total: fi.Size(),
Completed: fi.Size(),
})
return nil
}
2023-09-27 19:22:30 -04:00
f, err := os.OpenFile(fp+"-partial", os.O_CREATE|os.O_RDWR, 0644)
if err != nil {
return err
}
2023-09-27 19:22:30 -04:00
defer f.Close()
2023-09-27 19:22:30 -04:00
partFilePaths, err := filepath.Glob(fp + "-partial-*")
if err != nil {
2023-08-15 14:07:19 -04:00
return err
}
2023-09-27 19:22:30 -04:00
var total, completed int64
var parts []BlobDownloadPart
for _, partFilePath := range partFilePaths {
var part BlobDownloadPart
partFile, err := os.Open(partFilePath)
2023-08-01 15:34:52 -04:00
if err != nil {
return err
}
2023-09-27 19:22:30 -04:00
defer partFile.Close()
if err := json.NewDecoder(partFile).Decode(&part); err != nil {
return err
}
2023-09-27 19:22:30 -04:00
total += part.Size
completed += part.Completed
2023-09-27 19:22:30 -04:00
parts = append(parts, part)
}
2023-09-27 19:22:30 -04:00
requestURL := opts.mp.BaseURL()
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
2023-09-27 19:22:30 -04:00
if len(parts) == 0 {
resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, opts.regOpts)
if err != nil {
2023-09-27 19:22:30 -04:00
return err
}
defer resp.Body.Close()
total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
// reserve the file
f.Truncate(total)
var offset int64
var size int64 = 64 * 1024 * 1024
for offset < total {
if offset+size > total {
size = total - offset
}
parts = append(parts, BlobDownloadPart{
Offset: offset,
Size: size,
})
offset += size
}
}
2023-09-27 19:22:30 -04:00
pw := &ProgressWriter{
status: fmt.Sprintf("downloading %s", opts.digest),
digest: opts.digest,
total: total,
completed: completed,
fn: opts.fn,
}
2023-08-21 21:24:42 -04:00
2023-09-27 19:22:30 -04:00
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(64)
for i := range parts {
part := parts[i]
if part.Completed == part.Size {
continue
}
2023-09-27 19:22:30 -04:00
i := i
g.Go(func() error {
for try := 0; try < maxRetries; try++ {
if err := downloadBlobChunk(ctx, f, requestURL, parts, i, pw, opts); err != nil {
log.Printf("%s part %d attempt %d failed: %v, retrying", opts.digest[7:19], i, try, err)
continue
}
return nil
}
return errors.New("max retries exceeded")
})
}
2023-09-27 19:22:30 -04:00
if err := g.Wait(); err != nil {
return err
}
2023-09-27 19:22:30 -04:00
if err := f.Close(); err != nil {
return err
}
for i := range parts {
if err := os.Remove(f.Name() + "-" + strconv.Itoa(i)); err != nil {
return err
}
}
2023-09-27 19:22:30 -04:00
return os.Rename(f.Name(), fp)
}
2023-09-27 19:22:30 -04:00
func downloadBlobChunk(ctx context.Context, f *os.File, requestURL *url.URL, parts []BlobDownloadPart, i int, pw *ProgressWriter, opts downloadOpts) error {
part := &parts[i]
2023-09-27 19:22:30 -04:00
partName := f.Name() + "-" + strconv.Itoa(i)
if err := flushPart(partName, part); err != nil {
return err
}
2023-09-27 19:22:30 -04:00
offset := part.Offset + part.Completed
w := io.NewOffsetWriter(f, offset)
2023-09-27 19:22:30 -04:00
headers := make(http.Header)
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", offset, part.Offset+part.Size-1))
resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts.regOpts)
if err != nil {
return err
}
defer resp.Body.Close()
2023-09-27 19:22:30 -04:00
n, err := io.Copy(w, io.TeeReader(resp.Body, pw))
if err != nil && !errors.Is(err, io.EOF) {
// rollback progress bar
pw.completed -= n
return err
}
2023-09-27 19:22:30 -04:00
part.Completed += n
2023-09-27 19:22:30 -04:00
return flushPart(partName, part)
}
func flushPart(name string, part *BlobDownloadPart) error {
partFile, err := os.OpenFile(name, os.O_CREATE|os.O_RDWR, 0644)
if err != nil {
return err
}
2023-09-27 19:22:30 -04:00
defer partFile.Close()
2023-09-27 19:22:30 -04:00
return json.NewEncoder(partFile).Encode(part)
}