diff --git a/llm/llm.go b/llm/llm.go index 86dd3346..69ea705f 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -22,6 +22,9 @@ type LLM interface { Close() } +// Set to false on linux/windows if we are able to load the shim +var ShimPresent = false + func New(workDir, model string, adapters, projectors []string, opts api.Options) (LLM, error) { if _, err := os.Stat(model); err != nil { return nil, err @@ -79,11 +82,10 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options) opts.RopeFrequencyBase = 0.0 opts.RopeFrequencyScale = 0.0 gpuInfo := gpu.GetGPUInfo() - switch gpuInfo.Driver { - case "ROCM": + if gpuInfo.Driver == "ROCM" && ShimPresent { return newRocmShimExtServer(model, adapters, projectors, ggml.NumLayers(), opts) - default: - // Rely on the built-in CUDA based server which will fall back to CPU + } else { + // Rely on the built-in CUDA/Metal based server which will fall back to CPU return newLlamaExtServer(model, adapters, projectors, ggml.NumLayers(), opts) } } diff --git a/llm/shim_ext_server.go b/llm/shim_ext_server.go index 0e7bcfae..7505adaa 100644 --- a/llm/shim_ext_server.go +++ b/llm/shim_ext_server.go @@ -30,7 +30,6 @@ import ( var libEmbed embed.FS var RocmShimMissing = fmt.Errorf("ROCm shim library not included in this build of ollama. Radeon GPUs are not supported") -var NoShim = true type shimExtServer struct { s C.struct_rocm_llama_server @@ -78,7 +77,7 @@ func (llm *shimExtServer) llama_server_release_json_resp(json_resp **C.char) { } func newRocmShimExtServer(model string, adapters, projectors []string, numLayers int64, opts api.Options) (extServer, error) { - if NoShim { + if !ShimPresent { return nil, RocmShimMissing } log.Printf("Loading ROCM llm server") @@ -207,6 +206,6 @@ func extractLib(workDir string) error { case err != nil: return fmt.Errorf("stat ROCm shim %s: %v", files[0], err) } - NoShim = false + ShimPresent = true return nil }