diff --git a/llm/llama.go b/llm/llama.go index 9589a417..f23d5d85 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -218,7 +218,6 @@ func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int { if opts.NumGPU != -1 { return opts.NumGPU } - n := 1 // default to enable metal on macOS if runtime.GOOS == "linux" { vramMib, err := CheckVRAM() if err != nil { @@ -235,10 +234,11 @@ func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int { // TODO: this is a rough heuristic, better would be to calculate this based on number of layers and context size bytesPerLayer := fileSizeBytes / numLayer - // set n to the max number of layers we can fit in VRAM - return int(totalVramBytes / bytesPerLayer) + // max number of layers we can fit in VRAM + layers := int(totalVramBytes / bytesPerLayer) + log.Printf("%d MiB VRAM available, loading up to %d GPU layers", vramMib, layers) - log.Printf("%d MiB VRAM available, loading up to %d GPU layers", vramMib, n) + return layers } // default to enable metal on macOS return 1