diff --git a/gpu/gpu_test.go b/gpu/gpu_test.go index 46d3201e..13a3f544 100644 --- a/gpu/gpu_test.go +++ b/gpu/gpu_test.go @@ -32,4 +32,29 @@ func TestCPUMemInfo(t *testing.T) { } } +func TestByLibrary(t *testing.T) { + type testCase struct { + input []GpuInfo + expect int + } + + testCases := map[string]*testCase{ + "empty": {input: []GpuInfo{}, expect: 0}, + "cpu": {input: []GpuInfo{{Library: "cpu"}}, expect: 1}, + "cpu + GPU": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda"}}, expect: 2}, + "cpu + 2 GPU no variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda"}, {Library: "cuda"}}, expect: 2}, + "cpu + 2 GPU same variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda", Variant: "v11"}, {Library: "cuda", Variant: "v11"}}, expect: 2}, + "cpu + 2 GPU diff variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda", Variant: "v11"}, {Library: "cuda", Variant: "v12"}}, expect: 3}, + } + + for k, v := range testCases { + t.Run(k, func(t *testing.T) { + resp := (GpuInfoList)(v.input).ByLibrary() + if len(resp) != v.expect { + t.Fatalf("expected length %d, got %d => %+v", v.expect, len(resp), resp) + } + }) + } +} + // TODO - add some logic to figure out card type through other means and actually verify we got back what we expected diff --git a/gpu/types.go b/gpu/types.go index 4cbbeb84..a30e5fb3 100644 --- a/gpu/types.go +++ b/gpu/types.go @@ -94,7 +94,7 @@ func (l GpuInfoList) ByLibrary() []GpuInfoList { } } if !found { - libs = append(libs, info.Library) + libs = append(libs, requested) resp = append(resp, []GpuInfo{info}) } }