add max context length check
This commit is contained in:
parent
565f8a3c44
commit
eaed6f8c45
|
@ -83,6 +83,7 @@ type model interface {
|
||||||
NumEmbed() uint32
|
NumEmbed() uint32
|
||||||
NumHead() uint32
|
NumHead() uint32
|
||||||
NumHeadKv() uint32
|
NumHeadKv() uint32
|
||||||
|
NumCtx() uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
type container interface {
|
type container interface {
|
||||||
|
|
|
@ -308,6 +308,15 @@ func (llm *ggufModel) NumHeadKv() uint32 {
|
||||||
return value.(uint32)
|
return value.(uint32)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (llm *ggufModel) NumCtx() uint32 {
|
||||||
|
value, exists := llm.kv[fmt.Sprintf("%s.context_length", llm.ModelFamily())]
|
||||||
|
if !exists {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return value.(uint32)
|
||||||
|
}
|
||||||
|
|
||||||
func (llm *ggufModel) NumGQA() uint32 {
|
func (llm *ggufModel) NumGQA() uint32 {
|
||||||
numHeadKv := llm.NumHeadKv()
|
numHeadKv := llm.NumHeadKv()
|
||||||
if numHeadKv == 0 {
|
if numHeadKv == 0 {
|
||||||
|
|
|
@ -35,6 +35,11 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if opts.NumCtx > int(ggml.NumCtx()) {
|
||||||
|
log.Printf("WARNING: requested context length is greater than model's max context length (%d > %d), using %d instead", opts.NumCtx, ggml.NumCtx(), ggml.NumCtx())
|
||||||
|
opts.NumCtx = int(ggml.NumCtx())
|
||||||
|
}
|
||||||
|
|
||||||
if opts.NumCtx < 4 {
|
if opts.NumCtx < 4 {
|
||||||
opts.NumCtx = 4
|
opts.NumCtx = 4
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue