diff --git a/runner/internal/common/gpu.go b/runner/internal/common/gpu.go index 19c46ea661..045cc773be 100644 --- a/runner/internal/common/gpu.go +++ b/runner/internal/common/gpu.go @@ -16,6 +16,8 @@ const ( ) func GetGpuVendor() GpuVendor { + // FIXME: There might be errors other than os.ErrNotExist that are ignored silently. + // Propagate and log. if _, err := os.Stat("/dev/kfd"); !errors.Is(err, os.ErrNotExist) { return GpuVendorAmd } @@ -28,5 +30,11 @@ func GetGpuVendor() GpuVendor { if _, err := os.Stat("/dev/tenstorrent"); !errors.Is(err, os.ErrNotExist) { return GpuVendorTenstorrent } + if _, err := os.Stat("/dev/dxg"); !errors.Is(err, os.ErrNotExist) { + // WSL2 + if _, err := os.Stat("/usr/lib/wsl/lib/nvidia-smi"); !errors.Is(err, os.ErrNotExist) { + return GpuVendorNvidia + } + } return GpuVendorNone }