Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 70 additions & 32 deletions src/dstack/_internal/core/backends/gcp/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def __init__(self, config: GCPConfig):
self.resource_policies_client = compute_v1.ResourcePoliciesClient(
credentials=self.credentials
)
self._extra_subnets_cache_lock = threading.Lock()
self._extra_subnets_cache = TTLCache(maxsize=30, ttl=60)
self._usable_subnets_cache_lock = threading.Lock()
self._usable_subnets_cache = TTLCache(maxsize=1, ttl=120)

def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]:
regions = get_or_error(self.config.regions)
Expand Down Expand Up @@ -203,12 +203,12 @@ def create_instance(
disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
# Choose any usable subnet in a VPC.
# Configuring a specific subnet per region is not supported yet.
subnetwork = _get_vpc_subnet(
subnetworks_client=self.subnetworks_client,
config=self.config,
subnetwork = self._get_vpc_subnet(instance_offer.region)
extra_subnets = self._get_extra_subnets(
region=instance_offer.region,
instance_type_name=instance_offer.instance.name,
)
extra_subnets = self._get_extra_subnets(
roce_subnets = self._get_roce_subnets(
region=instance_offer.region,
instance_type_name=instance_offer.instance.name,
)
Expand Down Expand Up @@ -330,6 +330,7 @@ def create_instance(
network=self.config.vpc_resource_name,
subnetwork=subnetwork,
extra_subnetworks=extra_subnets,
roce_subnetworks=roce_subnets,
allocate_public_ip=allocate_public_ip,
placement_policy=placement_policy,
)
Expand All @@ -339,6 +340,13 @@ def create_instance(
# If the request succeeds, we'll probably timeout and update_provisioning_data() will get hostname.
operation = self.instances_client.insert(request=request)
gcp_resources.wait_for_extended_operation(operation, timeout=30)
except google.api_core.exceptions.BadRequest as e:
if "Network profile only allows resource creation in location" in e.message:
# A hack to find the correct RoCE VPC zone by trial and error.
# Could be better to find it via the API.
logger.debug("Got GCP error when provisioning a VM: %s", e)
continue
raise
except (
google.api_core.exceptions.ServiceUnavailable,
google.api_core.exceptions.NotFound,
Expand Down Expand Up @@ -487,11 +495,7 @@ def create_gateway(
)
# Choose any usable subnet in a VPC.
# Configuring a specific subnet per region is not supported yet.
subnetwork = _get_vpc_subnet(
subnetworks_client=self.subnetworks_client,
config=self.config,
region=configuration.region,
)
subnetwork = self._get_vpc_subnet(configuration.region)

labels = {
"owner": "dstack",
Expand Down Expand Up @@ -793,10 +797,6 @@ def detach_volume(
instance_id,
)

@cachedmethod(
cache=lambda self: self._extra_subnets_cache,
lock=lambda self: self._extra_subnets_cache_lock,
)
def _get_extra_subnets(
self,
region: str,
Expand All @@ -808,15 +808,16 @@ def _get_extra_subnets(
subnets_num = 8
elif instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]:
subnets_num = 4
elif instance_type_name == "a4-highgpu-8g":
subnets_num = 1 # 1 main + 1 extra + 8 RoCE
else:
return []
extra_subnets = []
for vpc_name in self.config.extra_vpcs[:subnets_num]:
subnet = gcp_resources.get_vpc_subnet_or_error(
subnetworks_client=self.subnetworks_client,
vpc_project_id=self.config.vpc_project_id or self.config.project_id,
vpc_name=vpc_name,
region=region,
usable_subnets=self._list_usable_subnets(),
)
vpc_resource_name = gcp_resources.vpc_name_to_vpc_resource_name(
project_id=self.config.vpc_project_id or self.config.project_id,
Expand All @@ -825,6 +826,58 @@ def _get_extra_subnets(
extra_subnets.append((vpc_resource_name, subnet))
return extra_subnets

def _get_roce_subnets(
self,
region: str,
instance_type_name: str,
) -> List[Tuple[str, str]]:
if not self.config.roce_vpcs:
return []
if instance_type_name == "a4-highgpu-8g":
nics_num = 8
else:
return []
roce_vpc = self.config.roce_vpcs[0] # roce_vpcs is validated to have at most 1 item
subnets = gcp_resources.get_vpc_subnets(
vpc_name=roce_vpc,
region=region,
usable_subnets=self._list_usable_subnets(),
)
if len(subnets) < nics_num:
raise ComputeError(
f"{instance_type_name} requires {nics_num} RoCE subnets,"
f" but only {len(subnets)} are available in VPC {roce_vpc}"
)
vpc_resource_name = gcp_resources.vpc_name_to_vpc_resource_name(
project_id=self.config.vpc_project_id or self.config.project_id,
vpc_name=roce_vpc,
)
nic_subnets = []
for subnet in subnets[:nics_num]:
nic_subnets.append((vpc_resource_name, subnet))
return nic_subnets

@cachedmethod(
cache=lambda self: self._usable_subnets_cache,
lock=lambda self: self._usable_subnets_cache_lock,
)
def _list_usable_subnets(self) -> list[compute_v1.UsableSubnetwork]:
# To avoid hitting the `ListUsable requests per minute` system limit, we fetch all subnets
# at once and cache them
return gcp_resources.list_project_usable_subnets(
subnetworks_client=self.subnetworks_client,
project_id=self.config.vpc_project_id or self.config.project_id,
)

def _get_vpc_subnet(self, region: str) -> Optional[str]:
if self.config.vpc_name is None:
return None
return gcp_resources.get_vpc_subnet_or_error(
vpc_name=self.config.vpc_name,
region=region,
usable_subnets=self._list_usable_subnets(),
)


def _supported_instances_and_zones(
regions: List[str],
Expand Down Expand Up @@ -889,21 +942,6 @@ def _unique_instance_name(instance: InstanceType) -> str:
return f"{name}-{gpu.name}-{gpu.memory_mib}"


def _get_vpc_subnet(
subnetworks_client: compute_v1.SubnetworksClient,
config: GCPConfig,
region: str,
) -> Optional[str]:
if config.vpc_name is None:
return None
return gcp_resources.get_vpc_subnet_or_error(
subnetworks_client=subnetworks_client,
vpc_project_id=config.vpc_project_id or config.project_id,
vpc_name=config.vpc_name,
region=region,
)


@dataclass
class GCPImage:
id: str
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/core/backends/gcp/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,5 +202,5 @@ def _check_config_vpc(
)
except BackendError as e:
raise ServerClientError(e.args[0])
# Not checking config.extra_vpc so that users are not required to configure subnets for all regions
# Not checking config.extra_vpcs and config.roce_vpcs so that users are not required to configure subnets for all regions
# but only for regions they intend to use. Validation will be done on provisioning.
15 changes: 14 additions & 1 deletion src/dstack/_internal/core/backends/gcp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,24 @@ class GCPBackendConfig(CoreModel):
Optional[List[str]],
Field(
description=(
"The names of additional VPCs used for GPUDirect. Specify eight VPCs to maximize bandwidth."
"The names of additional VPCs used for multi-NIC instances, such as those that support GPUDirect."
" Specify eight VPCs to maximize bandwidth in clusters with eight-GPU instances."
" Each VPC must have a subnet and a firewall rule allowing internal traffic across all subnets"
)
),
] = None
roce_vpcs: Annotated[
Optional[List[str]],
Field(
description=(
"The names of additional VPCs with the RoCE network profile."
" Used for RDMA on GPU instances that support the MRDMA interface type."
" A VPC should have eight subnets to maximize the bandwidth in clusters"
" with eight-GPU instances."
),
max_items=1, # The currently supported instance types only need one VPC with eight subnets.
),
] = None
vpc_project_id: Annotated[
Optional[str],
Field(description="The shared VPC hosted project ID. Required for shared VPC only"),
Expand Down
45 changes: 33 additions & 12 deletions src/dstack/_internal/core/backends/gcp/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ def check_vpc(
)
for region in regions:
get_vpc_subnet_or_error(
subnetworks_client=subnetworks_client,
vpc_project_id=vpc_project_id,
vpc_name=vpc_name,
region=region,
usable_subnets=usable_subnets,
Expand Down Expand Up @@ -122,6 +120,7 @@ def create_instance_struct(
network: str = "global/networks/default",
subnetwork: Optional[str] = None,
extra_subnetworks: Optional[List[Tuple[str, str]]] = None,
roce_subnetworks: Optional[List[Tuple[str, str]]] = None,
allocate_public_ip: bool = True,
placement_policy: Optional[str] = None,
) -> compute_v1.Instance:
Expand All @@ -133,6 +132,7 @@ def create_instance_struct(
subnetwork=subnetwork,
allocate_public_ip=allocate_public_ip,
extra_subnetworks=extra_subnetworks,
roce_subnetworks=roce_subnetworks,
)

disk = compute_v1.AttachedDisk()
Expand Down Expand Up @@ -195,6 +195,7 @@ def _get_network_interfaces(
subnetwork: Optional[str],
allocate_public_ip: bool,
extra_subnetworks: Optional[List[Tuple[str, str]]],
roce_subnetworks: Optional[List[Tuple[str, str]]],
) -> List[compute_v1.NetworkInterface]:
network_interface = compute_v1.NetworkInterface()
network_interface.network = network
Expand Down Expand Up @@ -222,6 +223,14 @@ def _get_network_interfaces(
nic_type=compute_v1.NetworkInterface.NicType.GVNIC.name,
)
)
for network, subnetwork in roce_subnetworks or []:
network_interfaces.append(
compute_v1.NetworkInterface(
network=network,
subnetwork=subnetwork,
nic_type=compute_v1.NetworkInterface.NicType.MRDMA.name,
)
)
return network_interfaces


Expand All @@ -234,29 +243,41 @@ def list_project_usable_subnets(


def get_vpc_subnet_or_error(
subnetworks_client: compute_v1.SubnetworksClient,
vpc_project_id: str,
vpc_name: str,
region: str,
usable_subnets: Optional[List[compute_v1.UsableSubnetwork]] = None,
usable_subnets: list[compute_v1.UsableSubnetwork],
) -> str:
"""
Returns resource name of any usable subnet in a given VPC
(e.g. "projects/example-project/regions/europe-west4/subnetworks/example-subnet")
"""
if usable_subnets is None:
usable_subnets = list_project_usable_subnets(subnetworks_client, vpc_project_id)
vpc_subnets = get_vpc_subnets(vpc_name, region, usable_subnets)
if vpc_subnets:
return vpc_subnets[0]
raise ComputeError(
f"No usable subnetwork found in region {region} for VPC {vpc_name}."
f" Ensure that VPC {vpc_name} exists and has usable subnetworks."
)


def get_vpc_subnets(
vpc_name: str,
region: str,
usable_subnets: list[compute_v1.UsableSubnetwork],
) -> list[str]:
"""
Returns resource names of all usable subnets in a given VPC
(e.g. ["projects/example-project/regions/europe-west4/subnetworks/example-subnet"])
"""
result = []
for subnet in usable_subnets:
network_name = subnet.network.split("/")[-1]
subnet_url = subnet.subnetwork
subnet_resource_name = remove_prefix(subnet_url, "https://www.googleapis.com/compute/v1/")
subnet_region = subnet_resource_name.split("/")[3]
if network_name == vpc_name and subnet_region == region:
return subnet_resource_name
raise ComputeError(
f"No usable subnetwork found in region {region} for VPC {vpc_name} in project {vpc_project_id}."
f" Ensure that VPC {vpc_name} exists and has usable subnetworks."
)
result.append(subnet_resource_name)
return result


def create_runner_firewall_rules(
Expand Down