From 354967667840a65ff07873fb90b146bd543f8ce4 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 26 Jul 2023 11:50:29 -0700 Subject: [PATCH] embed ggml-metal.metal --- ggml-metal.metal | 1 - llama/llama.go | 4 ++++ llama/llama_darwin.go | 53 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 1 deletion(-) delete mode 120000 ggml-metal.metal create mode 100644 llama/llama_darwin.go diff --git a/ggml-metal.metal b/ggml-metal.metal deleted file mode 120000 index 9f596334..00000000 --- a/ggml-metal.metal +++ /dev/null @@ -1 +0,0 @@ -llama/ggml-metal.metal \ No newline at end of file diff --git a/llama/llama.go b/llama/llama.go index 04e679a0..9032bf5f 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -86,6 +86,7 @@ llama_token llama_sample( import "C" import ( "bytes" + "embed" "errors" "fmt" "io" @@ -99,6 +100,9 @@ import ( "github.com/jmorganca/ollama/api" ) +//go:embed ggml-metal.metal +var fs embed.FS + type LLM struct { params *C.struct_llama_context_params model *C.struct_llama_model diff --git a/llama/llama_darwin.go b/llama/llama_darwin.go new file mode 100644 index 00000000..8e81ed54 --- /dev/null +++ b/llama/llama_darwin.go @@ -0,0 +1,53 @@ +package llama + +import ( + "errors" + "io" + "log" + "os" + "path/filepath" +) + +func init() { + if err := initBackend(); err != nil { + log.Printf("WARNING: GPU could not be initialized correctly: %v", err) + log.Printf("WARNING: falling back to CPU") + } +} + +func initBackend() error { + exec, err := os.Executable() + if err != nil { + return err + } + + exec, err = filepath.EvalSymlinks(exec) + if err != nil { + return err + } + + metal := filepath.Join(filepath.Dir(exec), "ggml-metal.metal") + if _, err := os.Stat(metal); err != nil { + if !errors.Is(err, os.ErrNotExist) { + return err + } + + dst, err := os.Create(filepath.Join(filepath.Dir(exec), "ggml-metal.metal")) + if err != nil { + return err + } + defer dst.Close() + + src, err := fs.Open("ggml-metal.metal") + if err != nil { + return err + } + defer src.Close() + + if _, err := io.Copy(dst, src); err != nil { + return err + } + } + + return nil +}