Skip to content

Commit c10b1fe

Browse files
authored
Support specifying internal_ip for SSH fleet hosts (#2056)
* Support specifying internal_ip for SSH fleet hosts * Validate internal_ip * Handle client backward compatibility * Remove extra space
1 parent 3193792 commit c10b1fe

6 files changed

Lines changed: 105 additions & 19 deletions

File tree

src/dstack/_internal/core/models/fleets.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,29 @@ class SSHHostParams(CoreModel):
5050
identity_file: Annotated[
5151
Optional[str], Field(description="The private key to use for this host")
5252
] = None
53+
internal_ip: Annotated[
54+
Optional[str],
55+
Field(
56+
description=(
57+
"The internal IP of the host used for communication inside the cluster."
58+
" If not specified, `dstack` will use the IP address from `network` or from the first found internal network."
59+
)
60+
),
61+
] = None
5362
ssh_key: Optional[SSHKey] = None
5463

64+
@validator("internal_ip")
65+
def validate_internal_ip(cls, value):
66+
if value is None:
67+
return value
68+
try:
69+
internal_ip = ipaddress.ip_address(value)
70+
except ValueError as e:
71+
raise ValueError("Invalid IP address") from e
72+
if not internal_ip.is_private:
73+
raise ValueError("IP address is not private")
74+
return value
75+
5576

5677
class SSHParams(CoreModel):
5778
user: Annotated[Optional[str], Field(description="The user to log in with on all hosts")] = (
@@ -70,7 +91,13 @@ class SSHParams(CoreModel):
7091
]
7192
network: Annotated[
7293
Optional[str],
73-
Field(description="The network address for cluster setup in the format `<ip>/<netmask>`"),
94+
Field(
95+
description=(
96+
"The network address for cluster setup in the format `<ip>/<netmask>`."
97+
" `dstack` will use IP addresses from this network for communication between hosts."
98+
" If not specified, `dstack` will use IPs from the first found internal network."
99+
)
100+
),
74101
]
75102

76103
@validator("network")

src/dstack/_internal/server/background/tasks/process_instances.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
from dstack._internal.server.utils.common import run_async
9393
from dstack._internal.utils.common import get_current_datetime
9494
from dstack._internal.utils.logging import get_logger
95-
from dstack._internal.utils.network import get_ip_from_network
95+
from dstack._internal.utils.network import get_ip_from_network, is_ip_among_addresses
9696
from dstack._internal.utils.ssh import (
9797
pkey_from_str,
9898
)
@@ -290,16 +290,20 @@ async def _add_remote(instance: InstanceModel) -> None:
290290

291291
instance_type = host_info_to_instance_type(host_info)
292292
instance_network = None
293+
internal_ip = None
293294
try:
294295
default_jpd = JobProvisioningData.__response__.parse_raw(instance.job_provisioning_data)
295296
instance_network = default_jpd.instance_network
297+
internal_ip = default_jpd.internal_ip
296298
except ValidationError:
297299
pass
298300

299-
internal_ip = get_ip_from_network(
300-
network=instance_network,
301-
addresses=host_info.get("addresses", []),
302-
)
301+
host_network_addresses = host_info.get("addresses", [])
302+
if internal_ip is None:
303+
internal_ip = get_ip_from_network(
304+
network=instance_network,
305+
addresses=host_network_addresses,
306+
)
303307
if instance_network is not None and internal_ip is None:
304308
instance.status = InstanceStatus.TERMINATED
305309
instance.termination_reason = "Failed to locate internal IP address on the given network"
@@ -312,6 +316,21 @@ async def _add_remote(instance: InstanceModel) -> None:
312316
},
313317
)
314318
return
319+
if internal_ip is not None:
320+
if not is_ip_among_addresses(ip_address=internal_ip, addresses=host_network_addresses):
321+
instance.status = InstanceStatus.TERMINATED
322+
instance.termination_reason = (
323+
"Specified internal IP not found among instance interfaces"
324+
)
325+
logger.warning(
326+
"Failed to add instance %s: specified internal IP not found among instance interfaces",
327+
instance.name,
328+
extra={
329+
"instance_name": instance.name,
330+
"instance_status": InstanceStatus.TERMINATED.value,
331+
},
332+
)
333+
return
315334

316335
region = instance.region
317336
jpd = JobProvisioningData(

src/dstack/_internal/server/services/fleets.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,11 +402,13 @@ async def create_fleet_ssh_instance_model(
402402
ssh_user = ssh_params.user
403403
ssh_key = ssh_params.ssh_key
404404
port = ssh_params.port
405+
internal_ip = None
405406
else:
406407
hostname = host.hostname
407408
ssh_user = host.user or ssh_params.user
408409
ssh_key = host.ssh_key or ssh_params.ssh_key
409410
port = host.port or ssh_params.port
411+
internal_ip = host.internal_ip
410412

411413
if ssh_user is None or ssh_key is None:
412414
# This should not be reachable but checked by fleet spec validation
@@ -422,6 +424,7 @@ async def create_fleet_ssh_instance_model(
422424
ssh_user=ssh_user,
423425
ssh_keys=[ssh_key],
424426
env=env,
427+
internal_ip=internal_ip,
425428
instance_network=ssh_params.network,
426429
port=port or 22,
427430
)
@@ -678,6 +681,7 @@ def _validate_fleet_spec(spec: FleetSpec):
678681
for host in spec.configuration.ssh_config.hosts:
679682
if is_core_model_instance(host, SSHHostParams) and host.ssh_key is not None:
680683
_validate_ssh_key(host.ssh_key)
684+
_validate_internal_ips(spec.configuration.ssh_config)
681685

682686

683687
def _validate_all_ssh_params_specified(ssh_config: SSHParams):
@@ -706,6 +710,17 @@ def _validate_ssh_key(ssh_key: SSHKey):
706710
)
707711

708712

713+
def _validate_internal_ips(ssh_config: SSHParams):
714+
internal_ips_num = 0
715+
for host in ssh_config.hosts:
716+
if not isinstance(host, str) and host.internal_ip is not None:
717+
internal_ips_num += 1
718+
if internal_ips_num != 0 and internal_ips_num != len(ssh_config.hosts):
719+
raise ServerClientError("internal_ip must be specified for all hosts")
720+
if internal_ips_num > 0 and ssh_config.network is not None:
721+
raise ServerClientError("internal_ip is mutually exclusive with network")
722+
723+
709724
def _get_fleet_nodes_to_provision(spec: FleetSpec) -> int:
710725
if spec.configuration.nodes is None or spec.configuration.nodes.min is None:
711726
return 0

src/dstack/_internal/server/services/pools.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,7 @@ async def create_ssh_instance_model(
656656
pool: PoolModel,
657657
instance_name: str,
658658
instance_num: int,
659+
internal_ip: Optional[str],
659660
instance_network: Optional[str],
660661
region: Optional[str],
661662
host: str,
@@ -676,7 +677,7 @@ async def create_ssh_instance_model(
676677
instance_id=instance_name,
677678
hostname=host,
678679
region=host_region,
679-
internal_ip=None,
680+
internal_ip=internal_ip,
680681
instance_network=instance_network,
681682
price=0,
682683
username=ssh_user,

src/dstack/_internal/utils/network.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import ipaddress
2-
from typing import Optional, Sequence
2+
from typing import List, Optional, Sequence
33

44

55
def get_ip_from_network(network: Optional[str], addresses: Sequence[str]) -> Optional[str]:
@@ -32,3 +32,19 @@ def get_ip_from_network(network: Optional[str], addresses: Sequence[str]) -> Opt
3232
# return any ipv4
3333
internal_ip = str(ip_addresses[0]) if ip_addresses else None
3434
return internal_ip
35+
36+
37+
def is_ip_among_addresses(ip_address: str, addresses: Sequence[str]) -> bool:
38+
ip_addresses = get_ips_from_addresses(addresses)
39+
return ip_address in ip_addresses
40+
41+
42+
def get_ips_from_addresses(addresses: Sequence[str]) -> List[str]:
43+
ip_addresses = []
44+
for address in addresses:
45+
try:
46+
interface = ipaddress.IPv4Interface(address)
47+
ip_addresses.append(interface.ip)
48+
except ipaddress.AddressValueError:
49+
continue
50+
return [str(ip) for ip in ip_addresses]

src/dstack/api/server/_fleets.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import List, Optional
22

33
from pydantic import parse_obj_as
44

@@ -29,11 +29,7 @@ def get_plan(
2929
spec: FleetSpec,
3030
) -> FleetPlan:
3131
body = GetFleetPlanRequest(spec=spec)
32-
body_json = body.json()
33-
if spec.configuration_path is None:
34-
# Handle old server versions that do not accept configuration_path
35-
# TODO: Can be removed in 0.19
36-
body_json = body.json(exclude={"spec": {"configuration_path"}})
32+
body_json = body.json(exclude=_get_fleet_spec_excludes(spec))
3733
resp = self._request(f"/api/project/{project_name}/fleets/get_plan", body=body_json)
3834
return parse_obj_as(FleetPlan.__response__, resp.json())
3935

@@ -43,11 +39,7 @@ def create(
4339
spec: FleetSpec,
4440
) -> Fleet:
4541
body = CreateFleetRequest(spec=spec)
46-
body_json = body.json()
47-
if spec.configuration_path is None:
48-
# Handle old server versions that do not accept configuration_path
49-
# TODO: Can be removed in 0.19
50-
body_json = body.json(exclude={"spec": {"configuration_path"}})
42+
body_json = body.json(exclude=_get_fleet_spec_excludes(spec))
5143
resp = self._request(f"/api/project/{project_name}/fleets/create", body=body_json)
5244
return parse_obj_as(Fleet.__response__, resp.json())
5345

@@ -58,3 +50,19 @@ def delete(self, project_name: str, names: List[str]) -> None:
5850
def delete_instances(self, project_name: str, name: str, instance_nums: List[int]) -> None:
5951
body = DeleteFleetInstancesRequest(name=name, instance_nums=instance_nums)
6052
self._request(f"/api/project/{project_name}/fleets/delete_instances", body=body.json())
53+
54+
55+
def _get_fleet_spec_excludes(fleet_spec: FleetSpec) -> Optional[dict]:
56+
exclude = {}
57+
# TODO: Can be removed in 0.19
58+
if fleet_spec.configuration_path is None:
59+
exclude["spec"] = {"configuration_path"}
60+
if fleet_spec.configuration.ssh_config is not None:
61+
if all(
62+
isinstance(h, str) or h.internal_ip is None
63+
for h in fleet_spec.configuration.ssh_config.hosts
64+
):
65+
exclude["spec"] = {
66+
"configuration": {"ssh_config": {"hosts": {"__all__": {"internal_ip"}}}}
67+
}
68+
return exclude or None

0 commit comments

Comments
 (0)