From abed273de3a6183d734f0f3f0f129d7bd08ac4b4 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Wed, 11 Sep 2024 16:36:21 -0700 Subject: [PATCH] add "stop" command (#6739) --- cmd/cmd.go | 56 ++++++++++++++++++++++++++++++++++++++++++-- cmd/interactive.go | 23 +----------------- server/routes.go | 52 ++++++++++++++++++++++++++++++++++++++++ server/sched.go | 20 +++++++++++++++- server/sched_test.go | 46 ++++++++++++++++++++++++++++++++++++ 5 files changed, 172 insertions(+), 25 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 1fb721e7..3bb8b06e 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -346,6 +346,39 @@ func (w *progressWriter) Write(p []byte) (n int, err error) { return len(p), nil } +func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error { + p := progress.NewProgress(os.Stderr) + defer p.StopAndClear() + + spinner := progress.NewSpinner("") + p.Add("", spinner) + + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + req := &api.GenerateRequest{ + Model: opts.Model, + KeepAlive: opts.KeepAlive, + } + + return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil }) +} + +func StopHandler(cmd *cobra.Command, args []string) error { + opts := &runOptions{ + Model: args[0], + KeepAlive: &api.Duration{Duration: 0}, + } + if err := loadOrUnloadModel(cmd, opts); err != nil { + if strings.Contains(err.Error(), "not found") { + return fmt.Errorf("couldn't find model \"%s\" to stop", args[0]) + } + } + return nil +} + func RunHandler(cmd *cobra.Command, args []string) error { interactive := true @@ -424,7 +457,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { opts.ParentModel = info.Details.ParentModel if interactive { - if err := loadModel(cmd, &opts); err != nil { + if err := loadOrUnloadModel(cmd, &opts); err != nil { return err } @@ -615,7 +648,15 @@ func ListRunningHandler(cmd *cobra.Command, args []string) error { cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100) procStr = fmt.Sprintf("%d%%/%d%% CPU/GPU", int(cpuPercent), int(100-cpuPercent)) } - data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, format.HumanTime(m.ExpiresAt, "Never")}) + + var until string + delta := time.Since(m.ExpiresAt) + if delta > 0 { + until = "Stopping..." + } else { + until = format.HumanTime(m.ExpiresAt, "Never") + } + data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, until}) } } @@ -1294,6 +1335,15 @@ func NewCLI() *cobra.Command { runCmd.Flags().Bool("insecure", false, "Use an insecure registry") runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically") runCmd.Flags().String("format", "", "Response format (e.g. json)") + + stopCmd := &cobra.Command{ + Use: "stop MODEL", + Short: "Stop a running model", + Args: cobra.ExactArgs(1), + PreRunE: checkServerHeartbeat, + RunE: StopHandler, + } + serveCmd := &cobra.Command{ Use: "serve", Aliases: []string{"start"}, @@ -1361,6 +1411,7 @@ func NewCLI() *cobra.Command { createCmd, showCmd, runCmd, + stopCmd, pullCmd, pushCmd, listCmd, @@ -1400,6 +1451,7 @@ func NewCLI() *cobra.Command { createCmd, showCmd, runCmd, + stopCmd, pullCmd, pushCmd, listCmd, diff --git a/cmd/interactive.go b/cmd/interactive.go index 9fe1ed4c..94578f11 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -18,7 +18,6 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/parser" - "github.com/ollama/ollama/progress" "github.com/ollama/ollama/readline" "github.com/ollama/ollama/types/errtypes" ) @@ -31,26 +30,6 @@ const ( MultilineSystem ) -func loadModel(cmd *cobra.Command, opts *runOptions) error { - p := progress.NewProgress(os.Stderr) - defer p.StopAndClear() - - spinner := progress.NewSpinner("") - p.Add("", spinner) - - client, err := api.ClientFromEnvironment() - if err != nil { - return err - } - - chatReq := &api.ChatRequest{ - Model: opts.Model, - KeepAlive: opts.KeepAlive, - } - - return client.Chat(cmd.Context(), chatReq, func(api.ChatResponse) error { return nil }) -} - func generateInteractive(cmd *cobra.Command, opts runOptions) error { usage := func() { fmt.Fprintln(os.Stderr, "Available Commands:") @@ -217,7 +196,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { opts.Model = args[1] opts.Messages = []api.Message{} fmt.Printf("Loading model '%s'\n", opts.Model) - if err := loadModel(cmd, &opts); err != nil { + if err := loadOrUnloadModel(cmd, &opts); err != nil { return err } continue diff --git a/server/routes.go b/server/routes.go index 5e9f51e1..f202973e 100644 --- a/server/routes.go +++ b/server/routes.go @@ -117,6 +117,32 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + // expire the runner + if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 { + model, err := GetModel(req.Model) + if err != nil { + switch { + case os.IsNotExist(err): + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) + case err.Error() == "invalid model name": + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + default: + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + } + return + } + s.sched.expireRunner(model) + + c.JSON(http.StatusOK, api.GenerateResponse{ + Model: req.Model, + CreatedAt: time.Now().UTC(), + Response: "", + Done: true, + DoneReason: "unload", + }) + return + } + if req.Format != "" && req.Format != "json" { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""}) return @@ -1322,6 +1348,32 @@ func (s *Server) ChatHandler(c *gin.Context) { return } + // expire the runner + if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 { + model, err := GetModel(req.Model) + if err != nil { + switch { + case os.IsNotExist(err): + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) + case err.Error() == "invalid model name": + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + default: + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + } + return + } + s.sched.expireRunner(model) + + c.JSON(http.StatusOK, api.ChatResponse{ + Model: req.Model, + CreatedAt: time.Now().UTC(), + Message: api.Message{Role: "assistant"}, + Done: true, + DoneReason: "unload", + }) + return + } + caps := []Capability{CapabilityCompletion} if len(req.Tools) > 0 { caps = append(caps, CapabilityTools) diff --git a/server/sched.go b/server/sched.go index 58071bf0..3c8656ad 100644 --- a/server/sched.go +++ b/server/sched.go @@ -360,7 +360,6 @@ func (s *Scheduler) processCompleted(ctx context.Context) { slog.Debug("runner expired event received", "modelPath", runner.modelPath) runner.refMu.Lock() if runner.refCount > 0 { - // Shouldn't happen, but safeguard to ensure no leaked runners slog.Debug("expired event with positive ref count, retrying", "modelPath", runner.modelPath, "refCount", runner.refCount) go func(runner *runnerRef) { // We can't unload yet, but want to as soon as the current request completes @@ -802,6 +801,25 @@ func (s *Scheduler) unloadAllRunners() { } } +func (s *Scheduler) expireRunner(model *Model) { + s.loadedMu.Lock() + defer s.loadedMu.Unlock() + runner, ok := s.loaded[model.ModelPath] + if ok { + runner.refMu.Lock() + runner.expiresAt = time.Now() + if runner.expireTimer != nil { + runner.expireTimer.Stop() + runner.expireTimer = nil + } + runner.sessionDuration = 0 + if runner.refCount <= 0 { + s.expiredCh <- runner + } + runner.refMu.Unlock() + } +} + // If other runners are loaded, make sure the pending request will fit in system memory // If not, pick a runner to unload, else return nil and the request can be loaded func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) *runnerRef { diff --git a/server/sched_test.go b/server/sched_test.go index fb049574..be32065a 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -406,6 +406,52 @@ func TestGetRunner(t *testing.T) { b.ctxDone() } +func TestExpireRunner(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer done() + s := InitScheduler(ctx) + req := &LlmRequest{ + ctx: ctx, + model: &Model{ModelPath: "foo"}, + opts: api.DefaultOptions(), + successCh: make(chan *runnerRef, 1), + errCh: make(chan error, 1), + sessionDuration: &api.Duration{Duration: 2 * time.Minute}, + } + + var ggml *llm.GGML + gpus := gpu.GpuInfoList{} + server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}} + s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { + return server, nil + } + s.load(req, ggml, gpus, 0) + + select { + case err := <-req.errCh: + if err != nil { + t.Fatalf("expected no errors when loading, got '%s'", err.Error()) + } + case resp := <-req.successCh: + s.loadedMu.Lock() + if resp.refCount != uint(1) || len(s.loaded) != 1 { + t.Fatalf("expected a model to be loaded") + } + s.loadedMu.Unlock() + } + + s.expireRunner(&Model{ModelPath: "foo"}) + + s.finishedReqCh <- req + s.processCompleted(ctx) + + s.loadedMu.Lock() + if len(s.loaded) != 0 { + t.Fatalf("expected model to be unloaded") + } + s.loadedMu.Unlock() +} + // TODO - add one scenario that triggers the bogus finished event with positive ref count func TestPrematureExpired(t *testing.T) { ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)