Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,18 @@ def cli(self):
help="Run SSH command with arguments",
)
@click.option("--direct", is_flag=True, help="Use direct TCP address")
@click.option("-u", "--user", help="Username to use for SSH connection")
@click.argument("args", nargs=-1)
def ssh(direct, args):
def ssh(direct, user, args):
"""Run SSH command with arguments."""
options = SSHCommandRunOptions(
direct=direct,
# For the CLI, we never capture output so that interactive shells
# and long-running commands stream their output directly.
capture_output=False,
)

result = self.run(options, args)
result = self.run(options, args, user=user)
self.logger.debug("SSH exit code: %s", result.return_code)

if result.stdout:
Expand Down Expand Up @@ -118,8 +120,15 @@ def username(self) -> str:
"""Get the default SSH username"""
return self.call("get_default_username")

def run(self, options: SSHCommandRunOptions, args) -> SSHCommandRunResult:
"""Run SSH command with the given parameters and arguments"""
def run(self, options: SSHCommandRunOptions, args, user: str | None = None) -> SSHCommandRunResult:
"""
Run SSH command with the given parameters and arguments

Args:
options: SSH command run options.
args: Command arguments.
user: Optional username to override the default.
"""
# Get SSH command and default username from driver
if options.direct:
# Use direct TCP address
Expand All @@ -131,14 +140,14 @@ def run(self, options: SSHCommandRunOptions, args) -> SSHCommandRunResult:
if not host or not port:
raise ValueError(f"Invalid address format: {address}")
self.logger.debug("Using direct TCP connection for SSH - host: %s, port: %s", host, port)
return self._run_ssh_local(host, port, options, args)
return self._run_ssh_local(host, port, options, args, user)
except (DriverMethodNotImplemented, ValueError) as e:
self.logger.error("Direct address connection failed (%s), falling back to SSH port forwarding", e)
return self.run(SSHCommandRunOptions(
direct=False,
capture_output=options.capture_output,
capture_as_text=options.capture_as_text,
), args)
), args, user=user)
else:
# Use SSH port forwarding (default behavior)
self.logger.debug("Using SSH port forwarding for SSH connection")
Expand All @@ -147,9 +156,9 @@ def run(self, options: SSHCommandRunOptions, args) -> SSHCommandRunResult:
) as addr:
host, port = addr
self.logger.debug("SSH port forward established - host: %s, port: %s", host, port)
return self._run_ssh_local(host, port, options, args)
return self._run_ssh_local(host, port, options, args, user)

def _run_ssh_local(self, host, port, options, args):
def _run_ssh_local(self, host, port, options, args, user: str | None = None):
"""Run SSH command with the given host, port, and arguments"""
# Create temporary identity file if needed
ssh_identity = self.identity
Expand All @@ -175,7 +184,7 @@ def _run_ssh_local(self, host, port, options, args):

try:
# Build SSH command arguments
ssh_args = self._build_ssh_command_args(port, identity_file, args)
ssh_args = self._build_ssh_command_args(port, identity_file, args, user)

# Separate SSH options from command arguments
ssh_options, command_args = self._separate_ssh_options_and_command_args(args)
Expand All @@ -194,11 +203,11 @@ def _run_ssh_local(self, host, port, options, args):
except Exception as e:
self.logger.warning("Failed to clean up temporary identity file %s: %s", identity_file, str(e))

def _build_ssh_command_args(self, port, identity_file, args):
def _build_ssh_command_args(self, port, identity_file, args, user: str | None = None):
"""Build initial SSH command arguments"""
# Split the SSH command into individual arguments
ssh_args = shlex.split(self.command)
default_username = self.username
default_username = user or self.username

# Add identity file if provided
if identity_file:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,61 @@ def test_ssh_command_without_default_username():
assert result.stdout == "some stdout"


def test_ssh_command_with_explicit_user_parameter():
"""Test SSH command execution with the user parameter overriding the default."""
instance = SSHWrapper(
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
default_username="testuser",
)

with serve(instance) as client:
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="")

# Call run with an explicit user.
result = client.run(SSHCommandRunOptions(direct=False), ["hostname"], user="overrideuser")
assert isinstance(result, SSHCommandRunResult)

# Verify subprocess.run was called.
assert mock_run.called
call_args = mock_run.call_args[0][0]

# Check that the override user is present.
assert "-l" in call_args
assert "overrideuser" in call_args
assert "testuser" not in call_args
assert call_args[call_args.index("-l") + 1] == "overrideuser"

assert "127.0.0.1" in call_args
assert "hostname" in call_args

assert result.return_code == 0
assert result.stdout == "some stdout"


def test_ssh_command_with_explicit_user_parameter_fallback():
"""Test that user parameter is preserved during direct-to-portforward fallback."""
instance = SSHWrapper(
children={"tcp": TcpNetwork(host="127.0.0.1", port=22)},
default_username="testuser",
)

with serve(instance) as client:
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="")

# Mock tcp.address() to fail, triggering fallback
with patch.object(client.tcp, 'address', side_effect=ValueError("Connection failed")):
client.run(SSHCommandRunOptions(direct=True), ["hostname"], user="overrideuser")

# Verify that overrideuser is still used after fallback
call_args = mock_run.call_args[0][0]
assert "-l" in call_args
assert "overrideuser" in call_args
assert "testuser" not in call_args
assert call_args[call_args.index("-l") + 1] == "overrideuser"


def test_ssh_command_with_user_override():
"""Test SSH command execution with -l flag overriding default username"""
instance = SSHWrapper(
Expand Down