From 00b0699c75fc99b998f3394c04e5a16aa4c49eab Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Thu, 25 Apr 2024 19:02:40 -0400 Subject: [PATCH] Reload model if `num_gpu` changes (#3920) * reload model if `num_gpu` changes * dont reload on -1 * fix tests --- server/sched.go | 18 ++++++++++++------ server/sched_test.go | 3 +++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/server/sched.go b/server/sched.go index fa034d28..a8f31d98 100644 --- a/server/sched.go +++ b/server/sched.go @@ -421,16 +421,21 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool slog.Debug("evaluating already loaded", "model", req.model.ModelPath) runner.refMu.Lock() defer runner.refMu.Unlock() - // Ignore the NumGPU settings for comparison - optsExisting := runner.Options.Runner - optsExisting.NumGPU = -1 - optsNew := req.opts.Runner - optsNew.NumGPU = -1 + timeout := 10 * time.Second if runner.loading { timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems... } - ctx, cancel := context.WithTimeout(ctx, timeout) // BUG - + + // Don't reload runner if num_gpu=-1 was provided + optsExisting := runner.Options.Runner + optsNew := req.opts.Runner + if optsNew.NumGPU < 0 { + optsExisting.NumGPU = -1 + optsNew.NumGPU = -1 + } + + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() if !reflect.DeepEqual(runner.adapters, req.model.AdapterPaths) || // have the adapters changed? !reflect.DeepEqual(runner.projectors, req.model.ProjectorPaths) || // have the projectors changed? @@ -438,6 +443,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool runner.llama.Ping(ctx) != nil { return true } + return false } diff --git a/server/sched_test.go b/server/sched_test.go index 3b06e2ba..86bd7846 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -490,6 +490,9 @@ func TestNeedsReload(t *testing.T) { require.False(t, resp) req.opts.NumGPU = 99 resp = runner.needsReload(ctx, req) + require.True(t, resp) + req.opts.NumGPU = -1 + resp = runner.needsReload(ctx, req) require.False(t, resp) }