diff --git a/cmd/cmd.go b/cmd/cmd.go index d47db65b..2356110e 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -22,6 +22,7 @@ import ( "runtime" "slices" "strings" + "sync/atomic" "syscall" "time" @@ -78,6 +79,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { status := "transferring model data" spinner := progress.NewSpinner(status) p.Add(status, spinner) + defer p.Stop() for i := range modelfile.Commands { switch modelfile.Commands[i].Name { @@ -112,7 +114,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { path = tempfile } - digest, err := createBlob(cmd, client, path) + digest, err := createBlob(cmd, client, path, spinner) if err != nil { return err } @@ -263,13 +265,20 @@ func tempZipFiles(path string) (string, error) { return tempfile.Name(), nil } -func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) { +func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) { bin, err := os.Open(path) if err != nil { return "", err } defer bin.Close() + // Get file info to retrieve the size + fileInfo, err := bin.Stat() + if err != nil { + return "", err + } + fileSize := fileInfo.Size() + hash := sha256.New() if _, err := io.Copy(hash, bin); err != nil { return "", err @@ -279,13 +288,43 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er return "", err } + var pw progressWriter + status := "transferring model data 0%" + spinner.SetMessage(status) + + done := make(chan struct{}) + defer close(done) + + go func() { + ticker := time.NewTicker(60 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ticker.C: + spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n.Load()/fileSize))) + case <-done: + spinner.SetMessage("transferring model data 100%") + return + } + } + }() + digest := fmt.Sprintf("sha256:%x", hash.Sum(nil)) - if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil { + if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil { return "", err } return digest, nil } +type progressWriter struct { + n atomic.Int64 +} + +func (w *progressWriter) Write(p []byte) (n int, err error) { + w.n.Add(int64(len(p))) + return len(p), nil +} + func RunHandler(cmd *cobra.Command, args []string) error { interactive := true diff --git a/progress/spinner.go b/progress/spinner.go index 02f3f9fb..e39a45ee 100644 --- a/progress/spinner.go +++ b/progress/spinner.go @@ -3,11 +3,12 @@ package progress import ( "fmt" "strings" + "sync/atomic" "time" ) type Spinner struct { - message string + message atomic.Value messageWidth int parts []string @@ -21,20 +22,25 @@ type Spinner struct { func NewSpinner(message string) *Spinner { s := &Spinner{ - message: message, parts: []string{ "⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏", }, started: time.Now(), } + s.SetMessage(message) go s.start() return s } +func (s *Spinner) SetMessage(message string) { + s.message.Store(message) +} + func (s *Spinner) String() string { var sb strings.Builder - if len(s.message) > 0 { - message := strings.TrimSpace(s.message) + + if message, ok := s.message.Load().(string); ok && len(message) > 0 { + message := strings.TrimSpace(message) if s.messageWidth > 0 && len(message) > s.messageWidth { message = message[:s.messageWidth] }