Skip to content

Commit 6a58e18

Browse files
authored
Update ssh_tunnel.py
Logger, types, custom exception
1 parent 3c3c7e8 commit 6a58e18

1 file changed

Lines changed: 43 additions & 8 deletions

File tree

ssh_tunnel.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
1+
import logging
2+
import os
13
import subprocess
24
import tempfile
35
from contextlib import contextmanager
6+
from typing import Generator, Optional
7+
8+
logger = logging.getLogger(__name__)
9+
logger.addHandler(logging.NullHandler())
10+
11+
12+
class SSHTunnelConnectionError(Exception):
13+
pass
414

515

616
@contextmanager
7-
def create_ssh_tunnel(hostname, local_socket, remote_socket):
17+
def create_ssh_tunnel(
18+
hostname: str, local_socket: str, remote_socket: str, timeout: int = 10
19+
) -> Generator[str, None, None]:
820
ssh_socket_filename = gen_temp_socket_filename(f"{hostname}.")
921
ssh_tunnel_cmd = [
1022
"ssh",
11-
"-qfN",
23+
"-fN",
1224
"-M",
1325
"-S",
1426
ssh_socket_filename,
@@ -26,30 +38,53 @@ def create_ssh_tunnel(hostname, local_socket, remote_socket):
2638
]
2739
ssh_tunnel_terminate_cmd = [
2840
"ssh",
29-
"-q",
3041
"-S",
3142
ssh_socket_filename,
3243
"-O",
3344
"exit",
3445
hostname,
3546
]
3647
try:
37-
yield subprocess.run(ssh_tunnel_cmd, check=True)
48+
logger.debug(f"Execute cmd: {' '.join(ssh_tunnel_cmd)}")
49+
subprocess.run(
50+
ssh_tunnel_cmd,
51+
check=True,
52+
stdout=subprocess.DEVNULL,
53+
stderr=subprocess.DEVNULL,
54+
timeout=timeout,
55+
)
56+
yield ssh_socket_filename
57+
except (subprocess.CalledProcessError, subprocess.TimeoutExpired, ValueError) as ex:
58+
logger.exception(
59+
f"Exception occurred when trying to open SSH tunnel:\n{ex}",
60+
exc_info=False,
61+
)
62+
raise SSHTunnelConnectionError(ex) from ex
3863
finally:
3964
try:
40-
subprocess.run(ssh_tunnel_terminate_cmd)
65+
logger.debug(
66+
f"Execute cmd: {' '.join(ssh_tunnel_terminate_cmd)}",
67+
)
68+
subprocess.run(
69+
ssh_tunnel_terminate_cmd,
70+
check=True,
71+
stdout=subprocess.DEVNULL,
72+
stderr=subprocess.DEVNULL,
73+
)
74+
logger.debug("Deleting socket file")
4175
os.remove(local_socket)
4276
except (subprocess.CalledProcessError, FileNotFoundError):
4377
pass
4478

4579

46-
def gen_temp_socket_filename(prefix=None, suffix=None):
80+
def gen_temp_socket_filename(
81+
prefix: Optional[str] = None, suffix: Optional[str] = None
82+
) -> str:
4783
temp_socket_filename = None
4884
with tempfile.NamedTemporaryFile(
4985
suffix=suffix, prefix=prefix, dir=tempfile.gettempdir()
5086
) as tmpfile:
5187
temp_socket_filename = tmpfile.name
5288
if temp_socket_filename is not None:
5389
return temp_socket_filename
54-
else:
55-
raise RuntimeError("Unable to create temp file")
90+
raise RuntimeError("Unable to create temp file")

0 commit comments

Comments
 (0)