From 38255d2af15932150606e19bea8200b386cfd36d Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Wed, 22 May 2024 21:52:09 -0700 Subject: [PATCH] Use flash attention flag for now (#4580) * put flash attention behind flag for now * add test * remove print * up timeout for sheduler tests --- llm/server.go | 10 +++++----- server/envconfig/config.go | 10 ++++++++++ server/envconfig/config_test.go | 3 +++ server/sched_test.go | 2 +- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/llm/server.go b/llm/server.go index ba25fa21..c63a76a4 100644 --- a/llm/server.go +++ b/llm/server.go @@ -200,20 +200,20 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr params = append(params, "--numa") } - flashAttnSupported := true + flashAttnEnabled := envconfig.FlashAttention // partial offloading does not support flash attention - if uint64(opts.NumGPU) < ggml.KV().BlockCount() + 1 { - flashAttnSupported = false + if uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 { + flashAttnEnabled = false } // only cuda (compute capability 7+) and metal support flash attention for _, g := range gpus { if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) { - flashAttnSupported = false + flashAttnEnabled = false } } - if flashAttnSupported { + if flashAttnEnabled { params = append(params, "--flash-attn") } diff --git a/server/envconfig/config.go b/server/envconfig/config.go index 9ad68180..ae7d89b2 100644 --- a/server/envconfig/config.go +++ b/server/envconfig/config.go @@ -31,6 +31,8 @@ var ( RunnersDir string // Set via OLLAMA_TMPDIR in the environment TmpDir string + // Experimental flash attention + FlashAttention bool ) func AsMap() map[string]string { @@ -45,6 +47,7 @@ func AsMap() map[string]string { "OLLAMA_NUM_PARALLEL": fmt.Sprintf("%v", NumParallel), "OLLAMA_RUNNERS_DIR": fmt.Sprintf("%v", RunnersDir), "OLLAMA_TMPDIR": fmt.Sprintf("%v", TmpDir), + "OLLAMA_FLASH_ATTENTION": fmt.Sprintf("%v", FlashAttention), } } @@ -78,6 +81,13 @@ func LoadConfig() { } } + if fa := clean("OLLAMA_FLASH_ATTENTION"); fa != "" { + d, err := strconv.ParseBool(fa) + if err == nil { + FlashAttention = d + } + } + RunnersDir = clean("OLLAMA_RUNNERS_DIR") if runtime.GOOS == "windows" && RunnersDir == "" { // On Windows we do not carry the payloads inside the main executable diff --git a/server/envconfig/config_test.go b/server/envconfig/config_test.go index bad7c4a7..429434ae 100644 --- a/server/envconfig/config_test.go +++ b/server/envconfig/config_test.go @@ -17,4 +17,7 @@ func TestConfig(t *testing.T) { t.Setenv("OLLAMA_DEBUG", "1") LoadConfig() require.True(t, Debug) + t.Setenv("OLLAMA_FLASH_ATTENTION", "1") + LoadConfig() + require.True(t, FlashAttention) } diff --git a/server/sched_test.go b/server/sched_test.go index 6a6dd04f..addc1ad8 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -151,7 +151,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV } func TestRequests(t *testing.T) { - ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond) + ctx, done := context.WithTimeout(context.Background(), time.Second) defer done() // Same model, same request