add max context length check

This commit is contained in:
Michael Yang 2024-01-12 14:54:01 -08:00
parent 565f8a3c44
commit eaed6f8c45
3 changed files with 15 additions and 0 deletions

View file

@ -83,6 +83,7 @@ type model interface {
NumEmbed() uint32
NumHead() uint32
NumHeadKv() uint32
NumCtx() uint32
}
type container interface {

View file

@ -308,6 +308,15 @@ func (llm *ggufModel) NumHeadKv() uint32 {
return value.(uint32)
}
func (llm *ggufModel) NumCtx() uint32 {
value, exists := llm.kv[fmt.Sprintf("%s.context_length", llm.ModelFamily())]
if !exists {
return 0
}
return value.(uint32)
}
func (llm *ggufModel) NumGQA() uint32 {
numHeadKv := llm.NumHeadKv()
if numHeadKv == 0 {

View file

@ -35,6 +35,11 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
return nil, err
}
if opts.NumCtx > int(ggml.NumCtx()) {
log.Printf("WARNING: requested context length is greater than model's max context length (%d > %d), using %d instead", opts.NumCtx, ggml.NumCtx(), ggml.NumCtx())
opts.NumCtx = int(ggml.NumCtx())
}
if opts.NumCtx < 4 {
opts.NumCtx = 4
}