ollama/llm/shim_test.go

62 lines
1.7 KiB
Go
Raw Normal View History

package llm
import (
"testing"
"github.com/jmorganca/ollama/gpu"
"github.com/stretchr/testify/assert"
)
func TestGetShims(t *testing.T) {
availableShims = map[string]string{
"cpu": "X_cpu",
}
assert.Equal(t, false, rocmShimPresent())
res := getShims(gpu.GpuInfo{Library: "cpu"})
assert.Len(t, res, 2)
assert.Equal(t, availableShims["cpu"], res[0])
assert.Equal(t, "default", res[1])
availableShims = map[string]string{
"rocm_v5": "X_rocm_v5",
"rocm_v6": "X_rocm_v6",
"cpu": "X_cpu",
}
assert.Equal(t, true, rocmShimPresent())
res = getShims(gpu.GpuInfo{Library: "rocm"})
assert.Len(t, res, 4)
assert.Equal(t, availableShims["rocm_v5"], res[0])
assert.Equal(t, availableShims["rocm_v6"], res[1])
assert.Equal(t, availableShims["cpu"], res[2])
assert.Equal(t, "default", res[3])
res = getShims(gpu.GpuInfo{Library: "rocm", Variant: "v6"})
assert.Len(t, res, 4)
assert.Equal(t, availableShims["rocm_v6"], res[0])
assert.Equal(t, availableShims["rocm_v5"], res[1])
assert.Equal(t, availableShims["cpu"], res[2])
assert.Equal(t, "default", res[3])
res = getShims(gpu.GpuInfo{Library: "cuda"})
assert.Len(t, res, 2)
assert.Equal(t, availableShims["cpu"], res[0])
assert.Equal(t, "default", res[1])
res = getShims(gpu.GpuInfo{Library: "default"})
assert.Len(t, res, 2)
assert.Equal(t, availableShims["cpu"], res[0])
assert.Equal(t, "default", res[1])
availableShims = map[string]string{
"rocm": "X_rocm_v5",
"cpu": "X_cpu",
}
assert.Equal(t, true, rocmShimPresent())
res = getShims(gpu.GpuInfo{Library: "rocm", Variant: "v6"})
assert.Len(t, res, 3)
assert.Equal(t, availableShims["rocm"], res[0])
assert.Equal(t, availableShims["cpu"], res[1])
assert.Equal(t, "default", res[2])
}