From afa8d6e9d56da834a03df7817d065f6c8b46e102 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Tue, 30 Jul 2024 18:06:26 -0700 Subject: [PATCH] patch gemma support --- llm/patches/10-params.diff | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 llm/patches/10-params.diff diff --git a/llm/patches/10-params.diff b/llm/patches/10-params.diff new file mode 100644 index 00000000..56699b8e --- /dev/null +++ b/llm/patches/10-params.diff @@ -0,0 +1,20 @@ +diff --git a/src/llama.cpp b/src/llama.cpp +index a207451f..fba6b175 100644 +--- a/src/llama.cpp ++++ b/src/llama.cpp +@@ -4969,6 +4969,7 @@ static void llm_load_hparams( + hparams.attn_soft_cap = true; + + switch (hparams.n_layer) { ++ case 26: model.type = e_model::MODEL_2B; break; + case 42: model.type = e_model::MODEL_9B; break; + case 46: model.type = e_model::MODEL_27B; break; + default: model.type = e_model::MODEL_UNKNOWN; +@@ -11736,6 +11737,7 @@ struct llm_build_context { + + // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e + switch (model.type) { ++ case e_model::MODEL_2B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); break; + case e_model::MODEL_9B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); break; + case e_model::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break; + default: GGML_ABORT("fatal error");