Skip to content

Commit 5e74a49

Browse files
Egor-Speterschmidt85
authored andcommitted
Request GPU without deprecated --runtime=nvidia (#913)
(cherry picked from commit 730e8b0)
1 parent 4d60c80 commit 5e74a49

1 file changed

Lines changed: 15 additions & 10 deletions

File tree

runner/internal/shim/docker.go

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,10 @@ func pullImage(ctx context.Context, client docker.APIClient, imageName string, r
8282
}
8383

8484
func createContainer(ctx context.Context, client docker.APIClient, params DockerParameters) (string, error) {
85-
runtime, err := getRuntime(ctx, client)
85+
gpuRequest, err := requestGpuIfAvailable(ctx, client)
8686
if err != nil {
8787
return "", gerrors.Wrap(err)
8888
}
89-
9089
mounts, err := params.DockerMounts()
9190
if err != nil {
9291
return "", gerrors.Wrap(err)
@@ -103,8 +102,10 @@ func createContainer(ctx context.Context, client docker.APIClient, params Docker
103102
PortBindings: bindPorts(params.DockerPorts()...),
104103
PublishAllPorts: true,
105104
Sysctls: map[string]string{},
106-
Runtime: runtime,
107-
Mounts: mounts,
105+
Resources: container.Resources{
106+
DeviceRequests: gpuRequest,
107+
},
108+
Mounts: mounts,
108109
}
109110
resp, err := client.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, "")
110111
if err != nil {
@@ -178,17 +179,21 @@ func getNetworkMode() container.NetworkMode {
178179
return "default"
179180
}
180181

181-
func getRuntime(ctx context.Context, client docker.APIClient) (string, error) {
182+
func requestGpuIfAvailable(ctx context.Context, client docker.APIClient) ([]container.DeviceRequest, error) {
182183
info, err := client.Info(ctx)
183184
if err != nil {
184-
return "", gerrors.Wrap(err)
185+
return nil, gerrors.Wrap(err)
185186
}
186-
for name := range info.Runtimes {
187-
if name == consts.NVIDIA_RUNTIME {
188-
return name, nil
187+
188+
for runtime := range info.Runtimes {
189+
if runtime == consts.NVIDIA_RUNTIME {
190+
return []container.DeviceRequest{
191+
{Capabilities: [][]string{{"gpu"}}, Count: -1}, // --gpus=all
192+
}, nil
189193
}
190194
}
191-
return info.DefaultRuntime, nil
195+
196+
return nil, nil
192197
}
193198

194199
/* DockerParameters interface implementation for CLIArgs */

0 commit comments

Comments
 (0)