Skip to content

Commit ae0835e

Browse files
authored
Support services with head node setup (#2299)
* Add proxy chain support to `SSHTunnel` * Add optional head proxy fields to `Replica` * Extend gateway API to support head proxy fields Closes: #2010
1 parent 5cc97ae commit ae0835e

14 files changed

Lines changed: 214 additions & 62 deletions

File tree

docs/docs/concepts/fleets.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,6 @@ add a front node key (`~/.ssh/head_node_key`) to an SSH agent or configure a key
339339
where `Host` must match `ssh_config.proxy_jump.hostname` or `ssh_config.hosts[n].proxy_jump.hostname` if you configure head nodes
340340
on a per-worker basis.
341341

342-
> Currently, [services](services.md) do not work on instances with a head node setup.
343-
344342
!!! info "Reference"
345343
For all SSH fleet configuration options, refer to the [reference](../reference/dstack.yml/fleet.md).
346344

src/dstack/_internal/core/services/ssh/tunnel.py

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -69,44 +69,38 @@ def __init__(
6969
options: Dict[str, str] = SSH_DEFAULT_OPTIONS,
7070
ssh_config_path: Union[PathLike, Literal["none"]] = "none",
7171
port: Optional[int] = None,
72-
ssh_proxy: Optional[SSHConnectionParams] = None,
73-
ssh_proxy_identity: Optional[FilePathOrContent] = None,
72+
ssh_proxies: Iterable[tuple[SSHConnectionParams, Optional[FilePathOrContent]]] = (),
7473
):
7574
"""
7675
:param forwarded_sockets: Connections to the specified local sockets will be
7776
forwarded to their corresponding remote sockets
7877
:param reverse_forwarded_sockets: Connections to the specified remote sockets
7978
will be forwarded to their corresponding local sockets
79+
:param ssh_proxies: pairs of SSH connections params and optional identities,
80+
in order from outer to inner. If an identity is `None`, the `identity` param
81+
is used instead.
8082
"""
8183
self.destination = destination
8284
self.forwarded_sockets = list(forwarded_sockets)
8385
self.reverse_forwarded_sockets = list(reverse_forwarded_sockets)
8486
self.options = options
8587
self.port = port
8688
self.ssh_config_path = normalize_path(ssh_config_path)
87-
self.ssh_proxy = ssh_proxy
8889
temp_dir = tempfile.TemporaryDirectory()
8990
self.temp_dir = temp_dir
9091
if control_sock_path is None:
9192
control_sock_path = os.path.join(temp_dir.name, "control.sock")
9293
self.control_sock_path = normalize_path(control_sock_path)
93-
if isinstance(identity, FilePath):
94-
identity_path = identity.path
95-
else:
96-
identity_path = os.path.join(temp_dir.name, "identity")
97-
with open(
98-
identity_path, opener=lambda path, flags: os.open(path, flags, 0o600), mode="w"
99-
) as f:
100-
f.write(identity.content)
10194
self.identity_path = normalize_path(self._get_identity_path(identity, "identity"))
102-
if ssh_proxy_identity is not None:
103-
self.ssh_proxy_identity_path = normalize_path(
104-
self._get_identity_path(ssh_proxy_identity, "proxy_identity")
105-
)
106-
elif ssh_proxy is not None:
107-
self.ssh_proxy_identity_path = self.identity_path
108-
else:
109-
self.ssh_proxy_identity_path = None
95+
self.ssh_proxies: list[tuple[SSHConnectionParams, PathLike]] = []
96+
for proxy_index, (proxy_params, proxy_identity) in enumerate(ssh_proxies):
97+
if proxy_identity is None:
98+
proxy_identity_path = self.identity_path
99+
else:
100+
proxy_identity_path = self._get_identity_path(
101+
proxy_identity, f"proxy_identity_{proxy_index}"
102+
)
103+
self.ssh_proxies.append((proxy_params, proxy_identity_path))
110104
self.log_path = normalize_path(os.path.join(temp_dir.name, "tunnel.log"))
111105
self.ssh_client_info = get_ssh_client_info()
112106
self.ssh_exec_path = str(self.ssh_client_info.path)
@@ -151,8 +145,8 @@ def open_command(self) -> List[str]:
151145
command += ["-p", str(self.port)]
152146
for k, v in self.options.items():
153147
command += ["-o", f"{k}={v}"]
154-
if proxy_command := self.proxy_command():
155-
command += ["-o", "ProxyCommand=" + shlex.join(proxy_command)]
148+
if proxy_command := self._get_proxy_command():
149+
command += ["-o", proxy_command]
156150
for socket_pair in self.forwarded_sockets:
157151
command += ["-L", f"{socket_pair.local.render()}:{socket_pair.remote.render()}"]
158152
for socket_pair in self.reverse_forwarded_sockets:
@@ -169,24 +163,6 @@ def check_command(self) -> List[str]:
169163
def exec_command(self) -> List[str]:
170164
return [self.ssh_exec_path, "-S", self.control_sock_path, self.destination]
171165

172-
def proxy_command(self) -> Optional[List[str]]:
173-
if self.ssh_proxy is None:
174-
return None
175-
return [
176-
self.ssh_exec_path,
177-
"-i",
178-
self.ssh_proxy_identity_path,
179-
"-W",
180-
"%h:%p",
181-
"-o",
182-
"StrictHostKeyChecking=no",
183-
"-o",
184-
"UserKnownHostsFile=/dev/null",
185-
"-p",
186-
str(self.ssh_proxy.port),
187-
f"{self.ssh_proxy.username}@{self.ssh_proxy.hostname}",
188-
]
189-
190166
def open(self) -> None:
191167
# We cannot use `stderr=subprocess.PIPE` here since the forked process (daemon) does not
192168
# close standard streams if ProxyJump is used, therefore we will wait EOF from the pipe
@@ -260,6 +236,38 @@ def __enter__(self):
260236
def __exit__(self, exc_type, exc_val, exc_tb):
261237
self.close()
262238

239+
def _get_proxy_command(self) -> Optional[str]:
240+
proxy_command: Optional[str] = None
241+
for params, identity_path in self.ssh_proxies:
242+
proxy_command = self._build_proxy_command(params, identity_path, proxy_command)
243+
return proxy_command
244+
245+
def _build_proxy_command(
246+
self,
247+
params: SSHConnectionParams,
248+
identity_path: PathLike,
249+
prev_proxy_command: Optional[str],
250+
) -> Optional[str]:
251+
command = [
252+
self.ssh_exec_path,
253+
"-i",
254+
identity_path,
255+
"-W",
256+
"%h:%p",
257+
"-o",
258+
"StrictHostKeyChecking=no",
259+
"-o",
260+
"UserKnownHostsFile=/dev/null",
261+
]
262+
if prev_proxy_command is not None:
263+
command += ["-o", prev_proxy_command.replace("%", "%%")]
264+
command += [
265+
"-p",
266+
str(params.port),
267+
f"{params.username}@{params.hostname}",
268+
]
269+
return "ProxyCommand=" + shlex.join(command)
270+
263271
def _read_log_file(self) -> bytes:
264272
with open(self.log_path, "rb") as f:
265273
return f.read()

src/dstack/_internal/proxy/gateway/routers/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ async def register_replica(
7676
ssh_destination=body.ssh_host,
7777
ssh_port=body.ssh_port,
7878
ssh_proxy=body.ssh_proxy,
79+
ssh_head_proxy=body.ssh_head_proxy,
80+
ssh_head_proxy_private_key=body.ssh_head_proxy_private_key,
7981
repo=repo,
8082
nginx=nginx,
8183
service_conn_pool=service_conn_pool,

src/dstack/_internal/proxy/gateway/schemas/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class RegisterReplicaRequest(BaseModel):
5050
ssh_host: str
5151
ssh_port: int
5252
ssh_proxy: Optional[SSHConnectionParams]
53+
ssh_head_proxy: Optional[SSHConnectionParams]
54+
ssh_head_proxy_private_key: Optional[str]
5355

5456

5557
class RegisterEntrypointRequest(BaseModel):

src/dstack/_internal/proxy/gateway/services/registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ async def register_replica(
123123
ssh_destination: str,
124124
ssh_port: int,
125125
ssh_proxy: Optional[SSHConnectionParams],
126+
ssh_head_proxy: Optional[SSHConnectionParams],
127+
ssh_head_proxy_private_key: Optional[str],
126128
repo: GatewayProxyRepo,
127129
nginx: Nginx,
128130
service_conn_pool: ServiceConnectionPool,
@@ -133,6 +135,8 @@ async def register_replica(
133135
ssh_destination=ssh_destination,
134136
ssh_port=ssh_port,
135137
ssh_proxy=ssh_proxy,
138+
ssh_head_proxy=ssh_head_proxy,
139+
ssh_head_proxy_private_key=ssh_head_proxy_private_key,
136140
)
137141

138142
async with lock:

src/dstack/_internal/proxy/lib/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ class Replica(ImmutableModel):
2323
ssh_destination: str
2424
ssh_port: int
2525
ssh_proxy: Optional[SSHConnectionParams]
26+
# Optional outer proxy, a head node/bastion
27+
ssh_head_proxy: Optional[SSHConnectionParams] = None
28+
ssh_head_proxy_private_key: Optional[str] = None
2629

2730

2831
class Service(ImmutableModel):

src/dstack/_internal/proxy/lib/services/service_connection.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from dstack._internal.proxy.lib.errors import UnexpectedProxyError
1919
from dstack._internal.proxy.lib.models import Project, Replica, Service
2020
from dstack._internal.proxy.lib.repo import BaseProxyRepo
21+
from dstack._internal.utils.common import get_or_error
2122
from dstack._internal.utils.logging import get_logger
2223
from dstack._internal.utils.path import FileContent
2324

@@ -45,10 +46,16 @@ def __init__(self, project: Project, service: Service, replica: Replica) -> None
4546
os.chmod(self._temp_dir.name, 0o755)
4647
options["StreamLocalBindMask"] = "0111"
4748
self._app_socket_path = (Path(self._temp_dir.name) / "replica.sock").absolute()
49+
ssh_proxies = []
50+
if replica.ssh_head_proxy is not None:
51+
ssh_head_proxy_private_key = get_or_error(replica.ssh_head_proxy_private_key)
52+
ssh_proxies.append((replica.ssh_head_proxy, FileContent(ssh_head_proxy_private_key)))
53+
if replica.ssh_proxy is not None:
54+
ssh_proxies.append((replica.ssh_proxy, None))
4855
self._tunnel = SSHTunnel(
4956
destination=replica.ssh_destination,
5057
port=replica.ssh_port,
51-
ssh_proxy=replica.ssh_proxy,
58+
ssh_proxies=ssh_proxies,
5259
identity=FileContent(project.ssh_private_key),
5360
forwarded_sockets=[
5461
SocketPair(

src/dstack/_internal/server/background/tasks/process_running_jobs.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
from dstack._internal.core.models.backends.base import BackendType
1212
from dstack._internal.core.models.common import NetworkMode, RegistryAuth, is_core_model_instance
1313
from dstack._internal.core.models.configurations import DevEnvironmentConfiguration
14-
from dstack._internal.core.models.instances import InstanceStatus, RemoteConnectionInfo
14+
from dstack._internal.core.models.instances import (
15+
InstanceStatus,
16+
RemoteConnectionInfo,
17+
SSHConnectionParams,
18+
)
1519
from dstack._internal.core.models.repos import RemoteRepoCreds
1620
from dstack._internal.core.models.runs import (
1721
ClusterInfo,
@@ -308,8 +312,24 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
308312
and job_model.job_num == 0 # gateway connects only to the first node
309313
and run.run_spec.configuration.type == "service"
310314
):
315+
ssh_head_proxy: Optional[SSHConnectionParams] = None
316+
ssh_head_proxy_private_key: Optional[str] = None
317+
instance = common_utils.get_or_error(job_model.instance)
318+
if instance.remote_connection_info is not None:
319+
rci = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info)
320+
if rci.ssh_proxy is not None:
321+
ssh_head_proxy = rci.ssh_proxy
322+
ssh_head_proxy_keys = common_utils.get_or_error(rci.ssh_proxy_keys)
323+
ssh_head_proxy_private_key = ssh_head_proxy_keys[0].private
311324
try:
312-
await services.register_replica(session, run_model.gateway_id, run, job_model)
325+
await services.register_replica(
326+
session,
327+
run_model.gateway_id,
328+
run,
329+
job_model,
330+
ssh_head_proxy,
331+
ssh_head_proxy_private_key,
332+
)
313333
except GatewayError as e:
314334
logger.warning(
315335
"%s: failed to register service replica: %s, age=%s",

src/dstack/_internal/server/services/gateways/client.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,18 @@ async def unregister_service(self, project: str, run_name: str):
7474
resp.raise_for_status()
7575
self.is_server_ready = True
7676

77-
async def register_replica(self, run: Run, job_submission: JobSubmission):
77+
async def register_replica(
78+
self,
79+
run: Run,
80+
job_submission: JobSubmission,
81+
ssh_head_proxy: Optional[SSHConnectionParams],
82+
ssh_head_proxy_private_key: Optional[str],
83+
):
7884
payload = {
7985
"job_id": job_submission.id.hex,
8086
"app_port": run.run_spec.configuration.port.container_port,
87+
"ssh_head_proxy": ssh_head_proxy.dict() if ssh_head_proxy is not None else None,
88+
"ssh_head_proxy_private_key": ssh_head_proxy_private_key,
8189
}
8290
jpd = job_submission.job_provisioning_data
8391
if not jpd.dockerized:

src/dstack/_internal/server/services/proxy/repo.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
1010
from dstack._internal.core.models.common import is_core_model_instance
1111
from dstack._internal.core.models.configurations import ServiceConfiguration
12-
from dstack._internal.core.models.instances import SSHConnectionParams
12+
from dstack._internal.core.models.instances import RemoteConnectionInfo, SSHConnectionParams
1313
from dstack._internal.core.models.runs import (
1414
JobProvisioningData,
1515
JobStatus,
@@ -30,6 +30,7 @@
3030
from dstack._internal.proxy.lib.repo import BaseProxyRepo
3131
from dstack._internal.server.models import JobModel, ProjectModel, RunModel
3232
from dstack._internal.server.settings import DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE
33+
from dstack._internal.utils.common import get_or_error
3334

3435

3536
class ServerProxyRepo(BaseProxyRepo):
@@ -53,9 +54,12 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
5354
JobModel.status == JobStatus.RUNNING,
5455
JobModel.job_num == 0,
5556
)
56-
.options(joinedload(JobModel.run))
57+
.options(
58+
joinedload(JobModel.run),
59+
joinedload(JobModel.instance),
60+
)
5761
)
58-
jobs = res.scalars().all()
62+
jobs = res.unique().scalars().all()
5963
if not len(jobs):
6064
return None
6165
run = jobs[0].run
@@ -83,12 +87,22 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
8387
username=jpd.username,
8488
port=jpd.ssh_port,
8589
)
90+
ssh_head_proxy: Optional[SSHConnectionParams] = None
91+
ssh_head_proxy_private_key: Optional[str] = None
92+
instance = get_or_error(job.instance)
93+
if instance.remote_connection_info is not None:
94+
rci = RemoteConnectionInfo.__response__.parse_raw(instance.remote_connection_info)
95+
if rci.ssh_proxy is not None:
96+
ssh_head_proxy = rci.ssh_proxy
97+
ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private
8698
replica = Replica(
8799
id=job.id.hex,
88100
app_port=run_spec.configuration.port.container_port,
89101
ssh_destination=ssh_destination,
90102
ssh_port=ssh_port,
91103
ssh_proxy=ssh_proxy,
104+
ssh_head_proxy=ssh_head_proxy,
105+
ssh_head_proxy_private_key=ssh_head_proxy_private_key,
92106
)
93107
replicas.append(replica)
94108
return Service(

0 commit comments

Comments
 (0)