diff --git a/llm/llama.go b/llm/llama.go index 79cd8f7e..8ddaadef 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -191,7 +191,7 @@ var errNoGPU = errors.New("nvidia-smi command failed") // CheckVRAM returns the available VRAM in MiB on Linux machines with NVIDIA GPUs func CheckVRAM() (int64, error) { - cmd := exec.Command("nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits") + cmd := exec.Command("nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits") var stdout bytes.Buffer cmd.Stdout = &stdout err := cmd.Run() @@ -199,7 +199,7 @@ func CheckVRAM() (int64, error) { return 0, errNoGPU } - var total int64 + var free int64 scanner := bufio.NewScanner(&stdout) for scanner.Scan() { line := scanner.Text() @@ -208,10 +208,10 @@ func CheckVRAM() (int64, error) { return 0, fmt.Errorf("failed to parse available VRAM: %v", err) } - total += vram + free += vram } - return total, nil + return free, nil } func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int { @@ -228,14 +228,14 @@ func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int { return 0 } - totalVramBytes := int64(vramMib) * 1024 * 1024 // 1 MiB = 1024^2 bytes + freeVramBytes := int64(vramMib) * 1024 * 1024 // 1 MiB = 1024^2 bytes // Calculate bytes per layer // TODO: this is a rough heuristic, better would be to calculate this based on number of layers and context size bytesPerLayer := fileSizeBytes / numLayer - // max number of layers we can fit in VRAM - layers := int(totalVramBytes / bytesPerLayer) + // max number of layers we can fit in VRAM, subtract 5% to prevent consuming all available VRAM and running out of memory + layers := int(freeVramBytes/bytesPerLayer) * 95 / 100 log.Printf("%d MiB VRAM available, loading up to %d GPU layers", vramMib, layers) return layers @@ -367,7 +367,13 @@ func waitForServer(llm *llama) error { } func (llm *llama) Close() { + // signal the sub-process to terminate llm.Cancel() + + // wait for the command to exit to prevent race conditions with the next run + if err := llm.Cmd.Wait(); err != nil { + log.Printf("llama runner exited: %v", err) + } } func (llm *llama) SetOptions(opts api.Options) {