Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/dstack/_internal/core/services/ssh/tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
ssh_config_path: Union[PathLike, Literal["none"]] = "none",
port: Optional[int] = None,
ssh_proxies: Iterable[tuple[SSHConnectionParams, Optional[FilePathOrContent]]] = (),
batch_mode: bool = False,
):
"""
:param forwarded_sockets: Connections to the specified local sockets will be
Expand All @@ -79,11 +80,22 @@ def __init__(
:param ssh_proxies: pairs of SSH connections params and optional identities,
in order from outer to inner. If an identity is `None`, the `identity` param
is used instead.
:param batch_mode: If enabled, "user interaction such as password prompts and host key
confirmation requests will be disabled", see `ssh_config(5)`, `BatchMode`.
Although this is probably the desired behavior in all use cases, the default value
is `False` for gradual adoption.
Note, this option is only applied to the `destination` and `ssh_proxies`. If you
configured `destination` with `ProxyJump` in the `ssh_config_path` config, the proxy
jump connection will ignore this option -- in that case, you should replace `ProxyJump`
with explicit `ProxyCommand=ssh [...] -o BatchMode=yes` in your config.
"""
self.destination = destination
self.forwarded_sockets = list(forwarded_sockets)
self.reverse_forwarded_sockets = list(reverse_forwarded_sockets)
self.options = options
# A copy of options with names normalized to lowercase. Used only internally, the actual
# ssh command is built from user-provided options as is.
self.options_normalized = {k.lower(): v for k, v in options.items()}
self.port = port
self.ssh_config_path = normalize_path(ssh_config_path)
temp_dir = tempfile.TemporaryDirectory()
Expand All @@ -101,6 +113,7 @@ def __init__(
proxy_identity, f"proxy_identity_{proxy_index}"
)
self.ssh_proxies.append((proxy_params, proxy_identity_path))
self.batch_mode = batch_mode
self.log_path = normalize_path(os.path.join(temp_dir.name, "tunnel.log"))
self.ssh_client_info = get_ssh_client_info()
self.ssh_exec_path = str(self.ssh_client_info.path)
Expand Down Expand Up @@ -145,6 +158,14 @@ def open_command(self) -> List[str]:
command += ["-p", str(self.port)]
for k, v in self.options.items():
command += ["-o", f"{k}={v}"]
if self.batch_mode:
command += ["-o", "BatchMode=yes"]
if "serveraliveinterval" not in self.options_normalized:
# Revert Debian-specific patch effect:
# > The default is 0, indicating that these messages will not be sent
# > to the server, or 300 if the BatchMode option is set (Debian-specific).
# https://salsa.debian.org/ssh-team/openssh/-/blob/d87b69641b533b892b87e2eea02dbee796682d64/debian/patches/keepalive-extensions.patch#L69-77
command += ["-o", "ServerAliveInterval=0"]
if proxy_command := self._get_proxy_command():
command += ["-o", proxy_command]
for socket_pair in self.forwarded_sockets:
Expand Down Expand Up @@ -290,6 +311,14 @@ def _build_proxy_command(
"-o",
"UserKnownHostsFile=/dev/null",
]
if self.batch_mode:
# ServerAliveInterval is explained in the open_command() comment
command += [
"-o",
"BatchMode=yes",
"-o",
"ServerAliveInterval=0",
]
if prev_proxy_command is not None:
command += ["-o", prev_proxy_command.replace("%", "%%")]
command += [
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/server/services/runner/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def wrapper(
forwarded_sockets=ports_to_forwarded_sockets(tunnel_ports_map),
identity=identity,
ssh_proxies=ssh_proxies,
batch_mode=True,
):
return func(runner_ports_map, *args, **kwargs)
except SSHError:
Expand Down
Loading