Skip to content

Commit c051458

Browse files
committed
Extract backend provisioning helpers
1 parent 9b0a11e commit c051458

5 files changed

Lines changed: 302 additions & 241 deletions

File tree

src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py

Lines changed: 9 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
import enum
3-
import re
43
import uuid
54
from collections.abc import Iterable
65
from dataclasses import dataclass
@@ -11,7 +10,6 @@
1110
from sqlalchemy.ext.asyncio import AsyncSession
1211
from sqlalchemy.orm import aliased, contains_eager, joinedload, load_only
1312

14-
from dstack._internal import settings
1513
from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT
1614
from dstack._internal.core.errors import GatewayError
1715
from dstack._internal.core.models.backends.base import BackendType
@@ -52,10 +50,15 @@
5250
RunModel,
5351
UserModel,
5452
)
55-
from dstack._internal.server.schemas.runner import GPUDevice, TaskStatus
53+
from dstack._internal.server.schemas.runner import TaskStatus
5654
from dstack._internal.server.services import events, services
5755
from dstack._internal.server.services import files as files_services
5856
from dstack._internal.server.services import logs as logs_services
57+
from dstack._internal.server.services.backends.provisioning import (
58+
get_instance_specific_gpu_devices,
59+
get_instance_specific_mounts,
60+
resolve_provisioning_image_name,
61+
)
5962
from dstack._internal.server.services.instances import (
6063
get_instance_remote_connection_info,
6164
get_instance_ssh_private_keys,
@@ -759,9 +762,9 @@ def _process_provisioning_with_shim(
759762
for volume, volume_mount in zip(volumes, volume_mounts):
760763
volume_mount.name = volume.name
761764

762-
instance_mounts += _get_instance_specific_mounts(jpd.backend, jpd.instance_type.name)
765+
instance_mounts += get_instance_specific_mounts(jpd.backend, jpd.instance_type.name)
763766

764-
gpu_devices = _get_instance_specific_gpu_devices(jpd.backend, jpd.instance_type.name)
767+
gpu_devices = get_instance_specific_gpu_devices(jpd.backend, jpd.instance_type.name)
765768

766769
container_user = "root"
767770

@@ -778,7 +781,7 @@ def _process_provisioning_with_shim(
778781
cpu = None
779782
memory = None
780783
network_mode = NetworkMode.HOST
781-
image_name = _patch_base_image_for_aws_efa(job_spec, jpd)
784+
image_name = resolve_provisioning_image_name(job_spec, jpd)
782785
if shim_client.is_api_v2_supported():
783786
shim_client.submit_task(
784787
task_id=job_model.id,
@@ -1301,105 +1304,3 @@ def _interpolate_secrets(secrets: Dict[str, str], job_spec: JobSpec):
13011304
username=interpolate(job_spec.registry_auth.username),
13021305
password=interpolate(job_spec.registry_auth.password),
13031306
)
1304-
1305-
1306-
def _get_instance_specific_mounts(
1307-
backend_type: BackendType, instance_type_name: str
1308-
) -> List[InstanceMountPoint]:
1309-
if backend_type == BackendType.GCP:
1310-
if instance_type_name == "a3-megagpu-8g":
1311-
return [
1312-
InstanceMountPoint(
1313-
instance_path="/dev/aperture_devices",
1314-
path="/dev/aperture_devices",
1315-
),
1316-
InstanceMountPoint(
1317-
instance_path="/var/lib/tcpxo/lib64",
1318-
path="/var/lib/tcpxo/lib64",
1319-
),
1320-
InstanceMountPoint(
1321-
instance_path="/var/lib/fastrak/lib64",
1322-
path="/var/lib/fastrak/lib64",
1323-
),
1324-
]
1325-
if instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]:
1326-
return [
1327-
InstanceMountPoint(
1328-
instance_path="/var/lib/nvidia/lib64",
1329-
path="/usr/local/nvidia/lib64",
1330-
),
1331-
InstanceMountPoint(
1332-
instance_path="/var/lib/nvidia/bin",
1333-
path="/usr/local/nvidia/bin",
1334-
),
1335-
InstanceMountPoint(
1336-
instance_path="/var/lib/tcpx/lib64",
1337-
path="/usr/local/tcpx/lib64",
1338-
),
1339-
InstanceMountPoint(
1340-
instance_path="/run/tcpx",
1341-
path="/run/tcpx",
1342-
),
1343-
]
1344-
return []
1345-
1346-
1347-
def _get_instance_specific_gpu_devices(
1348-
backend_type: BackendType, instance_type_name: str
1349-
) -> List[GPUDevice]:
1350-
gpu_devices = []
1351-
if backend_type == BackendType.GCP and instance_type_name in [
1352-
"a3-edgegpu-8g",
1353-
"a3-highgpu-8g",
1354-
]:
1355-
for i in range(8):
1356-
gpu_devices.append(
1357-
GPUDevice(path_on_host=f"/dev/nvidia{i}", path_in_container=f"/dev/nvidia{i}")
1358-
)
1359-
gpu_devices.append(
1360-
GPUDevice(path_on_host="/dev/nvidia-uvm", path_in_container="/dev/nvidia-uvm")
1361-
)
1362-
gpu_devices.append(
1363-
GPUDevice(path_on_host="/dev/nvidiactl", path_in_container="/dev/nvidiactl")
1364-
)
1365-
return gpu_devices
1366-
1367-
1368-
def _patch_base_image_for_aws_efa(
1369-
job_spec: JobSpec, job_provisioning_data: JobProvisioningData
1370-
) -> str:
1371-
image_name = job_spec.image_name
1372-
1373-
if job_provisioning_data.backend != BackendType.AWS:
1374-
return image_name
1375-
1376-
instance_type = job_provisioning_data.instance_type.name
1377-
efa_enabled_patterns = [
1378-
# TODO: p6-b200 isn't supported yet in gpuhunt
1379-
r"^p6-b200\.(48xlarge)$",
1380-
r"^p5\.(4xlarge|48xlarge)$",
1381-
r"^p5e\.(48xlarge)$",
1382-
r"^p5en\.(48xlarge)$",
1383-
r"^p4d\.(24xlarge)$",
1384-
r"^p4de\.(24xlarge)$",
1385-
r"^g6\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$",
1386-
r"^g6e\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$",
1387-
r"^gr6\.8xlarge$",
1388-
r"^g5\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$",
1389-
r"^g4dn\.(8xlarge|12xlarge|16xlarge|metal)$",
1390-
r"^p3dn\.(24xlarge)$",
1391-
]
1392-
1393-
is_efa_enabled = any(re.match(pattern, instance_type) for pattern in efa_enabled_patterns)
1394-
if not is_efa_enabled:
1395-
return image_name
1396-
1397-
if not image_name.startswith(f"{settings.DSTACK_BASE_IMAGE}:"):
1398-
return image_name
1399-
1400-
if image_name.endswith(f"-base-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"):
1401-
return image_name[:-17] + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"
1402-
elif image_name.endswith(f"-devel-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"):
1403-
return image_name[:-18] + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"
1404-
1405-
return image_name
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import re
2+
3+
from dstack._internal import settings
4+
from dstack._internal.core.models.backends.base import BackendType
5+
from dstack._internal.core.models.runs import JobProvisioningData, JobSpec
6+
from dstack._internal.core.models.volumes import InstanceMountPoint
7+
from dstack._internal.server.schemas.runner import GPUDevice
8+
9+
_AWS_EFA_ENABLED_INSTANCE_TYPE_PATTERNS = [
10+
# TODO: p6-b200 isn't supported yet in gpuhunt
11+
r"^p6-b200\.(48xlarge)$",
12+
r"^p5\.(4xlarge|48xlarge)$",
13+
r"^p5e\.(48xlarge)$",
14+
r"^p5en\.(48xlarge)$",
15+
r"^p4d\.(24xlarge)$",
16+
r"^p4de\.(24xlarge)$",
17+
r"^g6\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$",
18+
r"^g6e\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$",
19+
r"^gr6\.8xlarge$",
20+
r"^g5\.(8xlarge|12xlarge|16xlarge|24xlarge|48xlarge)$",
21+
r"^g4dn\.(8xlarge|12xlarge|16xlarge|metal)$",
22+
r"^p3dn\.(24xlarge)$",
23+
]
24+
25+
26+
def get_instance_specific_mounts(
27+
backend_type: BackendType,
28+
instance_type_name: str,
29+
) -> list[InstanceMountPoint]:
30+
if backend_type == BackendType.GCP:
31+
if instance_type_name == "a3-megagpu-8g":
32+
return [
33+
InstanceMountPoint(
34+
instance_path="/dev/aperture_devices",
35+
path="/dev/aperture_devices",
36+
),
37+
InstanceMountPoint(
38+
instance_path="/var/lib/tcpxo/lib64",
39+
path="/var/lib/tcpxo/lib64",
40+
),
41+
InstanceMountPoint(
42+
instance_path="/var/lib/fastrak/lib64",
43+
path="/var/lib/fastrak/lib64",
44+
),
45+
]
46+
if instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]:
47+
return [
48+
InstanceMountPoint(
49+
instance_path="/var/lib/nvidia/lib64",
50+
path="/usr/local/nvidia/lib64",
51+
),
52+
InstanceMountPoint(
53+
instance_path="/var/lib/nvidia/bin",
54+
path="/usr/local/nvidia/bin",
55+
),
56+
InstanceMountPoint(
57+
instance_path="/var/lib/tcpx/lib64",
58+
path="/usr/local/tcpx/lib64",
59+
),
60+
InstanceMountPoint(
61+
instance_path="/run/tcpx",
62+
path="/run/tcpx",
63+
),
64+
]
65+
return []
66+
67+
68+
def get_instance_specific_gpu_devices(
69+
backend_type: BackendType,
70+
instance_type_name: str,
71+
) -> list[GPUDevice]:
72+
gpu_devices = []
73+
if backend_type == BackendType.GCP and instance_type_name in [
74+
"a3-edgegpu-8g",
75+
"a3-highgpu-8g",
76+
]:
77+
for i in range(8):
78+
gpu_devices.append(
79+
GPUDevice(path_on_host=f"/dev/nvidia{i}", path_in_container=f"/dev/nvidia{i}")
80+
)
81+
gpu_devices.append(
82+
GPUDevice(path_on_host="/dev/nvidia-uvm", path_in_container="/dev/nvidia-uvm")
83+
)
84+
gpu_devices.append(
85+
GPUDevice(path_on_host="/dev/nvidiactl", path_in_container="/dev/nvidiactl")
86+
)
87+
return gpu_devices
88+
89+
90+
def resolve_provisioning_image_name(
91+
job_spec: JobSpec,
92+
job_provisioning_data: JobProvisioningData,
93+
) -> str:
94+
image_name = job_spec.image_name
95+
if job_provisioning_data.backend == BackendType.AWS:
96+
return _patch_base_image_for_aws_efa(
97+
image_name,
98+
job_provisioning_data.instance_type.name,
99+
)
100+
return image_name
101+
102+
103+
def _patch_base_image_for_aws_efa(
104+
image_name: str,
105+
instance_type_name: str,
106+
) -> str:
107+
is_efa_enabled = any(
108+
re.match(pattern, instance_type_name)
109+
for pattern in _AWS_EFA_ENABLED_INSTANCE_TYPE_PATTERNS
110+
)
111+
if not is_efa_enabled:
112+
return image_name
113+
114+
if not image_name.startswith(f"{settings.DSTACK_BASE_IMAGE}:"):
115+
return image_name
116+
117+
if image_name.endswith(f"-base-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"):
118+
return image_name[:-17] + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"
119+
if image_name.endswith(f"-devel-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"):
120+
return image_name[:-18] + f"-devel-efa-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}"
121+
122+
return image_name

src/dstack/_internal/server/services/jobs/configurators/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ def get_default_python_verison() -> str:
7777
def get_default_image(nvcc: bool = False) -> str:
7878
"""
7979
Note: May be overridden by dstack (e.g., EFA-enabled version for AWS EFA-capable instances).
80-
See `dstack._internal.server.background.scheduled_tasks.running_jobs._patch_base_image_for_aws_efa` for details.
80+
See `dstack._internal.server.services.backends.provisioning.resolve_provisioning_image_name`
81+
for details.
8182
8283
Args:
8384
nvcc: If True, returns 'devel' variant, otherwise 'base'.

0 commit comments

Comments
 (0)