@@ -69,44 +69,38 @@ def __init__(
6969 options : Dict [str , str ] = SSH_DEFAULT_OPTIONS ,
7070 ssh_config_path : Union [PathLike , Literal ["none" ]] = "none" ,
7171 port : Optional [int ] = None ,
72- ssh_proxy : Optional [SSHConnectionParams ] = None ,
73- ssh_proxy_identity : Optional [FilePathOrContent ] = None ,
72+ ssh_proxies : Iterable [tuple [SSHConnectionParams , Optional [FilePathOrContent ]]] = (),
7473 ):
7574 """
7675 :param forwarded_sockets: Connections to the specified local sockets will be
7776 forwarded to their corresponding remote sockets
7877 :param reverse_forwarded_sockets: Connections to the specified remote sockets
7978 will be forwarded to their corresponding local sockets
79+ :param ssh_proxies: pairs of SSH connections params and optional identities,
80+ in order from outer to inner. If an identity is `None`, the `identity` param
81+ is used instead.
8082 """
8183 self .destination = destination
8284 self .forwarded_sockets = list (forwarded_sockets )
8385 self .reverse_forwarded_sockets = list (reverse_forwarded_sockets )
8486 self .options = options
8587 self .port = port
8688 self .ssh_config_path = normalize_path (ssh_config_path )
87- self .ssh_proxy = ssh_proxy
8889 temp_dir = tempfile .TemporaryDirectory ()
8990 self .temp_dir = temp_dir
9091 if control_sock_path is None :
9192 control_sock_path = os .path .join (temp_dir .name , "control.sock" )
9293 self .control_sock_path = normalize_path (control_sock_path )
93- if isinstance (identity , FilePath ):
94- identity_path = identity .path
95- else :
96- identity_path = os .path .join (temp_dir .name , "identity" )
97- with open (
98- identity_path , opener = lambda path , flags : os .open (path , flags , 0o600 ), mode = "w"
99- ) as f :
100- f .write (identity .content )
10194 self .identity_path = normalize_path (self ._get_identity_path (identity , "identity" ))
102- if ssh_proxy_identity is not None :
103- self .ssh_proxy_identity_path = normalize_path (
104- self ._get_identity_path (ssh_proxy_identity , "proxy_identity" )
105- )
106- elif ssh_proxy is not None :
107- self .ssh_proxy_identity_path = self .identity_path
108- else :
109- self .ssh_proxy_identity_path = None
95+ self .ssh_proxies : list [tuple [SSHConnectionParams , PathLike ]] = []
96+ for proxy_index , (proxy_params , proxy_identity ) in enumerate (ssh_proxies ):
97+ if proxy_identity is None :
98+ proxy_identity_path = self .identity_path
99+ else :
100+ proxy_identity_path = self ._get_identity_path (
101+ proxy_identity , f"proxy_identity_{ proxy_index } "
102+ )
103+ self .ssh_proxies .append ((proxy_params , proxy_identity_path ))
110104 self .log_path = normalize_path (os .path .join (temp_dir .name , "tunnel.log" ))
111105 self .ssh_client_info = get_ssh_client_info ()
112106 self .ssh_exec_path = str (self .ssh_client_info .path )
@@ -151,8 +145,8 @@ def open_command(self) -> List[str]:
151145 command += ["-p" , str (self .port )]
152146 for k , v in self .options .items ():
153147 command += ["-o" , f"{ k } ={ v } " ]
154- if proxy_command := self .proxy_command ():
155- command += ["-o" , "ProxyCommand=" + shlex . join ( proxy_command ) ]
148+ if proxy_command := self ._get_proxy_command ():
149+ command += ["-o" , proxy_command ]
156150 for socket_pair in self .forwarded_sockets :
157151 command += ["-L" , f"{ socket_pair .local .render ()} :{ socket_pair .remote .render ()} " ]
158152 for socket_pair in self .reverse_forwarded_sockets :
@@ -169,24 +163,6 @@ def check_command(self) -> List[str]:
169163 def exec_command (self ) -> List [str ]:
170164 return [self .ssh_exec_path , "-S" , self .control_sock_path , self .destination ]
171165
172- def proxy_command (self ) -> Optional [List [str ]]:
173- if self .ssh_proxy is None :
174- return None
175- return [
176- self .ssh_exec_path ,
177- "-i" ,
178- self .ssh_proxy_identity_path ,
179- "-W" ,
180- "%h:%p" ,
181- "-o" ,
182- "StrictHostKeyChecking=no" ,
183- "-o" ,
184- "UserKnownHostsFile=/dev/null" ,
185- "-p" ,
186- str (self .ssh_proxy .port ),
187- f"{ self .ssh_proxy .username } @{ self .ssh_proxy .hostname } " ,
188- ]
189-
190166 def open (self ) -> None :
191167 # We cannot use `stderr=subprocess.PIPE` here since the forked process (daemon) does not
192168 # close standard streams if ProxyJump is used, therefore we will wait EOF from the pipe
@@ -260,6 +236,38 @@ def __enter__(self):
260236 def __exit__ (self , exc_type , exc_val , exc_tb ):
261237 self .close ()
262238
239+ def _get_proxy_command (self ) -> Optional [str ]:
240+ proxy_command : Optional [str ] = None
241+ for params , identity_path in self .ssh_proxies :
242+ proxy_command = self ._build_proxy_command (params , identity_path , proxy_command )
243+ return proxy_command
244+
245+ def _build_proxy_command (
246+ self ,
247+ params : SSHConnectionParams ,
248+ identity_path : PathLike ,
249+ prev_proxy_command : Optional [str ],
250+ ) -> Optional [str ]:
251+ command = [
252+ self .ssh_exec_path ,
253+ "-i" ,
254+ identity_path ,
255+ "-W" ,
256+ "%h:%p" ,
257+ "-o" ,
258+ "StrictHostKeyChecking=no" ,
259+ "-o" ,
260+ "UserKnownHostsFile=/dev/null" ,
261+ ]
262+ if prev_proxy_command is not None :
263+ command += ["-o" , prev_proxy_command .replace ("%" , "%%" )]
264+ command += [
265+ "-p" ,
266+ str (params .port ),
267+ f"{ params .username } @{ params .hostname } " ,
268+ ]
269+ return "ProxyCommand=" + shlex .join (command )
270+
263271 def _read_log_file (self ) -> bytes :
264272 with open (self .log_path , "rb" ) as f :
265273 return f .read ()
0 commit comments