Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 85 additions & 29 deletions packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,40 @@
from jumpstarter.client.decorators import driver_click_command


@dataclass
class SSHCommandRunResult:
"""Result of executing an SSH command"""
return_code: int
stdout: str | bytes
stderr: str | bytes

@staticmethod
def from_completed_process(result: subprocess.CompletedProcess) -> "SSHCommandRunResult":
return SSHCommandRunResult(
return_code=result.returncode,
stdout=result.stdout or "",
stderr=result.stderr or "",
)


@dataclass
class SSHCommandRunOptions:
"""
Options for running an SSH command

Attributes:
direct: If True, connect directly to the host's TCP address.
If False, use SSH port forwarding.
capture_output: If True, capture stdout and stderr.
If False, they are inherited from the parent process.
capture_as_text: If True and output is captured, decode stdout and
stderr as text. Otherwise, they are captured as bytes.
"""
direct: bool = False
capture_output: bool = True
capture_as_text: bool = True


@dataclass(kw_only=True)
class SSHWrapperClient(CompositeClient):
"""
Expand All @@ -30,11 +64,25 @@ def cli(self):
@click.option("--direct", is_flag=True, help="Use direct TCP address")
@click.argument("args", nargs=-1)
def ssh(direct, args):
result = self.run(direct, args)
self.logger.debug(f"SSH result: {result}")
if result != 0:
click.get_current_context().exit(result)
return result
options = SSHCommandRunOptions(
direct=direct,
# For the CLI, we never capture output so that interactive shells
# and long-running commands stream their output directly.
capture_output=False,
)

result = self.run(options, args)
self.logger.debug("SSH exit code: %s", result.return_code)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of the CLI, we should pass the new option structure, and use direct i/o, i.e. not capturing stdout/stderr, and printing through click.

Otherwise plain "j ssh" (which drops you into a shell) won't work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, sorry I missed it at first!

I tested it now, and it seems to work fine.


if result.stdout:
click.echo(result.stdout, nl=False)
if result.stderr:
click.echo(result.stderr, nl=False, err=True)

if result.return_code != 0:
click.get_current_context().exit(result.return_code)

return result.return_code

return ssh

Expand All @@ -46,14 +94,14 @@ def stream(self, method="connect"):
async def stream_async(self, method):
return await self.tcp.stream_async(method)

def run(self, direct, args):
def run(self, options: SSHCommandRunOptions, args) -> SSHCommandRunResult:
"""Run SSH command with the given parameters and arguments"""
# Get SSH command and default username from driver
ssh_command = self.call("get_ssh_command")
default_username = self.call("get_default_username")
ssh_identity = self.call("get_ssh_identity")

if direct:
if options.direct:
# Use direct TCP address
try:
address = self.tcp.address() # (format: "tcp://host:port")
Expand All @@ -62,23 +110,26 @@ def run(self, direct, args):
port = parsed.port
if not host or not port:
raise ValueError(f"Invalid address format: {address}")
self.logger.debug(f"Using direct TCP connection for SSH - host: {host}, port: {port}")
return self._run_ssh_local(host, port, ssh_command, default_username, ssh_identity, args)
self.logger.debug("Using direct TCP connection for SSH - host: %s, port: %s", host, port)
return self._run_ssh_local(host, port, ssh_command, options, default_username, ssh_identity, args)
except (DriverMethodNotImplemented, ValueError) as e:
self.logger.error(f"Direct address connection failed ({e}), falling back to SSH port forwarding")
return self.run(False, args)
self.logger.error("Direct address connection failed (%s), falling back to SSH port forwarding", e)
return self.run(SSHCommandRunOptions(
direct=False,
capture_output=options.capture_output,
capture_as_text=options.capture_as_text,
), args)
else:
# Use SSH port forwarding (default behavior)
self.logger.debug("Using SSH port forwarding for SSH connection")
with TcpPortforwardAdapter(
client=self.tcp,
) as addr:
host = addr[0]
port = addr[1]
self.logger.debug(f"SSH port forward established - host: {host}, port: {port}")
return self._run_ssh_local(host, port, ssh_command, default_username, ssh_identity, args)
host, port = addr
self.logger.debug("SSH port forward established - host: %s, port: %s", host, port)
return self._run_ssh_local(host, port, ssh_command, options, default_username, ssh_identity, args)

def _run_ssh_local(self, host, port, ssh_command, default_username, ssh_identity, args):
def _run_ssh_local(self, host, port, ssh_command, options, default_username, ssh_identity, args):
"""Run SSH command with the given host, port, and arguments"""
# Create temporary identity file if needed
identity_file = None
Expand All @@ -91,9 +142,9 @@ def _run_ssh_local(self, host, port, ssh_command, default_username, ssh_identity
# Set proper permissions (600) for SSH key
os.chmod(temp_file.name, 0o600)
identity_file = temp_file.name
self.logger.debug(f"Created temporary identity file: {identity_file}")
self.logger.debug("Created temporary identity file: %s", identity_file)
except Exception as e:
self.logger.error(f"Failed to create temporary identity file: {e}")
self.logger.error("Failed to create temporary identity file: %s", e)
if temp_file:
try:
os.unlink(temp_file.name)
Expand All @@ -112,15 +163,15 @@ def _run_ssh_local(self, host, port, ssh_command, default_username, ssh_identity
ssh_args = self._build_final_ssh_command(ssh_args, ssh_options, host, command_args)

# Execute the command
return self._execute_ssh_command(ssh_args)
return self._execute_ssh_command(ssh_args, options)
finally:
# Clean up temporary identity file
if identity_file:
try:
os.unlink(identity_file)
self.logger.debug(f"Cleaned up temporary identity file: {identity_file}")
self.logger.debug("Cleaned up temporary identity file: %s", identity_file)
except Exception as e:
self.logger.warning(f"Failed to clean up temporary identity file {identity_file}: {e}")
self.logger.warning("Failed to clean up temporary identity file %s: %s", identity_file, str(e))

def _build_ssh_command_args(self, ssh_command, port, default_username, identity_file, args):
"""Build initial SSH command arguments"""
Expand Down Expand Up @@ -192,8 +243,8 @@ def _separate_ssh_options_and_command_args(self, args):
i += 1

# Debug output
self.logger.debug(f"SSH options: {ssh_options}")
self.logger.debug(f"Command args: {command_args}")
self.logger.debug("SSH options: %s", ssh_options)
self.logger.debug("Command args: %s", command_args)
return ssh_options, command_args


Expand All @@ -209,16 +260,21 @@ def _build_final_ssh_command(self, ssh_args, ssh_options, host, command_args):
# Add command arguments
ssh_args.extend(command_args)

self.logger.debug(f"Running SSH command: {ssh_args}")
self.logger.debug("Running SSH command: %s", ssh_args)
return ssh_args

def _execute_ssh_command(self, ssh_args):
def _execute_ssh_command(self, ssh_args, options: SSHCommandRunOptions) -> SSHCommandRunResult:
"""Execute the SSH command and return the result"""
try:
result = subprocess.run(ssh_args)
return result.returncode
result = subprocess.run(ssh_args, capture_output=options.capture_output, text=options.capture_as_text)
return SSHCommandRunResult.from_completed_process(result)
except FileNotFoundError:
self.logger.error(
f"SSH command '{ssh_args[0]}' not found. Please ensure SSH is installed and available in PATH."
"SSH command '%s' not found. Please ensure SSH is installed and available in PATH.",
ssh_args[0],
)
return SSHCommandRunResult(
return_code=127, # Standard exit code for "command not found"
stdout="",
stderr=f"SSH command '{ssh_args[0]}' not found",
)
return 127 # Standard exit code for "command not found"
Loading
Loading