diff --git a/llm/ggml.go b/llm/ggml.go index ac3a39d8..3fb0539c 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -83,6 +83,7 @@ type model interface { NumEmbed() uint32 NumHead() uint32 NumHeadKv() uint32 + NumCtx() uint32 } type container interface { diff --git a/llm/gguf.go b/llm/gguf.go index 96c3c7b9..cfcab758 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -308,6 +308,15 @@ func (llm *ggufModel) NumHeadKv() uint32 { return value.(uint32) } +func (llm *ggufModel) NumCtx() uint32 { + value, exists := llm.kv[fmt.Sprintf("%s.context_length", llm.ModelFamily())] + if !exists { + return 0 + } + + return value.(uint32) +} + func (llm *ggufModel) NumGQA() uint32 { numHeadKv := llm.NumHeadKv() if numHeadKv == 0 { diff --git a/llm/llm.go b/llm/llm.go index 613519f0..fa352df4 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -35,6 +35,11 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options) return nil, err } + if opts.NumCtx > int(ggml.NumCtx()) { + log.Printf("WARNING: requested context length is greater than model's max context length (%d > %d), using %d instead", opts.NumCtx, ggml.NumCtx(), ggml.NumCtx()) + opts.NumCtx = int(ggml.NumCtx()) + } + if opts.NumCtx < 4 { opts.NumCtx = 4 }