From 67190976491af3535c7587e08db05a6f2ff2d7ea Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Thu, 5 Sep 2024 14:00:08 -0700 Subject: [PATCH] llm: make load time stall duration configurable via OLLAMA_LOAD_TIMEOUT With the new very large parameter models, some users are willing to wait for a very long time for models to load. --- cmd/cmd.go | 1 + envconfig/config.go | 27 +++++++++++++++++++++++---- envconfig/config_test.go | 34 ++++++++++++++++++++++++++++++++++ llm/server.go | 5 ++--- 4 files changed, 60 insertions(+), 7 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 890b839a..5de1ed1b 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1422,6 +1422,7 @@ func NewCLI() *cobra.Command { envVars["OLLAMA_FLASH_ATTENTION"], envVars["OLLAMA_LLM_LIBRARY"], envVars["OLLAMA_GPU_OVERHEAD"], + envVars["OLLAMA_LOAD_TIMEOUT"], }) default: appendEnvDocs(cmd, envs) diff --git a/envconfig/config.go b/envconfig/config.go index b47fd8d5..14e3cb0c 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -112,6 +112,26 @@ func KeepAlive() (keepAlive time.Duration) { return keepAlive } +// LoadTimeout returns the duration for stall detection during model loads. LoadTimeout can be configured via the OLLAMA_LOAD_TIMEOUT environment variable. +// Zero or Negative values are treated as infinite. +// Default is 5 minutes. +func LoadTimeout() (loadTimeout time.Duration) { + loadTimeout = 5 * time.Minute + if s := Var("OLLAMA_LOAD_TIMEOUT"); s != "" { + if d, err := time.ParseDuration(s); err == nil { + loadTimeout = d + } else if n, err := strconv.ParseInt(s, 10, 64); err == nil { + loadTimeout = time.Duration(n) * time.Second + } + } + + if loadTimeout <= 0 { + return time.Duration(math.MaxInt64) + } + + return loadTimeout +} + func Bool(k string) func() bool { return func() bool { if s := Var(k); s != "" { @@ -245,10 +265,8 @@ func Uint64(key string, defaultValue uint64) func() uint64 { } } -var ( - // Set aside VRAM per GPU - GpuOverhead = Uint64("OLLAMA_GPU_OVERHEAD", 0) -) +// Set aside VRAM per GPU +var GpuOverhead = Uint64("OLLAMA_GPU_OVERHEAD", 0) type EnvVar struct { Name string @@ -264,6 +282,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"}, "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"}, "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary(), "Set LLM library to bypass autodetection"}, + "OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"}, "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"}, "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"}, "OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"}, diff --git a/envconfig/config_test.go b/envconfig/config_test.go index d52a98a5..7ac7c53e 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -215,6 +215,40 @@ func TestKeepAlive(t *testing.T) { } } +func TestLoadTimeout(t *testing.T) { + defaultTimeout := 5 * time.Minute + cases := map[string]time.Duration{ + "": defaultTimeout, + "1s": time.Second, + "1m": time.Minute, + "1h": time.Hour, + "5m0s": defaultTimeout, + "1h2m3s": 1*time.Hour + 2*time.Minute + 3*time.Second, + "0": time.Duration(math.MaxInt64), + "60": 60 * time.Second, + "120": 2 * time.Minute, + "3600": time.Hour, + "-0": time.Duration(math.MaxInt64), + "-1": time.Duration(math.MaxInt64), + "-1m": time.Duration(math.MaxInt64), + // invalid values + " ": defaultTimeout, + "???": defaultTimeout, + "1d": defaultTimeout, + "1y": defaultTimeout, + "1w": defaultTimeout, + } + + for tt, expect := range cases { + t.Run(tt, func(t *testing.T) { + t.Setenv("OLLAMA_LOAD_TIMEOUT", tt) + if actual := LoadTimeout(); actual != expect { + t.Errorf("%s: expected %s, got %s", tt, expect, actual) + } + }) + } +} + func TestVar(t *testing.T) { cases := map[string]string{ "value": "value", diff --git a/llm/server.go b/llm/server.go index 9c08f1bb..28eb8d6f 100644 --- a/llm/server.go +++ b/llm/server.go @@ -584,8 +584,7 @@ func (s *llmServer) Ping(ctx context.Context) error { func (s *llmServer) WaitUntilRunning(ctx context.Context) error { start := time.Now() - stallDuration := 5 * time.Minute // If no progress happens - finalLoadDuration := 5 * time.Minute // After we hit 100%, give the runner more time to come online + stallDuration := envconfig.LoadTimeout() // If no progress happens stallTimer := time.Now().Add(stallDuration) // give up if we stall slog.Info("waiting for llama runner to start responding") @@ -637,7 +636,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { stallTimer = time.Now().Add(stallDuration) } else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 { slog.Debug("model load completed, waiting for server to become available", "status", status.ToString()) - stallTimer = time.Now().Add(finalLoadDuration) + stallTimer = time.Now().Add(stallDuration) fullyLoaded = true } time.Sleep(time.Millisecond * 250)