@@ -82,11 +82,10 @@ func pullImage(ctx context.Context, client docker.APIClient, imageName string, r
8282}
8383
8484func createContainer (ctx context.Context , client docker.APIClient , params DockerParameters ) (string , error ) {
85- runtime , err := getRuntime (ctx , client )
85+ gpuRequest , err := requestGpuIfAvailable (ctx , client )
8686 if err != nil {
8787 return "" , gerrors .Wrap (err )
8888 }
89-
9089 mounts , err := params .DockerMounts ()
9190 if err != nil {
9291 return "" , gerrors .Wrap (err )
@@ -103,8 +102,10 @@ func createContainer(ctx context.Context, client docker.APIClient, params Docker
103102 PortBindings : bindPorts (params .DockerPorts ()... ),
104103 PublishAllPorts : true ,
105104 Sysctls : map [string ]string {},
106- Runtime : runtime ,
107- Mounts : mounts ,
105+ Resources : container.Resources {
106+ DeviceRequests : gpuRequest ,
107+ },
108+ Mounts : mounts ,
108109 }
109110 resp , err := client .ContainerCreate (ctx , containerConfig , hostConfig , nil , nil , "" )
110111 if err != nil {
@@ -178,17 +179,21 @@ func getNetworkMode() container.NetworkMode {
178179 return "default"
179180}
180181
181- func getRuntime (ctx context.Context , client docker.APIClient ) (string , error ) {
182+ func requestGpuIfAvailable (ctx context.Context , client docker.APIClient ) ([]container. DeviceRequest , error ) {
182183 info , err := client .Info (ctx )
183184 if err != nil {
184- return "" , gerrors .Wrap (err )
185+ return nil , gerrors .Wrap (err )
185186 }
186- for name := range info .Runtimes {
187- if name == consts .NVIDIA_RUNTIME {
188- return name , nil
187+
188+ for runtime := range info .Runtimes {
189+ if runtime == consts .NVIDIA_RUNTIME {
190+ return []container.DeviceRequest {
191+ {Capabilities : [][]string {{"gpu" }}, Count : - 1 }, // --gpus=all
192+ }, nil
189193 }
190194 }
191- return info .DefaultRuntime , nil
195+
196+ return nil , nil
192197}
193198
194199/* DockerParameters interface implementation for CLIArgs */
0 commit comments