Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/dstack/_internal/core/backends/gcp/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import google.cloud.compute_v1 as compute_v1
from cachetools import TTLCache, cachedmethod
from google.cloud import tpu_v2
from google.cloud.compute_v1.types.compute import Instance
from gpuhunt import KNOWN_TPUS

import dstack._internal.core.backends.gcp.auth as auth
Expand All @@ -19,6 +20,7 @@
ComputeWithGatewaySupport,
ComputeWithMultinodeSupport,
ComputeWithPlacementGroupSupport,
ComputeWithPrivateGatewaySupport,
ComputeWithVolumeSupport,
generate_unique_gateway_instance_name,
generate_unique_instance_name,
Expand Down Expand Up @@ -83,6 +85,7 @@ class GCPCompute(
ComputeWithMultinodeSupport,
ComputeWithPlacementGroupSupport,
ComputeWithGatewaySupport,
ComputeWithPrivateGatewaySupport,
ComputeWithVolumeSupport,
Compute,
):
Expand Down Expand Up @@ -395,11 +398,7 @@ def update_provisioning_data(
if instance.status in ["PROVISIONING", "STAGING"]:
return
if instance.status == "RUNNING":
if allocate_public_ip:
hostname = instance.network_interfaces[0].access_configs[0].nat_i_p
else:
hostname = instance.network_interfaces[0].network_i_p
provisioning_data.hostname = hostname
provisioning_data.hostname = _get_instance_ip(instance, allocate_public_ip)
provisioning_data.internal_ip = instance.network_interfaces[0].network_i_p
return
raise ProvisioningError(
Expand Down Expand Up @@ -512,6 +511,7 @@ def create_gateway(
service_account=self.config.vm_service_account,
network=self.config.vpc_resource_name,
subnetwork=subnetwork,
allocate_public_ip=configuration.public_ip,
)
operation = self.instances_client.insert(request=request)
gcp_resources.wait_for_extended_operation(operation, "instance creation")
Expand All @@ -522,7 +522,7 @@ def create_gateway(
instance_id=instance_name,
region=configuration.region, # used for instance termination
availability_zone=zone,
ip_address=instance.network_interfaces[0].access_configs[0].nat_i_p,
ip_address=_get_instance_ip(instance, configuration.public_ip),
backend_data=json.dumps({"zone": zone}),
)

Expand Down Expand Up @@ -1024,3 +1024,9 @@ def _is_tpu_provisioning_data(provisioning_data: JobProvisioningData) -> bool:
backend_data_dict = json.loads(provisioning_data.backend_data)
is_tpu = backend_data_dict.get("is_tpu", False)
return is_tpu


def _get_instance_ip(instance: Instance, public_ip: bool) -> str:
if public_ip:
return instance.network_interfaces[0].access_configs[0].nat_i_p
return instance.network_interfaces[0].network_i_p
Loading