diff --git a/cmd/cmd.go b/cmd/cmd.go index ca924ae9..c783e6a8 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -59,7 +59,7 @@ func pull(model string) error { &api.PullRequest{Model: model}, func(progress api.PullProgress) error { if bar == nil { - if progress.Percent == 100 { + if progress.Percent >= 100 { // already downloaded return nil } @@ -73,10 +73,9 @@ func pull(model string) error { } func RunGenerate(_ *cobra.Command, args []string) error { - // join all args into a single prompt - prompt := strings.Join(args[1:], " ") if len(args) > 1 { - return generate(args[0], prompt) + // join all args into a single prompt + return generate(args[0], strings.Join(args[1:], " ")) } if term.IsTerminal(int(os.Stdin.Fd())) { diff --git a/server/models.go b/server/models.go index fd689ed6..b3eca189 100644 --- a/server/models.go +++ b/server/models.go @@ -119,25 +119,22 @@ func saveModel(model *Model, fn func(total, completed int64)) error { } defer out.Close() - totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) + remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) + completed := size - totalBytes := size - totalSize += size + total := remaining + completed for { - n, err := io.CopyN(out, resp.Body, 8192) + fn(total, completed) + if completed >= total { + return os.Rename(model.TempFile(), model.FullName()) + } + + n , err := io.CopyN(out, resp.Body, 8192) if err != nil && !errors.Is(err, io.EOF) { return err } - if n == 0 { - break - } - - totalBytes += n - fn(totalSize, totalBytes) + completed += n } - - fn(totalSize, totalSize) - return os.Rename(model.TempFile(), model.FullName()) } diff --git a/server/routes.go b/server/routes.go index ef19f3c2..2d7badb5 100644 --- a/server/routes.go +++ b/server/routes.go @@ -112,7 +112,7 @@ func pull(c *gin.Context) { ch <- api.PullProgress{ Total: total, Completed: completed, - Percent: float64(total) / float64(completed) * 100, + Percent: float64(completed) / float64(total) * 100, } }