diff --git a/llm/ggml.go b/llm/ggml.go index d877acd1..352c095f 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -303,3 +303,50 @@ func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) { model: model, }, offset, nil } + +func (llm GGML) GraphSize(context, batch int) (int64, bool) { + embeddingLength := llm.KV().EmbeddingLength() + headCount := llm.KV().HeadCount() + headCountKV := llm.KV().HeadCountKV() + vocabLength := len(llm.KV()["tokenizer.ggml.tokens"].([]any)) + + var attnQKVWeight1 uint64 = 0 + for _, t := range llm.Tensors() { + if strings.HasSuffix(t.Name, ".attn_qkv.weight") && len(t.Shape) >= 2 { + attnQKVWeight1 = t.Shape[1] + break + } + } + + var ffnGate1 uint64 = 0 + for _, t := range llm.Tensors() { + if strings.Index(t.Name, ".ffn_gate") > 0 && len(t.Shape) >= 2 { + ffnGate1 = t.Shape[1] + break + } + } + + switch llm.KV().Architecture() { + case "gemma": + return 4 * int64(batch) * int64(embeddingLength+uint64(vocabLength)), true + case "phi2": + return max( + 4*int64(batch)*int64(embeddingLength+uint64(vocabLength)), + 4*int64(batch)*int64(1+4*embeddingLength+uint64(context)+attnQKVWeight1+uint64(context)*headCount), + ), true + case "qwen2": + return max( + 4*int64(batch)*int64(embeddingLength+uint64(vocabLength)), + 4*int64(batch)*int64(1+2*embeddingLength+uint64(context)+uint64(context)*headCount), + ), true + case "llama": + if ffnGate1 > 0 { + // moe + return 4 * int64(batch) * int64(2+3*embeddingLength+uint64(context)+uint64(context)*headCount+2*headCountKV+ffnGate1), true + } + + return 4 * int64(batch) * int64(1+4*embeddingLength+uint64(context)+uint64(context)*headCount), true + } + + return 0, false +} diff --git a/llm/server.go b/llm/server.go index 82cd268d..2994f9a6 100644 --- a/llm/server.go +++ b/llm/server.go @@ -79,10 +79,11 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option // fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv kv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.KV().BlockCount()) * int64(ggml.KV().EmbeddingLength()) / int64(ggml.KV().HeadCount()) * int64(ggml.KV().HeadCountKV()) - // this amount is the overhead + tensors in memory - // TODO: get this from the llama.cpp's graph calculations instead of - // estimating it's 1/6 * kv_cache_size * num_gqa - graph := int64(ggml.KV().GQA()) * kv / 6 + graph, ok := ggml.GraphSize(opts.NumCtx, min(opts.NumCtx, opts.NumBatch)) + if !ok { + graph = int64(ggml.KV().GQA()) * kv / 6 + } + usedMemory += graph if (usedMemory > availableMemory || slices.Contains(cpuOnlyFamilies, ggml.KV().Architecture())) && info.Library != "metal" {