From 5c59455b59a3a6b53bad0f029b17d90aec173bfc Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 12 Oct 2023 15:56:40 -0700 Subject: [PATCH] cmd: use existing cmd context --- cmd/cmd.go | 54 ++++++++++++++++++++++++++---------------------------- 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index df0d90c8..073b1b4a 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -133,7 +133,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { } request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile)} - if err := client.Create(context.Background(), &request, fn); err != nil { + if err := client.Create(cmd.Context(), &request, fn); err != nil { return err } @@ -148,7 +148,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { name := args[0] // check if the model exists on the server - _, err = client.Show(context.Background(), &api.ShowRequest{Name: name}) + _, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name}) var statusError api.StatusError switch { case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound: @@ -208,7 +208,7 @@ func PushHandler(cmd *cobra.Command, args []string) error { } request := api.PushRequest{Name: args[0], Insecure: insecure} - if err := client.Push(context.Background(), &request, fn); err != nil { + if err := client.Push(cmd.Context(), &request, fn); err != nil { return err } @@ -222,7 +222,7 @@ func ListHandler(cmd *cobra.Command, args []string) error { return err } - models, err := client.List(context.Background()) + models, err := client.List(cmd.Context()) if err != nil { return err } @@ -257,7 +257,7 @@ func DeleteHandler(cmd *cobra.Command, args []string) error { for _, name := range args { req := api.DeleteRequest{Name: name} - if err := client.Delete(context.Background(), &req); err != nil { + if err := client.Delete(cmd.Context(), &req); err != nil { return err } fmt.Printf("deleted '%s'\n", name) @@ -322,7 +322,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error { } req := api.ShowRequest{Name: args[0]} - resp, err := client.Show(context.Background(), &req) + resp, err := client.Show(cmd.Context(), &req) if err != nil { return err } @@ -350,7 +350,7 @@ func CopyHandler(cmd *cobra.Command, args []string) error { } req := api.CopyRequest{Source: args[0], Destination: args[1]} - if err := client.Copy(context.Background(), &req); err != nil { + if err := client.Copy(cmd.Context(), &req); err != nil { return err } fmt.Printf("copied '%s' to '%s'\n", args[0], args[1]) @@ -404,7 +404,7 @@ func PullHandler(cmd *cobra.Command, args []string) error { } request := api.PullRequest{Name: args[0], Insecure: insecure} - if err := client.Pull(context.Background(), &request, fn); err != nil { + if err := client.Pull(cmd.Context(), &request, fn); err != nil { return err } @@ -493,7 +493,7 @@ func generate(cmd *cobra.Command, opts generateOptions) error { opts.WordWrap = false } - cancelCtx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(cmd.Context()) defer cancel() sigChan := make(chan os.Signal, 1) @@ -507,15 +507,6 @@ func generate(cmd *cobra.Command, opts generateOptions) error { var currentLineLength int var wordBuffer string - request := api.GenerateRequest{ - Model: opts.Model, - Prompt: opts.Prompt, - Context: generateContext, - Format: opts.Format, - System: opts.System, - Template: opts.Template, - Options: opts.Options, - } fn := func(response api.GenerateResponse) error { p.StopAndClear() @@ -560,7 +551,17 @@ func generate(cmd *cobra.Command, opts generateOptions) error { return nil } - if err := client.Generate(cancelCtx, &request, fn); err != nil { + request := api.GenerateRequest{ + Model: opts.Model, + Prompt: opts.Prompt, + Context: generateContext, + Format: opts.Format, + System: opts.System, + Template: opts.Template, + Options: opts.Options, + } + + if err := client.Generate(ctx, &request, fn); err != nil { if errors.Is(err, context.Canceled) { return nil } @@ -584,10 +585,7 @@ func generate(cmd *cobra.Command, opts generateOptions) error { latest.Summary() } - ctx := cmd.Context() - ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context) - cmd.SetContext(ctx) - + cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)) return nil } @@ -988,7 +986,7 @@ func initializeKeypair() error { return nil } -func startMacApp(client *api.Client) error { +func startMacApp(ctx context.Context, client *api.Client) error { exe, err := os.Executable() if err != nil { return err @@ -1012,24 +1010,24 @@ func startMacApp(client *api.Client) error { case <-timeout: return errors.New("timed out waiting for server to start") case <-tick: - if err := client.Heartbeat(context.Background()); err == nil { + if err := client.Heartbeat(ctx); err == nil { return nil // server has started } } } } -func checkServerHeartbeat(_ *cobra.Command, _ []string) error { +func checkServerHeartbeat(cmd *cobra.Command, _ []string) error { client, err := api.ClientFromEnvironment() if err != nil { return err } - if err := client.Heartbeat(context.Background()); err != nil { + if err := client.Heartbeat(cmd.Context()); err != nil { if !strings.Contains(err.Error(), "connection refused") { return err } if runtime.GOOS == "darwin" { - if err := startMacApp(client); err != nil { + if err := startMacApp(cmd.Context(), client); err != nil { return fmt.Errorf("could not connect to ollama app, is it running?") } } else {