diff --git a/server/models.go b/server/models.go index 496b2c45..fd689ed6 100644 --- a/server/models.go +++ b/server/models.go @@ -2,6 +2,7 @@ package server import ( "encoding/json" + "errors" "fmt" "io" "net/http" @@ -34,6 +35,14 @@ func (m *Model) FullName() string { return path.Join(home, ".ollama", "models", m.Name+".bin") } +func (m *Model) TempFile() string { + fullName := m.FullName() + return path.Join( + path.Dir(fullName), + fmt.Sprintf(".%s.part", path.Base(fullName)), + ) +} + func getRemote(model string) (*Model, error) { // resolve the model download from our directory resp, err := http.Get(directoryURL) @@ -66,37 +75,45 @@ func saveModel(model *Model, fn func(total, completed int64)) error { if err != nil { return fmt.Errorf("failed to download model: %w", err) } - // check for resume - alreadyDownloaded := int64(0) - fileInfo, err := os.Stat(model.FullName()) - if err != nil { - if !os.IsNotExist(err) { - return fmt.Errorf("failed to check resume model file: %w", err) - } - // file doesn't exist, create it now - } else { - alreadyDownloaded = fileInfo.Size() - req.Header.Add("Range", fmt.Sprintf("bytes=%d-", alreadyDownloaded)) + + // check if completed file exists + fi, err := os.Stat(model.FullName()) + switch { + case errors.Is(err, os.ErrNotExist): + // noop, file doesn't exist so create it + case err != nil: + return fmt.Errorf("stat: %w", err) + default: + fn(fi.Size(), fi.Size()) + return nil } + var size int64 + + // completed file doesn't exist, check partial file + fi, err = os.Stat(model.TempFile()) + switch { + case errors.Is(err, os.ErrNotExist): + // noop, file doesn't exist so create it + case err != nil: + return fmt.Errorf("stat: %w", err) + default: + size = fi.Size() + } + + req.Header.Add("Range", fmt.Sprintf("bytes=%d-", size)) + resp, err := client.Do(req) if err != nil { return fmt.Errorf("failed to download model: %w", err) } - defer resp.Body.Close() - if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable { - // already downloaded - fn(alreadyDownloaded, alreadyDownloaded) - return nil - } - - if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { + if resp.StatusCode >= 400 { return fmt.Errorf("failed to download model: %s", resp.Status) } - out, err := os.OpenFile(model.FullName(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) + out, err := os.OpenFile(model.TempFile(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) if err != nil { panic(err) } @@ -104,27 +121,23 @@ func saveModel(model *Model, fn func(total, completed int64)) error { totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) - buf := make([]byte, 1024) - totalBytes := alreadyDownloaded - totalSize += alreadyDownloaded + totalBytes := size + totalSize += size for { - n, err := resp.Body.Read(buf) - if err != nil && err != io.EOF { + n, err := io.CopyN(out, resp.Body, 8192) + if err != nil && !errors.Is(err, io.EOF) { return err } + if n == 0 { break } - if _, err := out.Write(buf[:n]); err != nil { - return err - } - - totalBytes += int64(n) + totalBytes += n fn(totalSize, totalBytes) } fn(totalSize, totalSize) - return nil + return os.Rename(model.TempFile(), model.FullName()) }