diff --git a/src/dstack/_internal/core/services/ssh/tunnel.py b/src/dstack/_internal/core/services/ssh/tunnel.py index e4f7f276e..5fc75ccfc 100644 --- a/src/dstack/_internal/core/services/ssh/tunnel.py +++ b/src/dstack/_internal/core/services/ssh/tunnel.py @@ -70,6 +70,7 @@ def __init__( ssh_config_path: Union[PathLike, Literal["none"]] = "none", port: Optional[int] = None, ssh_proxies: Iterable[tuple[SSHConnectionParams, Optional[FilePathOrContent]]] = (), + batch_mode: bool = False, ): """ :param forwarded_sockets: Connections to the specified local sockets will be @@ -79,11 +80,22 @@ def __init__( :param ssh_proxies: pairs of SSH connections params and optional identities, in order from outer to inner. If an identity is `None`, the `identity` param is used instead. + :param batch_mode: If enabled, "user interaction such as password prompts and host key + confirmation requests will be disabled", see `ssh_config(5)`, `BatchMode`. + Although this is probably the desired behavior in all use cases, the default value + is `False` for gradual adoption. + Note, this option is only applied to the `destination` and `ssh_proxies`. If you + configured `destination` with `ProxyJump` in the `ssh_config_path` config, the proxy + jump connection will ignore this option -- in that case, you should replace `ProxyJump` + with explicit `ProxyCommand=ssh [...] -o BatchMode=yes` in your config. """ self.destination = destination self.forwarded_sockets = list(forwarded_sockets) self.reverse_forwarded_sockets = list(reverse_forwarded_sockets) self.options = options + # A copy of options with names normalized to lowercase. Used only internally, the actual + # ssh command is built from user-provided options as is. + self.options_normalized = {k.lower(): v for k, v in options.items()} self.port = port self.ssh_config_path = normalize_path(ssh_config_path) temp_dir = tempfile.TemporaryDirectory() @@ -101,6 +113,7 @@ def __init__( proxy_identity, f"proxy_identity_{proxy_index}" ) self.ssh_proxies.append((proxy_params, proxy_identity_path)) + self.batch_mode = batch_mode self.log_path = normalize_path(os.path.join(temp_dir.name, "tunnel.log")) self.ssh_client_info = get_ssh_client_info() self.ssh_exec_path = str(self.ssh_client_info.path) @@ -145,6 +158,14 @@ def open_command(self) -> List[str]: command += ["-p", str(self.port)] for k, v in self.options.items(): command += ["-o", f"{k}={v}"] + if self.batch_mode: + command += ["-o", "BatchMode=yes"] + if "serveraliveinterval" not in self.options_normalized: + # Revert Debian-specific patch effect: + # > The default is 0, indicating that these messages will not be sent + # > to the server, or 300 if the BatchMode option is set (Debian-specific). + # https://salsa.debian.org/ssh-team/openssh/-/blob/d87b69641b533b892b87e2eea02dbee796682d64/debian/patches/keepalive-extensions.patch#L69-77 + command += ["-o", "ServerAliveInterval=0"] if proxy_command := self._get_proxy_command(): command += ["-o", proxy_command] for socket_pair in self.forwarded_sockets: @@ -290,6 +311,14 @@ def _build_proxy_command( "-o", "UserKnownHostsFile=/dev/null", ] + if self.batch_mode: + # ServerAliveInterval is explained in the open_command() comment + command += [ + "-o", + "BatchMode=yes", + "-o", + "ServerAliveInterval=0", + ] if prev_proxy_command is not None: command += ["-o", prev_proxy_command.replace("%", "%%")] command += [ diff --git a/src/dstack/_internal/server/services/runner/ssh.py b/src/dstack/_internal/server/services/runner/ssh.py index 669ea2181..a4ef98686 100644 --- a/src/dstack/_internal/server/services/runner/ssh.py +++ b/src/dstack/_internal/server/services/runner/ssh.py @@ -96,6 +96,7 @@ def wrapper( forwarded_sockets=ports_to_forwarded_sockets(tunnel_ports_map), identity=identity, ssh_proxies=ssh_proxies, + batch_mode=True, ): return func(runner_ports_map, *args, **kwargs) except SSHError: