From e15307fdf4217f87a80fba3c9cd72d0f3d325848 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 21 May 2024 06:36:03 +1000 Subject: [PATCH] feat: add support for flash_attn (#4120) * feat: enable flash attention if supported * feat: enable flash attention if supported * feat: enable flash attention if supported * feat: add flash_attn support --- llm/ext_server/server.cpp | 14 +++++++++++--- llm/server.go | 17 +++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/llm/ext_server/server.cpp b/llm/ext_server/server.cpp index 0c339989..3e03bb34 100644 --- a/llm/ext_server/server.cpp +++ b/llm/ext_server/server.cpp @@ -2104,6 +2104,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); + printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled"); printf(" -spf FNAME, --system-prompt-file FNAME\n"); printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n"); printf(" -ctk TYPE, --cache-type-k TYPE\n"); @@ -2501,7 +2502,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, { params.use_mmap = false; } - else if (arg == "--numa") { + else if (arg == "--numa") + { if (++i >= argc) { invalid_param = true; break; @@ -2521,6 +2523,10 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, { params.cont_batching = true; } + else if (arg == "-fa" || arg == "--flash-attn") + { + params.flash_attn = true; + } else if (arg == "-np" || arg == "--parallel") { if (++i >= argc) @@ -2529,7 +2535,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, break; } params.n_parallel = std::stoi(argv[i]); - } else if (arg == "-n" || arg == "--n-predict") + } + else if (arg == "-n" || arg == "--n-predict") { if (++i >= argc) { @@ -2537,7 +2544,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, break; } params.n_predict = std::stoi(argv[i]); - } else if (arg == "-spf" || arg == "--system-prompt-file") + } + else if (arg == "-spf" || arg == "--system-prompt-file") { if (++i >= argc) { diff --git a/llm/server.go b/llm/server.go index ccb1e419..ba25fa21 100644 --- a/llm/server.go +++ b/llm/server.go @@ -200,6 +200,23 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr params = append(params, "--numa") } + flashAttnSupported := true + + // partial offloading does not support flash attention + if uint64(opts.NumGPU) < ggml.KV().BlockCount() + 1 { + flashAttnSupported = false + } + + // only cuda (compute capability 7+) and metal support flash attention + for _, g := range gpus { + if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) { + flashAttnSupported = false + } + } + if flashAttnSupported { + params = append(params, "--flash-attn") + } + numParallel := envconfig.NumParallel // TODO (jmorganca): multimodal models don't support parallel yet