rename partial file

This commit is contained in:
Michael Yang 2023-07-11 13:36:35 -07:00
parent e243329e2e
commit 948323fa78

View file

@ -2,6 +2,7 @@ package server
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -34,6 +35,14 @@ func (m *Model) FullName() string {
return path.Join(home, ".ollama", "models", m.Name+".bin") return path.Join(home, ".ollama", "models", m.Name+".bin")
} }
func (m *Model) TempFile() string {
fullName := m.FullName()
return path.Join(
path.Dir(fullName),
fmt.Sprintf(".%s.part", path.Base(fullName)),
)
}
func getRemote(model string) (*Model, error) { func getRemote(model string) (*Model, error) {
// resolve the model download from our directory // resolve the model download from our directory
resp, err := http.Get(directoryURL) resp, err := http.Get(directoryURL)
@ -66,37 +75,45 @@ func saveModel(model *Model, fn func(total, completed int64)) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to download model: %w", err) return fmt.Errorf("failed to download model: %w", err)
} }
// check for resume
alreadyDownloaded := int64(0) // check if completed file exists
fileInfo, err := os.Stat(model.FullName()) fi, err := os.Stat(model.FullName())
if err != nil { switch {
if !os.IsNotExist(err) { case errors.Is(err, os.ErrNotExist):
return fmt.Errorf("failed to check resume model file: %w", err) // noop, file doesn't exist so create it
case err != nil:
return fmt.Errorf("stat: %w", err)
default:
fn(fi.Size(), fi.Size())
return nil
} }
// file doesn't exist, create it now
} else { var size int64
alreadyDownloaded = fileInfo.Size()
req.Header.Add("Range", fmt.Sprintf("bytes=%d-", alreadyDownloaded)) // completed file doesn't exist, check partial file
fi, err = os.Stat(model.TempFile())
switch {
case errors.Is(err, os.ErrNotExist):
// noop, file doesn't exist so create it
case err != nil:
return fmt.Errorf("stat: %w", err)
default:
size = fi.Size()
} }
req.Header.Add("Range", fmt.Sprintf("bytes=%d-", size))
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return fmt.Errorf("failed to download model: %w", err) return fmt.Errorf("failed to download model: %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable { if resp.StatusCode >= 400 {
// already downloaded
fn(alreadyDownloaded, alreadyDownloaded)
return nil
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
return fmt.Errorf("failed to download model: %s", resp.Status) return fmt.Errorf("failed to download model: %s", resp.Status)
} }
out, err := os.OpenFile(model.FullName(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) out, err := os.OpenFile(model.TempFile(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -104,27 +121,23 @@ func saveModel(model *Model, fn func(total, completed int64)) error {
totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
buf := make([]byte, 1024) totalBytes := size
totalBytes := alreadyDownloaded totalSize += size
totalSize += alreadyDownloaded
for { for {
n, err := resp.Body.Read(buf) n, err := io.CopyN(out, resp.Body, 8192)
if err != nil && err != io.EOF { if err != nil && !errors.Is(err, io.EOF) {
return err return err
} }
if n == 0 { if n == 0 {
break break
} }
if _, err := out.Write(buf[:n]); err != nil {
return err
}
totalBytes += int64(n)
totalBytes += n
fn(totalSize, totalBytes) fn(totalSize, totalBytes)
} }
fn(totalSize, totalSize) fn(totalSize, totalSize)
return nil return os.Rename(model.TempFile(), model.FullName())
} }