From 0f1910129f0a73c469ce2c012d39c8d98b79ef80 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 3 Jul 2024 19:41:17 -0700 Subject: [PATCH] int --- envconfig/config.go | 66 ++++++++++------------------------- integration/basic_test.go | 9 +---- integration/max_queue_test.go | 14 ++++---- server/sched.go | 23 +++++++----- server/sched_test.go | 7 ++-- 5 files changed, 42 insertions(+), 77 deletions(-) diff --git a/envconfig/config.go b/envconfig/config.go index 34cc4dac..01abea42 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -213,13 +213,22 @@ func RunnersDir() (p string) { return p } +func Int(k string, n int) func() int { + return func() int { + if s := getenv(k); s != "" { + if n, err := strconv.ParseInt(s, 10, 64); err == nil && n >= 0 { + return int(n) + } + } + + return n + } +} + var ( - // Set via OLLAMA_MAX_LOADED_MODELS in the environment - MaxRunners int - // Set via OLLAMA_MAX_QUEUE in the environment - MaxQueuedRequests int - // Set via OLLAMA_NUM_PARALLEL in the environment - NumParallel int + NumParallel = Int("OLLAMA_NUM_PARALLEL", 0) + MaxRunners = Int("OLLAMA_MAX_LOADED_MODELS", 0) + MaxQueue = Int("OLLAMA_MAX_QUEUE", 512) ) type EnvVar struct { @@ -235,12 +244,12 @@ func AsMap() map[string]EnvVar { "OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"}, "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"}, "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"}, - "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"}, - "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"}, + "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"}, + "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"}, "OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"}, "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"}, "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"}, - "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests"}, + "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"}, "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"}, "OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir(), "Location for runners"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, @@ -269,42 +278,3 @@ func Values() map[string]string { func getenv(key string) string { return strings.Trim(os.Getenv(key), "\"' ") } - -func init() { - // default values - NumParallel = 0 // Autoselect - MaxRunners = 0 // Autoselect - MaxQueuedRequests = 512 - - LoadConfig() -} - -func LoadConfig() { - if onp := getenv("OLLAMA_NUM_PARALLEL"); onp != "" { - val, err := strconv.Atoi(onp) - if err != nil { - slog.Error("invalid setting, ignoring", "OLLAMA_NUM_PARALLEL", onp, "error", err) - } else { - NumParallel = val - } - } - - maxRunners := getenv("OLLAMA_MAX_LOADED_MODELS") - if maxRunners != "" { - m, err := strconv.Atoi(maxRunners) - if err != nil { - slog.Error("invalid setting, ignoring", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err) - } else { - MaxRunners = m - } - } - - if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" { - p, err := strconv.Atoi(onp) - if err != nil || p <= 0 { - slog.Error("invalid setting, ignoring", "OLLAMA_MAX_QUEUE", onp, "error", err) - } else { - MaxQueuedRequests = p - } - } -} diff --git a/integration/basic_test.go b/integration/basic_test.go index 6e632a1c..8e35b5c5 100644 --- a/integration/basic_test.go +++ b/integration/basic_test.go @@ -45,14 +45,7 @@ func TestUnicodeModelDir(t *testing.T) { defer os.RemoveAll(modelDir) slog.Info("unicode", "OLLAMA_MODELS", modelDir) - oldModelsDir := os.Getenv("OLLAMA_MODELS") - if oldModelsDir == "" { - defer os.Unsetenv("OLLAMA_MODELS") - } else { - defer os.Setenv("OLLAMA_MODELS", oldModelsDir) - } - err = os.Setenv("OLLAMA_MODELS", modelDir) - require.NoError(t, err) + t.Setenv("OLLAMA_MODELS", modelDir) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() diff --git a/integration/max_queue_test.go b/integration/max_queue_test.go index dfa5eae0..b06197e1 100644 --- a/integration/max_queue_test.go +++ b/integration/max_queue_test.go @@ -5,7 +5,6 @@ package integration import ( "context" "errors" - "fmt" "log/slog" "os" "strconv" @@ -14,8 +13,10 @@ import ( "testing" "time" - "github.com/ollama/ollama/api" "github.com/stretchr/testify/require" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/envconfig" ) func TestMaxQueue(t *testing.T) { @@ -27,13 +28,10 @@ func TestMaxQueue(t *testing.T) { // Note: This test can be quite slow when running in CPU mode, so keep the threadCount low unless your on GPU // Also note that by default Darwin can't sustain > ~128 connections without adjusting limits threadCount := 32 - mq := os.Getenv("OLLAMA_MAX_QUEUE") - if mq != "" { - var err error - threadCount, err = strconv.Atoi(mq) - require.NoError(t, err) + if maxQueue := envconfig.MaxQueue(); maxQueue != 0 { + threadCount = maxQueue } else { - os.Setenv("OLLAMA_MAX_QUEUE", fmt.Sprintf("%d", threadCount)) + t.Setenv("OLLAMA_MAX_QUEUE", strconv.Itoa(threadCount)) } req := api.GenerateRequest{ diff --git a/server/sched.go b/server/sched.go index ad40c4ef..610a2c50 100644 --- a/server/sched.go +++ b/server/sched.go @@ -5,9 +5,11 @@ import ( "errors" "fmt" "log/slog" + "os" "reflect" "runtime" "sort" + "strconv" "strings" "sync" "time" @@ -59,11 +61,12 @@ var defaultParallel = 4 var ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending requests exceeded") func InitScheduler(ctx context.Context) *Scheduler { + maxQueue := envconfig.MaxQueue() sched := &Scheduler{ - pendingReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests), - finishedReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests), - expiredCh: make(chan *runnerRef, envconfig.MaxQueuedRequests), - unloadedCh: make(chan interface{}, envconfig.MaxQueuedRequests), + pendingReqCh: make(chan *LlmRequest, maxQueue), + finishedReqCh: make(chan *LlmRequest, maxQueue), + expiredCh: make(chan *runnerRef, maxQueue), + unloadedCh: make(chan interface{}, maxQueue), loaded: make(map[string]*runnerRef), newServerFn: llm.NewLlamaServer, getGpuFn: gpu.GetGPUInfo, @@ -126,7 +129,7 @@ func (s *Scheduler) processPending(ctx context.Context) { slog.Debug("pending request cancelled or timed out, skipping scheduling") continue } - numParallel := envconfig.NumParallel + numParallel := envconfig.NumParallel() // TODO (jmorganca): multimodal models don't support parallel yet // see https://github.com/ollama/ollama/issues/4165 if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 { @@ -148,7 +151,7 @@ func (s *Scheduler) processPending(ctx context.Context) { pending.useLoadedRunner(runner, s.finishedReqCh) break } - } else if envconfig.MaxRunners > 0 && loadedCount >= envconfig.MaxRunners { + } else if envconfig.MaxRunners() > 0 && loadedCount >= envconfig.MaxRunners() { slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount) runnerToExpire = s.findRunnerToUnload() } else { @@ -161,7 +164,7 @@ func (s *Scheduler) processPending(ctx context.Context) { gpus = s.getGpuFn() } - if envconfig.MaxRunners <= 0 { + if envconfig.MaxRunners() <= 0 { // No user specified MaxRunners, so figure out what automatic setting to use // If all GPUs have reliable free memory reporting, defaultModelsPerGPU * the number of GPUs // if any GPU has unreliable free memory reporting, 1x the number of GPUs @@ -173,11 +176,13 @@ func (s *Scheduler) processPending(ctx context.Context) { } } if allReliable { - envconfig.MaxRunners = defaultModelsPerGPU * len(gpus) + // HACK + os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(defaultModelsPerGPU*len(gpus))) slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners, "gpu_count", len(gpus)) } else { + // HACK + os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(len(gpus))) slog.Info("one or more GPUs detected that are unable to accurately report free memory - disabling default concurrency") - envconfig.MaxRunners = len(gpus) } } diff --git a/server/sched_test.go b/server/sched_test.go index 9ddd1fab..3166ff66 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -12,7 +12,6 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/app/lifecycle" - "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/gpu" "github.com/ollama/ollama/llm" @@ -272,7 +271,7 @@ func TestRequestsMultipleLoadedModels(t *testing.T) { c.req.opts.NumGPU = 0 // CPU load, will be allowed d := newScenarioRequest(t, ctx, "ollama-model-3c", 30, nil) // Needs prior unloaded - envconfig.MaxRunners = 1 + t.Setenv("OLLAMA_MAX_LOADED_MODELS", "1") s.newServerFn = a.newServer slog.Info("a") s.pendingReqCh <- a.req @@ -291,7 +290,7 @@ func TestRequestsMultipleLoadedModels(t *testing.T) { require.Len(t, s.loaded, 1) s.loadedMu.Unlock() - envconfig.MaxRunners = 0 + t.Setenv("OLLAMA_MAX_LOADED_MODELS", "0") s.newServerFn = b.newServer slog.Info("b") s.pendingReqCh <- b.req @@ -362,7 +361,7 @@ func TestGetRunner(t *testing.T) { a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond}) b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond}) c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond}) - envconfig.MaxQueuedRequests = 1 + t.Setenv("OLLAMA_MAX_QUEUE", "1") s := InitScheduler(ctx) s.getGpuFn = getGpuFn s.getCpuFn = getCpuFn