diff --git a/llm/llama.go b/llm/llama.go index 8c5762b6..ce697b33 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -117,7 +117,21 @@ func (llm *llamaModel) ModelFamily() ModelFamily { } func (llm *llamaModel) ModelType() ModelType { - return ModelType30B + switch llm.hyperparameters.NumLayer { + case 26: + return ModelType3B + case 32: + return ModelType7B + case 40: + return ModelType13B + case 60: + return ModelType30B + case 80: + return ModelType65B + } + + // TODO: find a better default + return ModelType7B } func (llm *llamaModel) FileType() FileType {