11import asyncio
22import enum
3- import re
43import uuid
54from collections .abc import Iterable
65from dataclasses import dataclass
1110from sqlalchemy .ext .asyncio import AsyncSession
1211from sqlalchemy .orm import aliased , contains_eager , joinedload , load_only
1312
14- from dstack ._internal import settings
1513from dstack ._internal .core .consts import DSTACK_RUNNER_HTTP_PORT , DSTACK_SHIM_HTTP_PORT
1614from dstack ._internal .core .errors import GatewayError
1715from dstack ._internal .core .models .backends .base import BackendType
5250 RunModel ,
5351 UserModel ,
5452)
55- from dstack ._internal .server .schemas .runner import GPUDevice , TaskStatus
53+ from dstack ._internal .server .schemas .runner import TaskStatus
5654from dstack ._internal .server .services import events , services
5755from dstack ._internal .server .services import files as files_services
5856from dstack ._internal .server .services import logs as logs_services
57+ from dstack ._internal .server .services .backends .provisioning import (
58+ get_instance_specific_gpu_devices ,
59+ get_instance_specific_mounts ,
60+ resolve_provisioning_image_name ,
61+ )
5962from dstack ._internal .server .services .instances import (
6063 get_instance_remote_connection_info ,
6164 get_instance_ssh_private_keys ,
@@ -759,9 +762,9 @@ def _process_provisioning_with_shim(
759762 for volume , volume_mount in zip (volumes , volume_mounts ):
760763 volume_mount .name = volume .name
761764
762- instance_mounts += _get_instance_specific_mounts (jpd .backend , jpd .instance_type .name )
765+ instance_mounts += get_instance_specific_mounts (jpd .backend , jpd .instance_type .name )
763766
764- gpu_devices = _get_instance_specific_gpu_devices (jpd .backend , jpd .instance_type .name )
767+ gpu_devices = get_instance_specific_gpu_devices (jpd .backend , jpd .instance_type .name )
765768
766769 container_user = "root"
767770
@@ -778,7 +781,7 @@ def _process_provisioning_with_shim(
778781 cpu = None
779782 memory = None
780783 network_mode = NetworkMode .HOST
781- image_name = _patch_base_image_for_aws_efa (job_spec , jpd )
784+ image_name = resolve_provisioning_image_name (job_spec , jpd )
782785 if shim_client .is_api_v2_supported ():
783786 shim_client .submit_task (
784787 task_id = job_model .id ,
@@ -1301,105 +1304,3 @@ def _interpolate_secrets(secrets: Dict[str, str], job_spec: JobSpec):
13011304 username = interpolate (job_spec .registry_auth .username ),
13021305 password = interpolate (job_spec .registry_auth .password ),
13031306 )
1304-
1305-
1306- def _get_instance_specific_mounts (
1307- backend_type : BackendType , instance_type_name : str
1308- ) -> List [InstanceMountPoint ]:
1309- if backend_type == BackendType .GCP :
1310- if instance_type_name == "a3-megagpu-8g" :
1311- return [
1312- InstanceMountPoint (
1313- instance_path = "/dev/aperture_devices" ,
1314- path = "/dev/aperture_devices" ,
1315- ),
1316- InstanceMountPoint (
1317- instance_path = "/var/lib/tcpxo/lib64" ,
1318- path = "/var/lib/tcpxo/lib64" ,
1319- ),
1320- InstanceMountPoint (
1321- instance_path = "/var/lib/fastrak/lib64" ,
1322- path = "/var/lib/fastrak/lib64" ,
1323- ),
1324- ]
1325- if instance_type_name in ["a3-edgegpu-8g" , "a3-highgpu-8g" ]:
1326- return [
1327- InstanceMountPoint (
1328- instance_path = "/var/lib/nvidia/lib64" ,
1329- path = "/usr/local/nvidia/lib64" ,
1330- ),
1331- InstanceMountPoint (
1332- instance_path = "/var/lib/nvidia/bin" ,
1333- path = "/usr/local/nvidia/bin" ,
1334- ),
1335- InstanceMountPoint (
1336- instance_path = "/var/lib/tcpx/lib64" ,
1337- path = "/usr/local/tcpx/lib64" ,
1338- ),
1339- InstanceMountPoint (
1340- instance_path = "/run/tcpx" ,
1341- path = "/run/tcpx" ,
1342- ),
1343- ]
1344- return []
1345-
1346-
1347- def _get_instance_specific_gpu_devices (
1348- backend_type : BackendType , instance_type_name : str
1349- ) -> List [GPUDevice ]:
1350- gpu_devices = []
1351- if backend_type == BackendType .GCP and instance_type_name in [
1352- "a3-edgegpu-8g" ,
1353- "a3-highgpu-8g" ,
1354- ]:
1355- for i in range (8 ):
1356- gpu_devices .append (
1357- GPUDevice (path_on_host = f"/dev/nvidia{ i } " , path_in_container = f"/dev/nvidia{ i } " )
1358- )
1359- gpu_devices .append (
1360- GPUDevice (path_on_host = "/dev/nvidia-uvm" , path_in_container = "/dev/nvidia-uvm" )
1361- )
1362- gpu_devices .append (
1363- GPUDevice (path_on_host = "/dev/nvidiactl" , path_in_container = "/dev/nvidiactl" )
1364- )
1365- return gpu_devices
1366-
1367-
1368- def _patch_base_image_for_aws_efa (
1369- job_spec : JobSpec , job_provisioning_data : JobProvisioningData
1370- ) -> str :
1371- image_name = job_spec .image_name
1372-
1373- if job_provisioning_data .backend != BackendType .AWS :
1374- return image_name
1375-
1376- instance_type = job_provisioning_data .instance_type .name
1377- efa_enabled_patterns = [
1378- # TODO: p6-b200 isn't supported yet in gpuhunt
1379- r"^p6-b200\.(48xlarge)$" ,
1380- r"^p5\.(4xlarge|48xlarge)$" ,
1381- r"^p5e\.(48xlarge)$" ,
1382- r"^p5en\.(48xlarge)$" ,
1383- r"^p4d\.(24xlarge)$" ,
1384- r"^p4de\.(24xlarge)$" ,
1385- r"^g6\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$" ,
1386- r"^g6e\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$" ,
1387- r"^gr6\.8xlarge$" ,
1388- r"^g5\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$" ,
1389- r"^g4dn\.(8xlarge|12xlarge|16xlarge|metal)$" ,
1390- r"^p3dn\.(24xlarge)$" ,
1391- ]
1392-
1393- is_efa_enabled = any (re .match (pattern , instance_type ) for pattern in efa_enabled_patterns )
1394- if not is_efa_enabled :
1395- return image_name
1396-
1397- if not image_name .startswith (f"{ settings .DSTACK_BASE_IMAGE } :" ):
1398- return image_name
1399-
1400- if image_name .endswith (f"-base-ubuntu{ settings .DSTACK_BASE_IMAGE_UBUNTU_VERSION } " ):
1401- return image_name [:- 17 ] + f"-devel-efa-ubuntu{ settings .DSTACK_BASE_IMAGE_UBUNTU_VERSION } "
1402- elif image_name .endswith (f"-devel-ubuntu{ settings .DSTACK_BASE_IMAGE_UBUNTU_VERSION } " ):
1403- return image_name [:- 18 ] + f"-devel-efa-ubuntu{ settings .DSTACK_BASE_IMAGE_UBUNTU_VERSION } "
1404-
1405- return image_name
0 commit comments