diff --git a/server/download.go b/server/download.go new file mode 100644 index 00000000..7aad599d --- /dev/null +++ b/server/download.go @@ -0,0 +1,215 @@ +package server + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "net/http" + "os" + "path" + "strconv" + "sync" + "time" + + "github.com/jmorganca/ollama/api" +) + +type FileDownload struct { + Digest string + FilePath string + Total int64 + Completed int64 +} + +var inProgress sync.Map // map of digests currently being downloaded to their current download progress + +// downloadBlob downloads a blob from the registry and stores it in the blobs directory +func downloadBlob(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { + fp, err := GetBlobsPath(digest) + if err != nil { + return err + } + + if fi, _ := os.Stat(fp); fi != nil { + // we already have the file, so return + fn(api.ProgressResponse{ + Digest: digest, + Total: int(fi.Size()), + Completed: int(fi.Size()), + }) + + return nil + } + + fileDownload := &FileDownload{ + Digest: digest, + FilePath: fp, + Total: 1, // dummy value to indicate that we don't know the total size yet + Completed: 0, + } + + _, downloading := inProgress.LoadOrStore(digest, fileDownload) + if downloading { + // this is another client requesting the server to download the same blob concurrently + return monitorDownload(ctx, mp, regOpts, fileDownload, fn) + } + return doDownload(ctx, mp, regOpts, fileDownload, fn) +} + +var downloadMu sync.Mutex // mutex to check to resume a download while monitoring + +// monitorDownload monitors the download progress of a blob and resumes it if it is interrupted +func monitorDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error { + tick := time.NewTicker(time.Second) + for range tick.C { + done, resume, err := func() (bool, bool, error) { + downloadMu.Lock() + defer downloadMu.Unlock() + val, downloading := inProgress.Load(f.Digest) + if !downloading { + // check once again if the download is complete + if fi, _ := os.Stat(f.FilePath); fi != nil { + // successful download while monitoring + fn(api.ProgressResponse{ + Digest: f.Digest, + Total: int(fi.Size()), + Completed: int(fi.Size()), + }) + return true, false, nil + } + // resume the download + inProgress.Store(f.Digest, f) // store the file download again to claim the resume + return false, true, nil + } + f, ok := val.(*FileDownload) + if !ok { + return false, false, fmt.Errorf("invalid type for in progress download: %T", val) + } + fn(api.ProgressResponse{ + Status: fmt.Sprintf("downloading %s", f.Digest), + Digest: f.Digest, + Total: int(f.Total), + Completed: int(f.Completed), + }) + return false, false, nil + }() + if err != nil { + return err + } + if done { + // done downloading + return nil + } + if resume { + return doDownload(ctx, mp, regOpts, f, fn) + } + } + return nil +} + +var chunkSize = 1024 * 1024 // 1 MiB in bytes + +// doDownload downloads a blob from the registry and stores it in the blobs directory +func doDownload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions, f *FileDownload, fn func(api.ProgressResponse)) error { + var size int64 + + fi, err := os.Stat(f.FilePath + "-partial") + switch { + case errors.Is(err, os.ErrNotExist): + // noop, file doesn't exist so create it + case err != nil: + return fmt.Errorf("stat: %w", err) + default: + size = fi.Size() + // Ensure the size is divisible by the chunk size by removing excess bytes + size -= size % int64(chunkSize) + + err := os.Truncate(f.FilePath+"-partial", size) + if err != nil { + return fmt.Errorf("truncate: %w", err) + } + } + + url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), f.Digest) + headers := map[string]string{ + "Range": fmt.Sprintf("bytes=%d-", size), + } + + resp, err := makeRequest("GET", url, headers, nil, regOpts) + if err != nil { + log.Printf("couldn't download blob: %v", err) + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("on download registry responded with code %d: %v", resp.StatusCode, string(body)) + } + + err = os.MkdirAll(path.Dir(f.FilePath), 0o700) + if err != nil { + return fmt.Errorf("make blobs directory: %w", err) + } + + remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) + f.Completed = size + f.Total = remaining + f.Completed + + inProgress.Store(f.Digest, f) + + out, err := os.OpenFile(f.FilePath+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) + if err != nil { + return fmt.Errorf("open file: %w", err) + } + defer out.Close() +outerLoop: + for { + select { + case <-ctx.Done(): + // handle client request cancellation + inProgress.Delete(f.Digest) + return nil + default: + fn(api.ProgressResponse{ + Status: fmt.Sprintf("downloading %s", f.Digest), + Digest: f.Digest, + Total: int(f.Total), + Completed: int(f.Completed), + }) + + if f.Completed >= f.Total { + if err := out.Close(); err != nil { + return err + } + + if err := os.Rename(f.FilePath+"-partial", f.FilePath); err != nil { + fn(api.ProgressResponse{ + Status: fmt.Sprintf("error renaming file: %v", err), + Digest: f.Digest, + Total: int(f.Total), + Completed: int(f.Completed), + }) + return err + } + + break outerLoop + } + } + + n, err := io.CopyN(out, resp.Body, int64(chunkSize)) + if err != nil && !errors.Is(err, io.EOF) { + return err + } + f.Completed += n + + inProgress.Store(f.Digest, f) + } + + inProgress.Delete(f.Digest) + + log.Printf("success getting %s\n", f.Digest) + return nil +} diff --git a/server/images.go b/server/images.go index 5796d2f4..22b0a74e 100644 --- a/server/images.go +++ b/server/images.go @@ -3,6 +3,7 @@ package server import ( "bufio" "bytes" + "context" "crypto/sha256" "encoding/json" "errors" @@ -13,7 +14,6 @@ import ( "math" "net/http" "os" - "path" "path/filepath" "reflect" "strconv" @@ -232,7 +232,7 @@ func filenameWithPath(path, f string) (string, error) { return f, nil } -func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) error { +func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error { mf, err := os.Open(path) if err != nil { fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)}) @@ -265,7 +265,7 @@ func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) e // the model file does not exist, try pulling it if errors.Is(err, os.ErrNotExist) { fn(api.ProgressResponse{Status: "pulling model file"}) - if err := PullModel(c.Args, &RegistryOptions{}, fn); err != nil { + if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil { return err } mf, err = GetManifest(ParseModelPath(c.Args)) @@ -900,7 +900,7 @@ func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressRespon return nil } -func PullModel(name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { +func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) fn(api.ProgressResponse{Status: "pulling manifest"}) @@ -915,7 +915,7 @@ func PullModel(name string, regOpts *RegistryOptions, fn func(api.ProgressRespon layers = append(layers, &manifest.Config) for _, layer := range layers { - if err := downloadBlob(mp, layer.Digest, regOpts, fn); err != nil { + if err := downloadBlob(ctx, mp, layer.Digest, regOpts, fn); err != nil { return err } } @@ -1142,112 +1142,6 @@ func uploadBlobChunked(mp ModelPath, url string, layer *Layer, regOpts *Registry return nil } -func downloadBlob(mp ModelPath, digest string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { - fp, err := GetBlobsPath(digest) - if err != nil { - return err - } - - if fi, _ := os.Stat(fp); fi != nil { - // we already have the file, so return - fn(api.ProgressResponse{ - Digest: digest, - Total: int(fi.Size()), - Completed: int(fi.Size()), - }) - - return nil - } - - var size int64 - chunkSize := 1024 * 1024 // 1 MiB in bytes - - fi, err := os.Stat(fp + "-partial") - switch { - case errors.Is(err, os.ErrNotExist): - // noop, file doesn't exist so create it - case err != nil: - return fmt.Errorf("stat: %w", err) - default: - size = fi.Size() - // Ensure the size is divisible by the chunk size by removing excess bytes - size -= size % int64(chunkSize) - - err := os.Truncate(fp+"-partial", size) - if err != nil { - return fmt.Errorf("truncate: %w", err) - } - } - - url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest) - headers := map[string]string{ - "Range": fmt.Sprintf("bytes=%d-", size), - } - - resp, err := makeRequest("GET", url, headers, nil, regOpts) - if err != nil { - log.Printf("couldn't download blob: %v", err) - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { - body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("on download registry responded with code %d: %v", resp.StatusCode, string(body)) - } - - err = os.MkdirAll(path.Dir(fp), 0o700) - if err != nil { - return fmt.Errorf("make blobs directory: %w", err) - } - - out, err := os.OpenFile(fp+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) - if err != nil { - return fmt.Errorf("open file: %w", err) - } - defer out.Close() - - remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) - completed := size - total := remaining + completed - - for { - fn(api.ProgressResponse{ - Status: fmt.Sprintf("pulling %s...", digest[7:19]), - Digest: digest, - Total: int(total), - Completed: int(completed), - }) - - if completed >= total { - if err := out.Close(); err != nil { - return err - } - - if err := os.Rename(fp+"-partial", fp); err != nil { - fn(api.ProgressResponse{ - Status: fmt.Sprintf("error renaming file: %v", err), - Digest: digest, - Total: int(total), - Completed: int(completed), - }) - return err - } - - break - } - - n, err := io.CopyN(out, resp.Body, int64(chunkSize)) - if err != nil && !errors.Is(err, io.EOF) { - return err - } - completed += n - } - - log.Printf("success getting %s\n", digest) - return nil -} - func makeRequest(method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) { if !strings.HasPrefix(url, "http") { if regOpts.Insecure { diff --git a/server/routes.go b/server/routes.go index 9182de3b..731e8078 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1,6 +1,7 @@ package server import ( + "context" "encoding/json" "errors" "fmt" @@ -200,7 +201,10 @@ func PullModelHandler(c *gin.Context) { Password: req.Password, } - if err := PullModel(req.Name, regOpts, fn); err != nil { + ctx, cancel := context.WithCancel(c.Request.Context()) + defer cancel() + + if err := PullModel(ctx, req.Name, regOpts, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() @@ -250,7 +254,10 @@ func CreateModelHandler(c *gin.Context) { ch <- resp } - if err := CreateModel(req.Name, req.Path, fn); err != nil { + ctx, cancel := context.WithCancel(c.Request.Context()) + defer cancel() + + if err := CreateModel(ctx, req.Name, req.Path, fn); err != nil { ch <- gin.H{"error": err.Error()} } }()