diff --git a/llm/llama.go b/llm/llama.go index 8c5762b6..ce697b33 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -117,7 +117,21 @@ func (llm *llamaModel) ModelFamily() ModelFamily { } func (llm *llamaModel) ModelType() ModelType { - return ModelType30B + switch llm.hyperparameters.NumLayer { + case 26: + return ModelType3B + case 32: + return ModelType7B + case 40: + return ModelType13B + case 60: + return ModelType30B + case 80: + return ModelType65B + } + + // TODO: find a better default + return ModelType7B } func (llm *llamaModel) FileType() FileType { diff --git a/server/images.go b/server/images.go index 5c6aa4d4..8e9c1a13 100644 --- a/server/images.go +++ b/server/images.go @@ -325,7 +325,27 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api } if mf != nil { - log.Printf("manifest = %#v", mf) + sourceBlobPath, err := GetBlobsPath(mf.Config.Digest) + if err != nil { + return err + } + + sourceBlob, err := os.Open(sourceBlobPath) + if err != nil { + return err + } + defer sourceBlob.Close() + + var source ConfigV2 + if err := json.NewDecoder(sourceBlob).Decode(&source); err != nil { + return err + } + + // copie the model metadata + config.ModelFamily = source.ModelFamily + config.ModelType = source.ModelType + config.FileType = source.FileType + for _, l := range mf.Layers { newLayer, err := GetLayerWithBufferFromLayer(l) if err != nil {