diff --git a/docs/faq.md b/docs/faq.md index 6da8c545..109a1144 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -232,3 +232,9 @@ curl http://localhost:11434/api/generate -d '{"model": "llama3", "keep_alive": 0 Alternatively, you can change the amount of time all models are loaded into memory by setting the `OLLAMA_KEEP_ALIVE` environment variable when starting the Ollama server. The `OLLAMA_KEEP_ALIVE` variable uses the same parameter types as the `keep_alive` parameter types mentioned above. Refer to section explaining [how to configure the Ollama server](#how-do-i-configure-ollama-server) to correctly set the environment variable. If you wish to override the `OLLAMA_KEEP_ALIVE` setting, use the `keep_alive` API parameter with the `/api/generate` or `/api/chat` API endpoints. + +## How do I manage the maximum number of requests the server can queue + +If too many requests are sent to the server, it will respond with a 503 error +indicating the server is overloaded. You can adjust how many requests may be +queue by setting `OLLAMA_MAX_QUEUE` \ No newline at end of file diff --git a/integration/max_queue_test.go b/integration/max_queue_test.go new file mode 100644 index 00000000..43b15c6c --- /dev/null +++ b/integration/max_queue_test.go @@ -0,0 +1,117 @@ +//go:build integration + +package integration + +import ( + "context" + "errors" + "fmt" + "log/slog" + "os" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/ollama/ollama/api" + "github.com/stretchr/testify/require" +) + +func TestMaxQueue(t *testing.T) { + // Note: This test can be quite slow when running in CPU mode, so keep the threadCount low unless your on GPU + // Also note that by default Darwin can't sustain > ~128 connections without adjusting limits + threadCount := 32 + mq := os.Getenv("OLLAMA_MAX_QUEUE") + if mq != "" { + var err error + threadCount, err = strconv.Atoi(mq) + require.NoError(t, err) + } else { + os.Setenv("OLLAMA_MAX_QUEUE", fmt.Sprintf("%d", threadCount)) + } + + req := api.GenerateRequest{ + Model: "orca-mini", + Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey", + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + } + resp := []string{"explore", "discover", "ocean"} + + // CPU mode takes much longer at the limit with a large queue setting + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + require.NoError(t, PullIfMissing(ctx, client, req.Model)) + + // Context for the worker threads so we can shut them down + // embedCtx, embedCancel := context.WithCancel(ctx) + embedCtx := ctx + + var genwg sync.WaitGroup + go func() { + genwg.Add(1) + defer genwg.Done() + slog.Info("Starting generate request") + DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second) + slog.Info("generate completed") + }() + + // Give the generate a chance to get started before we start hammering on embed requests + time.Sleep(5 * time.Millisecond) + + threadCount += 10 // Add a few extra to ensure we push the queue past its limit + busyCount := 0 + resetByPeerCount := 0 + canceledCount := 0 + succesCount := 0 + counterMu := sync.Mutex{} + var embedwg sync.WaitGroup + for i := 0; i < threadCount; i++ { + go func(i int) { + embedwg.Add(1) + defer embedwg.Done() + slog.Info("embed started", "id", i) + embedReq := api.EmbeddingRequest{ + Model: req.Model, + Prompt: req.Prompt, + Options: req.Options, + } + // Fresh client for every request + client, _ = GetTestEndpoint() + + resp, genErr := client.Embeddings(embedCtx, &embedReq) + counterMu.Lock() + defer counterMu.Unlock() + switch { + case genErr == nil: + succesCount++ + require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable + case errors.Is(genErr, context.Canceled): + canceledCount++ + case strings.Contains(genErr.Error(), "busy"): + busyCount++ + case strings.Contains(genErr.Error(), "connection reset by peer"): + resetByPeerCount++ + default: + require.NoError(t, genErr, "%d request failed", i) + } + + slog.Info("embed finished", "id", i) + }(i) + } + genwg.Wait() + slog.Info("generate done, waiting for embeds") + embedwg.Wait() + + require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?") + require.True(t, busyCount > 0, "no requests hit busy error but some should have") + require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout") + + slog.Info("embeds completed", "success", succesCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount) +} diff --git a/server/routes.go b/server/routes.go index 1e7c80e7..3b24735f 100644 --- a/server/routes.go +++ b/server/routes.go @@ -146,12 +146,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { select { case runner = <-rCh: case err = <-eCh: - if errors.Is(err, context.Canceled) { - c.JSON(499, gin.H{"error": "request canceled"}) - return - } - - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + handleErrorResponse(c, err) return } @@ -394,12 +389,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { select { case runner = <-rCh: case err = <-eCh: - if errors.Is(err, context.Canceled) { - c.JSON(499, gin.H{"error": "request canceled"}) - return - } - - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + handleErrorResponse(c, err) return } @@ -1212,12 +1202,7 @@ func (s *Server) ChatHandler(c *gin.Context) { select { case runner = <-rCh: case err = <-eCh: - if errors.Is(err, context.Canceled) { - c.JSON(499, gin.H{"error": "request canceled"}) - return - } - - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + handleErrorResponse(c, err) return } @@ -1338,3 +1323,15 @@ func (s *Server) ChatHandler(c *gin.Context) { streamResponse(c, ch) } + +func handleErrorResponse(c *gin.Context, err error) { + if errors.Is(err, context.Canceled) { + c.JSON(499, gin.H{"error": "request canceled"}) + return + } + if errors.Is(err, ErrMaxQueue) { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) +} diff --git a/server/sched.go b/server/sched.go index 61c5e1b3..d40a45ad 100644 --- a/server/sched.go +++ b/server/sched.go @@ -43,10 +43,13 @@ type Scheduler struct { getGpuFn func() gpu.GpuInfoList } -// TODO set this to zero after a release or two, to enable multiple models by default -var loadedMax = 1 // Maximum runners; < 1 maps to as many as will fit in VRAM (unlimited for CPU runners) -var maxQueuedRequests = 10 // TODO configurable -var numParallel = 1 +var ( + // TODO set this to zero after a release or two, to enable multiple models by default + loadedMax = 1 // Maximum runners; < 1 maps to as many as will fit in VRAM (unlimited for CPU runners) + maxQueuedRequests = 512 + numParallel = 1 + ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending requests exceeded") +) func InitScheduler(ctx context.Context) *Scheduler { maxRunners := os.Getenv("OLLAMA_MAX_LOADED_MODELS") @@ -66,6 +69,14 @@ func InitScheduler(ctx context.Context) *Scheduler { numParallel = p } } + if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" { + p, err := strconv.Atoi(onp) + if err != nil || p <= 0 { + slog.Error("invalid setting", "OLLAMA_MAX_QUEUE", onp, "error", err) + } else { + maxQueuedRequests = p + } + } sched := &Scheduler{ pendingReqCh: make(chan *LlmRequest, maxQueuedRequests), @@ -95,7 +106,7 @@ func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, select { case s.pendingReqCh <- req: default: - req.errCh <- fmt.Errorf("server busy, please try again. maximum pending requests exceeded") + req.errCh <- ErrMaxQueue } return req.successCh, req.errCh }