diff --git a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py index 2a0bbd63c..d4cdee981 100644 --- a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py +++ b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py @@ -62,8 +62,10 @@ def cli(self): help="Run SSH command with arguments", ) @click.option("--direct", is_flag=True, help="Use direct TCP address") + @click.option("-u", "--user", help="Username to use for SSH connection") @click.argument("args", nargs=-1) - def ssh(direct, args): + def ssh(direct, user, args): + """Run SSH command with arguments.""" options = SSHCommandRunOptions( direct=direct, # For the CLI, we never capture output so that interactive shells @@ -71,7 +73,7 @@ def ssh(direct, args): capture_output=False, ) - result = self.run(options, args) + result = self.run(options, args, user=user) self.logger.debug("SSH exit code: %s", result.return_code) if result.stdout: @@ -118,8 +120,15 @@ def username(self) -> str: """Get the default SSH username""" return self.call("get_default_username") - def run(self, options: SSHCommandRunOptions, args) -> SSHCommandRunResult: - """Run SSH command with the given parameters and arguments""" + def run(self, options: SSHCommandRunOptions, args, user: str | None = None) -> SSHCommandRunResult: + """ + Run SSH command with the given parameters and arguments + + Args: + options: SSH command run options. + args: Command arguments. + user: Optional username to override the default. + """ # Get SSH command and default username from driver if options.direct: # Use direct TCP address @@ -131,14 +140,14 @@ def run(self, options: SSHCommandRunOptions, args) -> SSHCommandRunResult: if not host or not port: raise ValueError(f"Invalid address format: {address}") self.logger.debug("Using direct TCP connection for SSH - host: %s, port: %s", host, port) - return self._run_ssh_local(host, port, options, args) + return self._run_ssh_local(host, port, options, args, user) except (DriverMethodNotImplemented, ValueError) as e: self.logger.error("Direct address connection failed (%s), falling back to SSH port forwarding", e) return self.run(SSHCommandRunOptions( direct=False, capture_output=options.capture_output, capture_as_text=options.capture_as_text, - ), args) + ), args, user=user) else: # Use SSH port forwarding (default behavior) self.logger.debug("Using SSH port forwarding for SSH connection") @@ -147,9 +156,9 @@ def run(self, options: SSHCommandRunOptions, args) -> SSHCommandRunResult: ) as addr: host, port = addr self.logger.debug("SSH port forward established - host: %s, port: %s", host, port) - return self._run_ssh_local(host, port, options, args) + return self._run_ssh_local(host, port, options, args, user) - def _run_ssh_local(self, host, port, options, args): + def _run_ssh_local(self, host, port, options, args, user: str | None = None): """Run SSH command with the given host, port, and arguments""" # Create temporary identity file if needed ssh_identity = self.identity @@ -175,7 +184,7 @@ def _run_ssh_local(self, host, port, options, args): try: # Build SSH command arguments - ssh_args = self._build_ssh_command_args(port, identity_file, args) + ssh_args = self._build_ssh_command_args(port, identity_file, args, user) # Separate SSH options from command arguments ssh_options, command_args = self._separate_ssh_options_and_command_args(args) @@ -194,11 +203,11 @@ def _run_ssh_local(self, host, port, options, args): except Exception as e: self.logger.warning("Failed to clean up temporary identity file %s: %s", identity_file, str(e)) - def _build_ssh_command_args(self, port, identity_file, args): + def _build_ssh_command_args(self, port, identity_file, args, user: str | None = None): """Build initial SSH command arguments""" # Split the SSH command into individual arguments ssh_args = shlex.split(self.command) - default_username = self.username + default_username = user or self.username # Add identity file if provided if identity_file: diff --git a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py index 92a540406..c18903015 100644 --- a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py +++ b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver_test.py @@ -105,6 +105,61 @@ def test_ssh_command_without_default_username(): assert result.stdout == "some stdout" +def test_ssh_command_with_explicit_user_parameter(): + """Test SSH command execution with the user parameter overriding the default.""" + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ) + + with serve(instance) as client: + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="") + + # Call run with an explicit user. + result = client.run(SSHCommandRunOptions(direct=False), ["hostname"], user="overrideuser") + assert isinstance(result, SSHCommandRunResult) + + # Verify subprocess.run was called. + assert mock_run.called + call_args = mock_run.call_args[0][0] + + # Check that the override user is present. + assert "-l" in call_args + assert "overrideuser" in call_args + assert "testuser" not in call_args + assert call_args[call_args.index("-l") + 1] == "overrideuser" + + assert "127.0.0.1" in call_args + assert "hostname" in call_args + + assert result.return_code == 0 + assert result.stdout == "some stdout" + + +def test_ssh_command_with_explicit_user_parameter_fallback(): + """Test that user parameter is preserved during direct-to-portforward fallback.""" + instance = SSHWrapper( + children={"tcp": TcpNetwork(host="127.0.0.1", port=22)}, + default_username="testuser", + ) + + with serve(instance) as client: + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0, stdout="some stdout", stderr="") + + # Mock tcp.address() to fail, triggering fallback + with patch.object(client.tcp, 'address', side_effect=ValueError("Connection failed")): + client.run(SSHCommandRunOptions(direct=True), ["hostname"], user="overrideuser") + + # Verify that overrideuser is still used after fallback + call_args = mock_run.call_args[0][0] + assert "-l" in call_args + assert "overrideuser" in call_args + assert "testuser" not in call_args + assert call_args[call_args.index("-l") + 1] == "overrideuser" + + def test_ssh_command_with_user_override(): """Test SSH command execution with -l flag overriding default username""" instance = SSHWrapper(