diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index 965d04af18..638a778545 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -111,8 +111,8 @@ def __init__(self, config: GCPConfig): self.resource_policies_client = compute_v1.ResourcePoliciesClient( credentials=self.credentials ) - self._extra_subnets_cache_lock = threading.Lock() - self._extra_subnets_cache = TTLCache(maxsize=30, ttl=60) + self._usable_subnets_cache_lock = threading.Lock() + self._usable_subnets_cache = TTLCache(maxsize=1, ttl=120) def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]: regions = get_or_error(self.config.regions) @@ -203,12 +203,12 @@ def create_instance( disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024) # Choose any usable subnet in a VPC. # Configuring a specific subnet per region is not supported yet. - subnetwork = _get_vpc_subnet( - subnetworks_client=self.subnetworks_client, - config=self.config, + subnetwork = self._get_vpc_subnet(instance_offer.region) + extra_subnets = self._get_extra_subnets( region=instance_offer.region, + instance_type_name=instance_offer.instance.name, ) - extra_subnets = self._get_extra_subnets( + roce_subnets = self._get_roce_subnets( region=instance_offer.region, instance_type_name=instance_offer.instance.name, ) @@ -330,6 +330,7 @@ def create_instance( network=self.config.vpc_resource_name, subnetwork=subnetwork, extra_subnetworks=extra_subnets, + roce_subnetworks=roce_subnets, allocate_public_ip=allocate_public_ip, placement_policy=placement_policy, ) @@ -339,6 +340,13 @@ def create_instance( # If the request succeeds, we'll probably timeout and update_provisioning_data() will get hostname. operation = self.instances_client.insert(request=request) gcp_resources.wait_for_extended_operation(operation, timeout=30) + except google.api_core.exceptions.BadRequest as e: + if "Network profile only allows resource creation in location" in e.message: + # A hack to find the correct RoCE VPC zone by trial and error. + # Could be better to find it via the API. + logger.debug("Got GCP error when provisioning a VM: %s", e) + continue + raise except ( google.api_core.exceptions.ServiceUnavailable, google.api_core.exceptions.NotFound, @@ -487,11 +495,7 @@ def create_gateway( ) # Choose any usable subnet in a VPC. # Configuring a specific subnet per region is not supported yet. - subnetwork = _get_vpc_subnet( - subnetworks_client=self.subnetworks_client, - config=self.config, - region=configuration.region, - ) + subnetwork = self._get_vpc_subnet(configuration.region) labels = { "owner": "dstack", @@ -793,10 +797,6 @@ def detach_volume( instance_id, ) - @cachedmethod( - cache=lambda self: self._extra_subnets_cache, - lock=lambda self: self._extra_subnets_cache_lock, - ) def _get_extra_subnets( self, region: str, @@ -808,15 +808,16 @@ def _get_extra_subnets( subnets_num = 8 elif instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]: subnets_num = 4 + elif instance_type_name == "a4-highgpu-8g": + subnets_num = 1 # 1 main + 1 extra + 8 RoCE else: return [] extra_subnets = [] for vpc_name in self.config.extra_vpcs[:subnets_num]: subnet = gcp_resources.get_vpc_subnet_or_error( - subnetworks_client=self.subnetworks_client, - vpc_project_id=self.config.vpc_project_id or self.config.project_id, vpc_name=vpc_name, region=region, + usable_subnets=self._list_usable_subnets(), ) vpc_resource_name = gcp_resources.vpc_name_to_vpc_resource_name( project_id=self.config.vpc_project_id or self.config.project_id, @@ -825,6 +826,58 @@ def _get_extra_subnets( extra_subnets.append((vpc_resource_name, subnet)) return extra_subnets + def _get_roce_subnets( + self, + region: str, + instance_type_name: str, + ) -> List[Tuple[str, str]]: + if not self.config.roce_vpcs: + return [] + if instance_type_name == "a4-highgpu-8g": + nics_num = 8 + else: + return [] + roce_vpc = self.config.roce_vpcs[0] # roce_vpcs is validated to have at most 1 item + subnets = gcp_resources.get_vpc_subnets( + vpc_name=roce_vpc, + region=region, + usable_subnets=self._list_usable_subnets(), + ) + if len(subnets) < nics_num: + raise ComputeError( + f"{instance_type_name} requires {nics_num} RoCE subnets," + f" but only {len(subnets)} are available in VPC {roce_vpc}" + ) + vpc_resource_name = gcp_resources.vpc_name_to_vpc_resource_name( + project_id=self.config.vpc_project_id or self.config.project_id, + vpc_name=roce_vpc, + ) + nic_subnets = [] + for subnet in subnets[:nics_num]: + nic_subnets.append((vpc_resource_name, subnet)) + return nic_subnets + + @cachedmethod( + cache=lambda self: self._usable_subnets_cache, + lock=lambda self: self._usable_subnets_cache_lock, + ) + def _list_usable_subnets(self) -> list[compute_v1.UsableSubnetwork]: + # To avoid hitting the `ListUsable requests per minute` system limit, we fetch all subnets + # at once and cache them + return gcp_resources.list_project_usable_subnets( + subnetworks_client=self.subnetworks_client, + project_id=self.config.vpc_project_id or self.config.project_id, + ) + + def _get_vpc_subnet(self, region: str) -> Optional[str]: + if self.config.vpc_name is None: + return None + return gcp_resources.get_vpc_subnet_or_error( + vpc_name=self.config.vpc_name, + region=region, + usable_subnets=self._list_usable_subnets(), + ) + def _supported_instances_and_zones( regions: List[str], @@ -889,21 +942,6 @@ def _unique_instance_name(instance: InstanceType) -> str: return f"{name}-{gpu.name}-{gpu.memory_mib}" -def _get_vpc_subnet( - subnetworks_client: compute_v1.SubnetworksClient, - config: GCPConfig, - region: str, -) -> Optional[str]: - if config.vpc_name is None: - return None - return gcp_resources.get_vpc_subnet_or_error( - subnetworks_client=subnetworks_client, - vpc_project_id=config.vpc_project_id or config.project_id, - vpc_name=config.vpc_name, - region=region, - ) - - @dataclass class GCPImage: id: str diff --git a/src/dstack/_internal/core/backends/gcp/configurator.py b/src/dstack/_internal/core/backends/gcp/configurator.py index f59decd122..c40aa2b7d3 100644 --- a/src/dstack/_internal/core/backends/gcp/configurator.py +++ b/src/dstack/_internal/core/backends/gcp/configurator.py @@ -202,5 +202,5 @@ def _check_config_vpc( ) except BackendError as e: raise ServerClientError(e.args[0]) - # Not checking config.extra_vpc so that users are not required to configure subnets for all regions + # Not checking config.extra_vpcs and config.roce_vpcs so that users are not required to configure subnets for all regions # but only for regions they intend to use. Validation will be done on provisioning. diff --git a/src/dstack/_internal/core/backends/gcp/models.py b/src/dstack/_internal/core/backends/gcp/models.py index 15807f9f8f..00c9492be3 100644 --- a/src/dstack/_internal/core/backends/gcp/models.py +++ b/src/dstack/_internal/core/backends/gcp/models.py @@ -41,11 +41,24 @@ class GCPBackendConfig(CoreModel): Optional[List[str]], Field( description=( - "The names of additional VPCs used for GPUDirect. Specify eight VPCs to maximize bandwidth." + "The names of additional VPCs used for multi-NIC instances, such as those that support GPUDirect." + " Specify eight VPCs to maximize bandwidth in clusters with eight-GPU instances." " Each VPC must have a subnet and a firewall rule allowing internal traffic across all subnets" ) ), ] = None + roce_vpcs: Annotated[ + Optional[List[str]], + Field( + description=( + "The names of additional VPCs with the RoCE network profile." + " Used for RDMA on GPU instances that support the MRDMA interface type." + " A VPC should have eight subnets to maximize the bandwidth in clusters" + " with eight-GPU instances." + ), + max_items=1, # The currently supported instance types only need one VPC with eight subnets. + ), + ] = None vpc_project_id: Annotated[ Optional[str], Field(description="The shared VPC hosted project ID. Required for shared VPC only"), diff --git a/src/dstack/_internal/core/backends/gcp/resources.py b/src/dstack/_internal/core/backends/gcp/resources.py index b22012f63f..57f2f8548c 100644 --- a/src/dstack/_internal/core/backends/gcp/resources.py +++ b/src/dstack/_internal/core/backends/gcp/resources.py @@ -59,8 +59,6 @@ def check_vpc( ) for region in regions: get_vpc_subnet_or_error( - subnetworks_client=subnetworks_client, - vpc_project_id=vpc_project_id, vpc_name=vpc_name, region=region, usable_subnets=usable_subnets, @@ -122,6 +120,7 @@ def create_instance_struct( network: str = "global/networks/default", subnetwork: Optional[str] = None, extra_subnetworks: Optional[List[Tuple[str, str]]] = None, + roce_subnetworks: Optional[List[Tuple[str, str]]] = None, allocate_public_ip: bool = True, placement_policy: Optional[str] = None, ) -> compute_v1.Instance: @@ -133,6 +132,7 @@ def create_instance_struct( subnetwork=subnetwork, allocate_public_ip=allocate_public_ip, extra_subnetworks=extra_subnetworks, + roce_subnetworks=roce_subnetworks, ) disk = compute_v1.AttachedDisk() @@ -195,6 +195,7 @@ def _get_network_interfaces( subnetwork: Optional[str], allocate_public_ip: bool, extra_subnetworks: Optional[List[Tuple[str, str]]], + roce_subnetworks: Optional[List[Tuple[str, str]]], ) -> List[compute_v1.NetworkInterface]: network_interface = compute_v1.NetworkInterface() network_interface.network = network @@ -222,6 +223,14 @@ def _get_network_interfaces( nic_type=compute_v1.NetworkInterface.NicType.GVNIC.name, ) ) + for network, subnetwork in roce_subnetworks or []: + network_interfaces.append( + compute_v1.NetworkInterface( + network=network, + subnetwork=subnetwork, + nic_type=compute_v1.NetworkInterface.NicType.MRDMA.name, + ) + ) return network_interfaces @@ -234,29 +243,41 @@ def list_project_usable_subnets( def get_vpc_subnet_or_error( - subnetworks_client: compute_v1.SubnetworksClient, - vpc_project_id: str, vpc_name: str, region: str, - usable_subnets: Optional[List[compute_v1.UsableSubnetwork]] = None, + usable_subnets: list[compute_v1.UsableSubnetwork], ) -> str: """ Returns resource name of any usable subnet in a given VPC (e.g. "projects/example-project/regions/europe-west4/subnetworks/example-subnet") """ - if usable_subnets is None: - usable_subnets = list_project_usable_subnets(subnetworks_client, vpc_project_id) + vpc_subnets = get_vpc_subnets(vpc_name, region, usable_subnets) + if vpc_subnets: + return vpc_subnets[0] + raise ComputeError( + f"No usable subnetwork found in region {region} for VPC {vpc_name}." + f" Ensure that VPC {vpc_name} exists and has usable subnetworks." + ) + + +def get_vpc_subnets( + vpc_name: str, + region: str, + usable_subnets: list[compute_v1.UsableSubnetwork], +) -> list[str]: + """ + Returns resource names of all usable subnets in a given VPC + (e.g. ["projects/example-project/regions/europe-west4/subnetworks/example-subnet"]) + """ + result = [] for subnet in usable_subnets: network_name = subnet.network.split("/")[-1] subnet_url = subnet.subnetwork subnet_resource_name = remove_prefix(subnet_url, "https://www.googleapis.com/compute/v1/") subnet_region = subnet_resource_name.split("/")[3] if network_name == vpc_name and subnet_region == region: - return subnet_resource_name - raise ComputeError( - f"No usable subnetwork found in region {region} for VPC {vpc_name} in project {vpc_project_id}." - f" Ensure that VPC {vpc_name} exists and has usable subnetworks." - ) + result.append(subnet_resource_name) + return result def create_runner_firewall_rules(