2023-07-25 17:08:51 -04:00
|
|
|
package server
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
2023-09-27 19:22:30 -04:00
|
|
|
"encoding/json"
|
2023-07-25 17:08:51 -04:00
|
|
|
"errors"
|
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"log"
|
|
|
|
"net/http"
|
2023-09-27 19:22:30 -04:00
|
|
|
"net/url"
|
2023-07-25 17:08:51 -04:00
|
|
|
"os"
|
2023-09-19 12:36:30 -04:00
|
|
|
"path/filepath"
|
2023-07-25 17:08:51 -04:00
|
|
|
"strconv"
|
|
|
|
|
|
|
|
"github.com/jmorganca/ollama/api"
|
2023-09-27 19:22:30 -04:00
|
|
|
"golang.org/x/sync/errgroup"
|
2023-07-25 17:08:51 -04:00
|
|
|
)
|
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
type BlobDownloadPart struct {
|
|
|
|
Offset int64
|
|
|
|
Size int64
|
2023-07-25 17:08:51 -04:00
|
|
|
Completed int64
|
|
|
|
}
|
|
|
|
|
2023-08-15 14:07:19 -04:00
|
|
|
type downloadOpts struct {
|
|
|
|
mp ModelPath
|
|
|
|
digest string
|
|
|
|
regOpts *RegistryOptions
|
|
|
|
fn func(api.ProgressResponse)
|
|
|
|
}
|
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
const maxRetries = 3
|
2023-08-15 14:07:19 -04:00
|
|
|
|
2023-07-25 17:08:51 -04:00
|
|
|
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
2023-08-15 14:07:19 -04:00
|
|
|
func downloadBlob(ctx context.Context, opts downloadOpts) error {
|
|
|
|
fp, err := GetBlobsPath(opts.digest)
|
2023-07-25 17:08:51 -04:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
fi, err := os.Stat(fp)
|
|
|
|
switch {
|
|
|
|
case errors.Is(err, os.ErrNotExist):
|
|
|
|
case err != nil:
|
|
|
|
return err
|
|
|
|
default:
|
2023-08-15 14:07:19 -04:00
|
|
|
opts.fn(api.ProgressResponse{
|
2023-09-27 19:22:30 -04:00
|
|
|
Status: fmt.Sprintf("downloading %s", opts.digest),
|
2023-08-15 14:07:19 -04:00
|
|
|
Digest: opts.digest,
|
2023-09-28 13:00:34 -04:00
|
|
|
Total: fi.Size(),
|
|
|
|
Completed: fi.Size(),
|
2023-07-25 17:08:51 -04:00
|
|
|
})
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
f, err := os.OpenFile(fp+"-partial", os.O_CREATE|os.O_RDWR, 0644)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
2023-07-25 17:08:51 -04:00
|
|
|
}
|
2023-09-27 19:22:30 -04:00
|
|
|
defer f.Close()
|
2023-07-25 17:08:51 -04:00
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
partFilePaths, err := filepath.Glob(fp + "-partial-*")
|
|
|
|
if err != nil {
|
2023-08-15 14:07:19 -04:00
|
|
|
return err
|
2023-07-25 17:08:51 -04:00
|
|
|
}
|
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
var total, completed int64
|
|
|
|
var parts []BlobDownloadPart
|
|
|
|
for _, partFilePath := range partFilePaths {
|
|
|
|
var part BlobDownloadPart
|
|
|
|
partFile, err := os.Open(partFilePath)
|
2023-08-01 15:34:52 -04:00
|
|
|
if err != nil {
|
|
|
|
return err
|
2023-07-25 17:08:51 -04:00
|
|
|
}
|
2023-09-27 19:22:30 -04:00
|
|
|
defer partFile.Close()
|
|
|
|
|
|
|
|
if err := json.NewDecoder(partFile).Decode(&part); err != nil {
|
|
|
|
return err
|
2023-07-25 17:08:51 -04:00
|
|
|
}
|
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
total += part.Size
|
|
|
|
completed += part.Completed
|
2023-07-25 17:08:51 -04:00
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
parts = append(parts, part)
|
|
|
|
}
|
2023-07-25 17:08:51 -04:00
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
requestURL := opts.mp.BaseURL()
|
|
|
|
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
|
2023-07-25 17:08:51 -04:00
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
if len(parts) == 0 {
|
|
|
|
resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, opts.regOpts)
|
2023-07-25 17:08:51 -04:00
|
|
|
if err != nil {
|
2023-09-27 19:22:30 -04:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
|
|
total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
|
|
|
|
|
|
|
// reserve the file
|
|
|
|
f.Truncate(total)
|
|
|
|
|
|
|
|
var offset int64
|
|
|
|
var size int64 = 64 * 1024 * 1024
|
|
|
|
|
|
|
|
for offset < total {
|
|
|
|
if offset+size > total {
|
|
|
|
size = total - offset
|
|
|
|
}
|
|
|
|
|
|
|
|
parts = append(parts, BlobDownloadPart{
|
|
|
|
Offset: offset,
|
|
|
|
Size: size,
|
|
|
|
})
|
|
|
|
|
|
|
|
offset += size
|
2023-07-25 17:08:51 -04:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
pw := &ProgressWriter{
|
|
|
|
status: fmt.Sprintf("downloading %s", opts.digest),
|
|
|
|
digest: opts.digest,
|
|
|
|
total: total,
|
|
|
|
completed: completed,
|
|
|
|
fn: opts.fn,
|
|
|
|
}
|
2023-08-21 21:24:42 -04:00
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
g, ctx := errgroup.WithContext(ctx)
|
|
|
|
g.SetLimit(64)
|
|
|
|
for i := range parts {
|
|
|
|
part := parts[i]
|
|
|
|
if part.Completed == part.Size {
|
|
|
|
continue
|
|
|
|
}
|
2023-07-25 17:08:51 -04:00
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
i := i
|
|
|
|
g.Go(func() error {
|
|
|
|
for try := 0; try < maxRetries; try++ {
|
|
|
|
if err := downloadBlobChunk(ctx, f, requestURL, parts, i, pw, opts); err != nil {
|
|
|
|
log.Printf("%s part %d attempt %d failed: %v, retrying", opts.digest[7:19], i, try, err)
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
return errors.New("max retries exceeded")
|
|
|
|
})
|
2023-07-25 17:08:51 -04:00
|
|
|
}
|
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
if err := g.Wait(); err != nil {
|
|
|
|
return err
|
2023-07-25 17:08:51 -04:00
|
|
|
}
|
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
if err := f.Close(); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
for i := range parts {
|
|
|
|
if err := os.Remove(f.Name() + "-" + strconv.Itoa(i)); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2023-07-25 17:08:51 -04:00
|
|
|
}
|
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
return os.Rename(f.Name(), fp)
|
|
|
|
}
|
2023-07-25 17:08:51 -04:00
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
func downloadBlobChunk(ctx context.Context, f *os.File, requestURL *url.URL, parts []BlobDownloadPart, i int, pw *ProgressWriter, opts downloadOpts) error {
|
|
|
|
part := &parts[i]
|
2023-07-25 17:08:51 -04:00
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
partName := f.Name() + "-" + strconv.Itoa(i)
|
|
|
|
if err := flushPart(partName, part); err != nil {
|
|
|
|
return err
|
2023-07-25 17:08:51 -04:00
|
|
|
}
|
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
offset := part.Offset + part.Completed
|
|
|
|
w := io.NewOffsetWriter(f, offset)
|
2023-07-25 17:08:51 -04:00
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
headers := make(http.Header)
|
|
|
|
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", offset, part.Offset+part.Size-1))
|
|
|
|
resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts.regOpts)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
defer resp.Body.Close()
|
2023-07-25 17:08:51 -04:00
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
n, err := io.Copy(w, io.TeeReader(resp.Body, pw))
|
|
|
|
if err != nil && !errors.Is(err, io.EOF) {
|
|
|
|
// rollback progress bar
|
|
|
|
pw.completed -= n
|
|
|
|
return err
|
|
|
|
}
|
2023-07-25 17:08:51 -04:00
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
part.Completed += n
|
2023-07-25 17:08:51 -04:00
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
return flushPart(partName, part)
|
|
|
|
}
|
|
|
|
|
|
|
|
func flushPart(name string, part *BlobDownloadPart) error {
|
|
|
|
partFile, err := os.OpenFile(name, os.O_CREATE|os.O_RDWR, 0644)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
2023-07-25 17:08:51 -04:00
|
|
|
}
|
2023-09-27 19:22:30 -04:00
|
|
|
defer partFile.Close()
|
2023-07-25 17:08:51 -04:00
|
|
|
|
2023-09-27 19:22:30 -04:00
|
|
|
return json.NewEncoder(partFile).Encode(part)
|
2023-07-25 17:08:51 -04:00
|
|
|
}
|