From b0618a466e998e8e7bb5b093972fd01013cee2b8 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 6 Jul 2023 15:43:04 -0700 Subject: [PATCH] generate progress --- cmd/cmd.go | 61 ++++++++++++++++++++++++++++++++++++++++-------- server/models.go | 23 +++++++++--------- server/routes.go | 4 ++++ 3 files changed, 67 insertions(+), 21 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 4ee74e77..e6cfa7ff 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -3,11 +3,13 @@ package cmd import ( "bufio" "context" + "errors" "fmt" "log" "net" "os" "path" + "time" "github.com/schollz/progressbar/v3" "github.com/spf13/cobra" @@ -27,11 +29,18 @@ func cacheDir() string { } func RunRun(cmd *cobra.Command, args []string) error { - if err := pull(args[0]); err != nil { + _, err := os.Stat(args[0]) + switch { + case errors.Is(err, os.ErrNotExist): + if err := pull(args[0]); err != nil { + return err + } + + fmt.Println("Up to date.") + case err != nil: return err } - fmt.Println("Up to date.") return RunGenerate(cmd, args) } @@ -54,7 +63,7 @@ func pull(model string) error { func RunGenerate(_ *cobra.Command, args []string) error { if len(args) > 1 { - return generate(args[0], args[1:]...) + return generateOneshot(args[0], args[1:]...) } if term.IsTerminal(int(os.Stdin.Fd())) { @@ -64,21 +73,53 @@ func RunGenerate(_ *cobra.Command, args []string) error { return generateBatch(args[0]) } -func generate(model string, prompts ...string) error { +func generate(model, prompt string) error { client := api.NewClient() - for _, prompt := range prompts { - client.Generate(context.Background(), &api.GenerateRequest{Model: model, Prompt: prompt}, func(resp api.GenerateResponse) error { - fmt.Print(resp.Response) - return nil - }) - } + spinner := progressbar.NewOptions(-1, + progressbar.OptionSetWriter(os.Stderr), + progressbar.OptionThrottle(60*time.Millisecond), + progressbar.OptionSpinnerType(14), + progressbar.OptionSetRenderBlankState(true), + progressbar.OptionSetElapsedTime(false), + progressbar.OptionClearOnFinish(), + ) + + go func() { + for range time.Tick(60 * time.Millisecond) { + if spinner.IsFinished() { + break + } + + spinner.Add(1) + } + }() + + client.Generate(context.Background(), &api.GenerateRequest{Model: model, Prompt: prompt}, func(resp api.GenerateResponse) error { + if !spinner.IsFinished() { + spinner.Finish() + } + + fmt.Print(resp.Response) + return nil + }) fmt.Println() fmt.Println() return nil } +func generateOneshot(model string, prompts ...string) error { + for _, prompt := range prompts { + fmt.Printf(">>> %s\n", prompt) + if err := generate(model, prompt); err != nil { + return err + } + } + + return nil +} + func generateInteractive(model string) error { fmt.Print(">>> ") scanner := bufio.NewScanner(os.Stdin) diff --git a/server/models.go b/server/models.go index b1753909..ac41d11d 100644 --- a/server/models.go +++ b/server/models.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "net/http" "os" "path" @@ -30,6 +29,15 @@ type Model struct { License string `json:"license"` } +func (m *Model) FullName() string { + home, err := os.UserHomeDir() + if err != nil { + panic(err) + } + + return path.Join(home, ".ollama", "models", m.Name+".bin") +} + func pull(model string, progressCh chan<- api.PullProgress) error { remote, err := getRemote(model) if err != nil { @@ -45,7 +53,7 @@ func getRemote(model string) (*Model, error) { return nil, fmt.Errorf("failed to get directory: %w", err) } defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read directory: %w", err) } @@ -64,13 +72,6 @@ func getRemote(model string) (*Model, error) { func saveModel(model *Model, progressCh chan<- api.PullProgress) error { // this models cache directory is created by the server on startup - home, err := os.UserHomeDir() - if err != nil { - return fmt.Errorf("failed to get home directory: %w", err) - } - modelsCache := path.Join(home, ".ollama", "models") - - fileName := path.Join(modelsCache, model.Name+".bin") client := &http.Client{} req, err := http.NewRequest("GET", model.URL, nil) @@ -79,7 +80,7 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error { } // check for resume alreadyDownloaded := int64(0) - fileInfo, err := os.Stat(fileName) + fileInfo, err := os.Stat(model.FullName()) if err != nil { if !os.IsNotExist(err) { return fmt.Errorf("failed to check resume model file: %w", err) @@ -111,7 +112,7 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error { return fmt.Errorf("failed to download model: %s", resp.Status) } - out, err := os.OpenFile(fileName, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) + out, err := os.OpenFile(model.FullName(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) if err != nil { panic(err) } diff --git a/server/routes.go b/server/routes.go index 0ca3d10e..4831b7ad 100644 --- a/server/routes.go +++ b/server/routes.go @@ -37,6 +37,10 @@ func generate(c *gin.Context) { return } + if remoteModel, _ := getRemote(req.Model); remoteModel != nil { + req.Model = remoteModel.FullName() + } + model, err := llama.New(req.Model, llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(gpulayers)) if err != nil { fmt.Println("Loading the model failed:", err.Error())